diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index afc30f93b524..8e98eb273cd8 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -1070,8 +1070,8 @@ class MllamaTextModel(nn.Module): inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds - for decoder_layer in self.layers: - if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer): + for idx, decoder_layer in enumerate(self.layers): + if idx in self.cross_attention_layers: if not skip_cross_attention: hidden_states = decoder_layer( hidden_states=hidden_states, @@ -1081,16 +1081,13 @@ class MllamaTextModel(nn.Module): full_text_row_masked_out_mask= full_text_row_masked_out_mask, ) - elif isinstance(decoder_layer, LlamaDecoderLayer): + else: hidden_states, residual = decoder_layer( positions=positions, hidden_states=hidden_states, residual=None, ) hidden_states = hidden_states + residual - else: - raise ValueError( - f"Unknown decoder layer type {type(decoder_layer)}") hidden_states = self.norm(hidden_states) return hidden_states @@ -1551,4 +1548,4 @@ def convert_dense_cross_attention_mask_to_tensor( full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None]) mask *= full_text_mask # (num_prompt_tokens, num_encoder_tokens) - return mask \ No newline at end of file + return mask