Fix Mistral model (#1220)

This commit is contained in:
Woosuk Kwon 2023-09-28 10:44:05 -07:00 committed by GitHub
parent bb1ba58f06
commit a8e98aee0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 27 additions and 14 deletions

View File

@ -29,7 +29,6 @@ from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from vllm.transformers_utils.configs.mistral import MistralConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
@ -46,6 +45,7 @@ from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator, convert_pyslice_to_tensor, hf_model_weights_iterator,
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab) load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mistral import MistralConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]

View File

@ -17,6 +17,15 @@ _CONFIG_REGISTRY = {
def get_config(model: str, def get_config(model: str,
trust_remote_code: bool, trust_remote_code: bool,
revision: Optional[str] = None) -> PretrainedConfig: revision: Optional[str] = None) -> PretrainedConfig:
# NOTE: Because the Mistral model in HF hub does not have
# `configuration_mistral.py`, we cannot use `AutoConfig` to load the
# config. Instead, we use `MistralConfig` directly.
# NOTE: This is a hack. This does not work for local models.
# FIXME: Remove this once the Mistral model is available in the stable
# version of HF transformers.
if "mistral" in model.lower():
return MistralConfig.from_pretrained(model, revision=revision)
try: try:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision) model, trust_remote_code=trust_remote_code, revision=revision)

View File

@ -6,6 +6,7 @@ from vllm.transformers_utils.configs.qwen import QWenConfig
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library. # `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.mistral import MistralConfig
__all__ = [ __all__ = [
"MPTConfig", "MPTConfig",
@ -13,4 +14,5 @@ __all__ = [
"AquilaConfig", "AquilaConfig",
"QWenConfig", "QWenConfig",
"RWConfig", "RWConfig",
"MistralConfig",
] ]

View File

@ -42,6 +42,7 @@ class Worker:
# self.init_cache_engine(). # self.init_cache_engine().
self.cache_config = None self.cache_config = None
self.block_size = None self.block_size = None
self.sliding_window = None
self.cache_engine = None self.cache_engine = None
self.cache_events = None self.cache_events = None
self.gpu_cache = None self.gpu_cache = None
@ -136,10 +137,13 @@ class Worker:
def init_cache_engine(self, cache_config: CacheConfig) -> None: def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config self.cache_config = cache_config
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.sliding_window = cache_config.sliding_window
max_seq_len = min(self.scheduler_config.max_model_len, if self.sliding_window is None:
cache_config.sliding_window or float("inf")) max_seq_len = self.scheduler_config.max_model_len
else:
max_seq_len = min(self.scheduler_config.max_model_len,
self.sliding_window)
_check_if_can_support_max_seq_len(max_seq_len, self.block_size) _check_if_can_support_max_seq_len(max_seq_len, self.block_size)
self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.cache_engine = CacheEngine(self.cache_config, self.model_config,
@ -151,6 +155,8 @@ class Worker:
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
assert self.block_size is not None
seq_groups: List[Tuple[List[int], SamplingParams]] = [] seq_groups: List[Tuple[List[int], SamplingParams]] = []
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
@ -193,9 +199,6 @@ class Worker:
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append(slot) slot_mapping.append(slot)
sliding_window = getattr(self.model_config.hf_config, "sliding_window",
float("inf"))
# Add generation tokens. # Add generation tokens.
max_context_len = 0 max_context_len = 0
max_num_blocks_per_seq = 0 max_num_blocks_per_seq = 0
@ -216,8 +219,8 @@ class Worker:
context_len = seq_data.get_len() context_len = seq_data.get_len()
position = context_len - 1 position = context_len - 1
if sliding_window: if self.sliding_window is not None:
context_len = min(context_len, sliding_window) context_len = min(context_len, self.sliding_window)
input_positions.append(position) input_positions.append(position)
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
@ -232,10 +235,9 @@ class Worker:
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append(slot) slot_mapping.append(slot)
if sliding_window: if self.sliding_window is not None:
assert self.cache_config is not None sliding_window_blocks = (self.sliding_window //
sliding_window_blocks = (sliding_window // self.block_size)
self.cache_config.block_size)
block_table = block_table[-sliding_window_blocks:] block_table = block_table[-sliding_window_blocks:]
generation_block_tables.append(block_table) generation_block_tables.append(block_table)
@ -277,7 +279,7 @@ class Worker:
context_lens=context_lens_tensor, context_lens=context_lens_tensor,
max_context_len=max_context_len, max_context_len=max_context_len,
block_tables=block_tables_tensor, block_tables=block_tables_tensor,
sliding_window=sliding_window, sliding_window=self.sliding_window,
) )
return tokens_tensor, positions_tensor, input_metadata return tokens_tensor, positions_tensor, input_metadata