[Model] Composite weight loading for multimodal Qwen2 (#10944)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-07 22:22:52 +08:00 committed by GitHub
parent b26b4cd03c
commit bf0e382e16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 148 additions and 206 deletions

View File

@ -2472,7 +2472,15 @@ class VllmConfig:
return quant_config return quant_config
return None return None
def with_hf_config(self, hf_config: PretrainedConfig) -> "VllmConfig": def with_hf_config(
self,
hf_config: PretrainedConfig,
architectures: Optional[list[str]] = None,
) -> "VllmConfig":
if architectures is not None:
hf_config = copy.deepcopy(hf_config)
hf_config.architectures = architectures
model_config = copy.deepcopy(self.model_config) model_config = copy.deepcopy(self.model_config)
model_config.hf_config = hf_config model_config.hf_config = hf_config

View File

@ -101,12 +101,10 @@ def _initialize_model(
vllm_config: VllmConfig, vllm_config: VllmConfig,
*, *,
prefix: str = "", prefix: str = "",
architectures: Optional[list[str]] = None,
) -> nn.Module: ) -> nn.Module:
"""Initialize a model with the given configurations.""" """Initialize a model with the given configurations."""
model_config = vllm_config.model_config model_config = vllm_config.model_config
model_class, _ = get_model_architecture(model_config, model_class, _ = get_model_architecture(model_config)
architectures=architectures)
signatures = inspect.signature(model_class.__init__) signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()] all_params = [param.name for param in signatures.parameters.values()]

View File

