diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index f56d6eaccfdd..545e41829bba 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -50,6 +50,9 @@ Alongside each architecture, we include some popular models that use it. * - :code:`MistralForCausalLM` - Mistral, Mistral-Instruct - :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc. + * - :code:`MixtralForCausalLM` + - Mixtral-8x7B, Mixtral-8x7B-Instruct + - :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, etc. * - :code:`MPTForCausalLM` - MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter - :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc. diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 5e30a0dd642e..37afccda7f01 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple import numpy as np @@ -453,10 +453,6 @@ class MixtralForCausalLM(nn.Module): assert linear_method is None self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.tok_embeddings: Union[nn.Embedding, None] = None - self.layers: nn.ModuleList = None - self.output: Union[nn.Linear, None] = None - self.sampler: Union[Sampler, None] = None self.tok_embeddings = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -492,6 +488,7 @@ class MixtralForCausalLM(nn.Module): input_metadata, cache_event, ) + hidden_states = self.norm(hidden_states) return hidden_states def sample( @@ -499,7 +496,6 @@ class MixtralForCausalLM(nn.Module): hidden_states: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> SamplerOutput: - hidden_states = self.norm(hidden_states) next_tokens = self.sampler(self.output.weight, hidden_states, sampling_metadata) return next_tokens