diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 7bc857575651..f2a4faa18b17 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -29,7 +29,6 @@ from typing import List, Optional, Tuple import torch from torch import nn -from vllm.transformers_utils.configs.mistral import MistralConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul @@ -46,6 +45,7 @@ from vllm.model_executor.weight_utils import ( convert_pyslice_to_tensor, hf_model_weights_iterator, load_tensor_parallel_weights, load_padded_tensor_parallel_vocab) from vllm.sequence import SamplerOutput +from vllm.transformers_utils.configs.mistral import MistralConfig KVCache = Tuple[torch.Tensor, torch.Tensor] diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index fd5618bd81ba..a1efbedb6895 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -17,6 +17,15 @@ _CONFIG_REGISTRY = { def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None) -> PretrainedConfig: + # NOTE: Because the Mistral model in HF hub does not have + # `configuration_mistral.py`, we cannot use `AutoConfig` to load the + # config. Instead, we use `MistralConfig` directly. + # NOTE: This is a hack. This does not work for local models. + # FIXME: Remove this once the Mistral model is available in the stable + # version of HF transformers. + if "mistral" in model.lower(): + return MistralConfig.from_pretrained(model, revision=revision) + try: config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 6611697d25ae..3955c772b7b3 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -6,6 +6,7 @@ from vllm.transformers_utils.configs.qwen import QWenConfig # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig +from vllm.transformers_utils.configs.mistral import MistralConfig __all__ = [ "MPTConfig", @@ -13,4 +14,5 @@ __all__ = [ "AquilaConfig", "QWenConfig", "RWConfig", + "MistralConfig", ] diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index e14086f27107..0de9d248a3e7 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -42,6 +42,7 @@ class Worker: # self.init_cache_engine(). self.cache_config = None self.block_size = None + self.sliding_window = None self.cache_engine = None self.cache_events = None self.gpu_cache = None @@ -136,10 +137,13 @@ class Worker: def init_cache_engine(self, cache_config: CacheConfig) -> None: self.cache_config = cache_config self.block_size = cache_config.block_size + self.sliding_window = cache_config.sliding_window - max_seq_len = min(self.scheduler_config.max_model_len, - cache_config.sliding_window or float("inf")) - + if self.sliding_window is None: + max_seq_len = self.scheduler_config.max_model_len + else: + max_seq_len = min(self.scheduler_config.max_model_len, + self.sliding_window) _check_if_can_support_max_seq_len(max_seq_len, self.block_size) self.cache_engine = CacheEngine(self.cache_config, self.model_config, @@ -151,6 +155,8 @@ class Worker: self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: + assert self.block_size is not None + seq_groups: List[Tuple[List[int], SamplingParams]] = [] input_tokens: List[int] = [] input_positions: List[int] = [] @@ -193,9 +199,6 @@ class Worker: slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - sliding_window = getattr(self.model_config.hf_config, "sliding_window", - float("inf")) - # Add generation tokens. max_context_len = 0 max_num_blocks_per_seq = 0 @@ -216,8 +219,8 @@ class Worker: context_len = seq_data.get_len() position = context_len - 1 - if sliding_window: - context_len = min(context_len, sliding_window) + if self.sliding_window is not None: + context_len = min(context_len, self.sliding_window) input_positions.append(position) block_table = seq_group_metadata.block_tables[seq_id] @@ -232,10 +235,9 @@ class Worker: slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - if sliding_window: - assert self.cache_config is not None - sliding_window_blocks = (sliding_window // - self.cache_config.block_size) + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) block_table = block_table[-sliding_window_blocks:] generation_block_tables.append(block_table) @@ -277,7 +279,7 @@ class Worker: context_lens=context_lens_tensor, max_context_len=max_context_len, block_tables=block_tables_tensor, - sliding_window=sliding_window, + sliding_window=self.sliding_window, ) return tokens_tensor, positions_tensor, input_metadata