@ -1,6 +1,6 @@
"""Utilities for selecting and loading models.""" """Utilities for selecting and loading models."""
import contextlib import contextlib
from typing import Optional, Tuple, Type from typing import Tuple, Type
import torch import torch
from torch import nn from torch import nn
@ -20,12 +20,8 @@ def set_default_torch_dtype(dtype: torch.dtype):
def get_model_architecture( def get_model_architecture(
model_config: ModelConfig, model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
*, architectures = getattr(model_config.hf_config, "architectures", [])
architectures: Optional[list[str]] = None,
) -> Tuple[Type[nn.Module], str]:
if architectures is None:
architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.

View File

@ -444,14 +444,17 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model = Qwen2Model(vllm_config=vllm_config, self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
if config.tie_word_embeddings: if get_pp_group().is_last_rank:
self.lm_head = self.model.embed_tokens if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else: else:
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = PPMissingLayer()
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler() self.sampler = get_sampler()

View File

@ -19,7 +19,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" """Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from functools import lru_cache from functools import cached_property, lru_cache
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union) Union)
@ -34,12 +34,7 @@ from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.inputs import NestedTensors
@ -47,15 +42,11 @@ from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import merge_multimodal_embeddings from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
# # === Audio Inputs === # # # === Audio Inputs === #
class Qwen2AudioInputs(TypedDict): class Qwen2AudioInputs(TypedDict):
@ -281,25 +272,23 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
self.quant_config = quant_config self.quant_config = quant_config
self.language_model = Qwen2Model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config.with_hf_config(config.text_config), vllm_config=vllm_config,
prefix=prefix) hf_config=config.text_config,
self.unpadded_vocab_size = config.text_config.vocab_size prefix=maybe_prefix(prefix, "language_model"),
if config.text_config.tie_word_embeddings: architectures=["Qwen2ForCausalLM"],
self.lm_head = self.language_model.embed_tokens )
else:
self.lm_head = ParallelLMHead(config.text_config.vocab_size,
config.text_config.hidden_size,
quant_config=quant_config)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.text_config.vocab_size,
logit_scale)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _validate_and_reshape_mm_tensor(self, def _validate_and_reshape_mm_tensor(self,
mm_input: Union[torch.Tensor, mm_input: Union[torch.Tensor,
List[torch.Tensor]], List[torch.Tensor]],
@ -414,72 +403,30 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings) multimodal_embeddings)
input_ids = None input_ids = None
hidden_states = self.language_model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
logits = self.logits_processor(self.lm_head, hidden_states, hidden_states: torch.Tensor,
sampling_metadata) sampling_metadata: SamplingMetadata,
return logits ) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample( def sample(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ loader = AutoWeightsLoader(self)
# (param_name, shard_name, shard_id) return loader.load_weights(weights)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if (self.config.text_config.tie_word_embeddings
and "lm_head.weight" in name):
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name or 'audio' in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

View File

@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import partial from functools import cached_property, partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Set, Tuple, Type, TypedDict, Union) Optional, Set, Tuple, Type, TypedDict, Union)
@ -40,7 +40,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
@ -49,15 +49,12 @@ from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import ( from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig) GPTQMarlinConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict, from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict,
@ -69,9 +66,8 @@ from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (PPMissingLayer, get_vit_attn_backend, from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
is_pp_missing_parameter, init_vllm_registered_model, maybe_prefix)
make_empty_intermediate_tensors_factory, maybe_prefix)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -506,6 +502,8 @@ class Qwen2VisionTransformer(nn.Module):
mlp_ratio: float = vision_config.mlp_ratio mlp_ratio: float = vision_config.mlp_ratio
self.spatial_merge_size = spatial_merge_size self.spatial_merge_size = spatial_merge_size
self.num_heads = num_heads
self.embed_dim = embed_dim
self.patch_embed = Qwen2VisionPatchEmbed( self.patch_embed = Qwen2VisionPatchEmbed(
patch_size=patch_size, patch_size=patch_size,
@ -595,6 +593,53 @@ class Qwen2VisionTransformer(nn.Module):
x = self.merger(x) x = self.merger(x)
return x return x
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if name.endswith("qkv.weight"):
visual_num_heads = self.num_heads
visual_embed_dim = self.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads,
head_size,
visual_embed_dim)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
elif name.endswith("qkv.bias"):
visual_num_heads = self.num_heads
visual_embed_dim = self.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads,
head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# === Vision input helpers === # # === Vision input helpers === #
@ -1082,27 +1127,21 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
) )
self.model = Qwen2Model(vllm_config=vllm_config, self.language_model = init_vllm_registered_model(
prefix=maybe_prefix(prefix, "model")) vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
if get_pp_group().is_last_rank: architectures=["Qwen2ForCausalLM"],
if config.tie_word_embeddings: )
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory( self.language_model.make_empty_intermediate_tensors)
["hidden_states", "residual"], config.hidden_size))
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ # GPTQ configs do not have a list of ignored modules, however AutoGPTQ
@ -1261,7 +1300,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[List[Tuple[NestedTensors, multimodal_embeddings: Optional[List[Tuple[NestedTensors,
str]]] = None, str]]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
for embeddings, modality in multimodal_embeddings: for embeddings, modality in multimodal_embeddings:
if modality == "image": if modality == "image":
@ -1330,7 +1369,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings) multimodal_embeddings)
input_ids = None input_ids = None
hidden_states = self.model( hidden_states = self.language_model.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches, kv_caches=kv_caches,
@ -1340,80 +1379,28 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
) )
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
logits = self.logits_processor(self.lm_head, hidden_states, hidden_states: torch.Tensor,
sampling_metadata) sampling_metadata: SamplingMetadata,
return logits ) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample( def sample(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ hf_to_vllm_mapper = WeightsMapper(
# (param_name, shard_name, shard_id) orig_to_new_prefix={
("qkv_proj", "q_proj", "q"), "lm_head.": "language_model.lm_head.",
("qkv_proj", "k_proj", "k"), "model.": "language_model.model.",
("qkv_proj", "v_proj", "v"), })
("gate_up_proj", "up_proj", 1),
("gate_up_proj", "gate_proj", 0),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "visual" in name and name.endswith("qkv.weight"):
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads,
head_size,
visual_embed_dim)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
elif "visual" in name and name.endswith("qkv.bias"):
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads,
head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)
try:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
except KeyError:
raise ValueError(f"Unexpected weight: {name}") from None
weight_loader = getattr(param, "weight_loader", loader = AutoWeightsLoader(self)
default_weight_loader) return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

View File

@ -17,7 +17,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
from vllm.platforms import _Backend, current_platform from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available, print_warning_once
logger = init_logger(__name__) logger = init_logger(__name__)
@ -251,12 +251,15 @@ def init_vllm_registered_model(
""" """
from vllm.model_executor.model_loader.loader import _initialize_model from vllm.model_executor.model_loader.loader import _initialize_model
if hf_config is not None: if hf_config is None and architectures is not None:
vllm_config = vllm_config.with_hf_config(hf_config) # So that the architectures field is overridden
hf_config = vllm_config.model_config.hf_config
return _initialize_model(vllm_config=vllm_config, if hf_config is not None:
prefix=prefix, vllm_config = vllm_config.with_hf_config(hf_config,
architectures=architectures) architectures=architectures)
return _initialize_model(vllm_config=vllm_config, prefix=prefix)
@overload @overload
@ -592,7 +595,7 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
if is_flash_attn_2_available(): if is_flash_attn_2_available():
selected_backend = _Backend.FLASH_ATTN selected_backend = _Backend.FLASH_ATTN
else: else:
logger.warning( print_warning_once(
"Current `vllm-flash-attn` has a bug inside vision module, " "Current `vllm-flash-attn` has a bug inside vision module, "
"so we use xformers backend instead. You can run " "so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend.") "`pip install flash-attn` to use flash-attention backend.")