[bugfix] fix aria model and add torch.compile (#10645)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-25 18:32:09 -08:00 committed by GitHub
parent 6e9ff050c8
commit 45ac4ff270
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 28 deletions

View File

@ -29,7 +29,7 @@ from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaMLP,
LlamaModel)
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
is_pp_missing_parameter,
make_layers, maybe_prefix,
maybe_prefix,
merge_multimodal_embeddings)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
@ -363,27 +363,9 @@ class AriaMoELMModel(LlamaModel):
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
# FIXME: this is a hack to disable the compilation of the model
self.do_not_compile = True
self.layers = None
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MoEDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=MoEDecoderLayer)
# Adapted from LlamaModel.load_weights with the modification of adding
# the expert weights mapping to `stacked_params_mapping`

View File

@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
import torch
from torch import nn
@ -273,7 +273,11 @@ class LlamaDecoderLayer(nn.Module):
@support_torch_compile
class LlamaModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
@ -299,10 +303,10 @@ class LlamaModel(nn.Module):
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: LlamaDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
lambda prefix: layer_type(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank: