diff --git a/vllm/utils.py b/vllm/utils.py index 516b33dca1dc8..77f4e2dcf5e45 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -61,7 +61,7 @@ import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import ModelConfig, VllmConfig logger = init_logger(__name__) @@ -2498,6 +2498,18 @@ def cprofile(save_file: Optional[str] = None, enabled: bool = True): return decorator +# Only relevant for models using ALiBi (e.g, MPT) +def check_use_alibi(model_config: ModelConfig) -> bool: + return (getattr(model_config.hf_text_config, "alibi", False) # Falcon + or ("BloomForCausalLM" in getattr(model_config.hf_config, + "architectures", [])) # Bloom + or getattr(model_config.hf_text_config, "position_encoding_type", + "") == "alibi" # codellm_1b_alibi + or + (hasattr(model_config.hf_text_config, "attn_config") # MPT + and model_config.hf_text_config.attn_config.get("alibi", False))) + + def sha256(input) -> int: """Hash any picklable Python object using SHA-256. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bcf7762b44496..230479f3f15e7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -25,7 +25,7 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - LayerBlockType, LazyLoader, cdiv, + LayerBlockType, LazyLoader, cdiv, check_use_alibi, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget @@ -223,6 +223,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): device="cpu", pin_memory=self.pin_memory) + # Only relevant for models using ALiBi (e.g, MPT) + self.use_alibi = check_use_alibi(model_config) + self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, @@ -689,7 +692,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): query_lens=num_scheduled_tokens, num_query_heads=self.num_query_heads, num_kv_heads=self.num_kv_heads, - use_alibi=False, # FIXME + use_alibi=self.use_alibi, use_sliding_window=self.window_size is not None, num_sms=self.num_sms, )