mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 21:45:25 +08:00
Minor fixes for Mixtral (#2015)
This commit is contained in:
parent
b5f882cc98
commit
4ff0203987
@ -50,6 +50,9 @@ Alongside each architecture, we include some popular models that use it.
|
|||||||
* - :code:`MistralForCausalLM`
|
* - :code:`MistralForCausalLM`
|
||||||
- Mistral, Mistral-Instruct
|
- Mistral, Mistral-Instruct
|
||||||
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
|
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
|
||||||
|
* - :code:`MixtralForCausalLM`
|
||||||
|
- Mixtral-8x7B, Mixtral-8x7B-Instruct
|
||||||
|
- :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.
|
||||||
* - :code:`MPTForCausalLM`
|
* - :code:`MPTForCausalLM`
|
||||||
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
||||||
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
||||||
|
|||||||
@ -21,7 +21,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only Mixtral model."""
|
"""Inference-only Mixtral model."""
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -453,10 +453,6 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
assert linear_method is None
|
assert linear_method is None
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.tok_embeddings: Union[nn.Embedding, None] = None
|
|
||||||
self.layers: nn.ModuleList = None
|
|
||||||
self.output: Union[nn.Linear, None] = None
|
|
||||||
self.sampler: Union[Sampler, None] = None
|
|
||||||
self.tok_embeddings = VocabParallelEmbedding(
|
self.tok_embeddings = VocabParallelEmbedding(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
@ -492,6 +488,7 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
input_metadata,
|
input_metadata,
|
||||||
cache_event,
|
cache_event,
|
||||||
)
|
)
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
@ -499,7 +496,6 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
hidden_states: Optional[torch.Tensor],
|
hidden_states: Optional[torch.Tensor],
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
next_tokens = self.sampler(self.output.weight, hidden_states,
|
next_tokens = self.sampler(self.output.weight, hidden_states,
|
||||||
sampling_metadata)
|
sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user