Use Transformers helper get_text_config() instead of checking for text_config (#17105)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-25 16:47:35 +01:00 committed by GitHub
parent 0bd7f8fca5
commit 423e9f1cbe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 30 additions and 46 deletions

View File

@ -553,9 +553,8 @@ def main(args: argparse.Namespace):
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
else: else:
if not hasattr(config, "hidden_size"):
# Support for llama4 # Support for llama4
config = config.text_config config = config.get_text_config()
# Default: Mixtral. # Default: Mixtral.
E = config.num_local_experts E = config.num_local_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok

View File

@ -24,10 +24,7 @@ def test_can_initialize(model_arch):
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
hf_config.update(model_info.hf_overrides) hf_config.update(model_info.hf_overrides)
if hasattr(hf_config, "text_config"): text_config = hf_config.get_text_config()
text_config: PretrainedConfig = hf_config.text_config
else:
text_config = hf_config
text_config.update({ text_config.update({
"num_layers": 1, "num_layers": 1,

View File

@ -2841,12 +2841,10 @@ def _get_and_verify_dtype(
) -> torch.dtype: ) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None. # because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None) config_dtype = getattr(config.get_text_config(), "torch_dtype", None)
# Fallbacks for multi-modal models if the root config # Fallback for multi-modal models if the root config
# does not define torch_dtype # does not define torch_dtype
if config_dtype is None and hasattr(config, "text_config"):
config_dtype = getattr(config.text_config, "torch_dtype", None)
if config_dtype is None and hasattr(config, "vision_config"): if config_dtype is None and hasattr(config, "vision_config"):
config_dtype = getattr(config.vision_config, "torch_dtype", None) config_dtype = getattr(config.vision_config, "torch_dtype", None)

View File

@ -760,19 +760,22 @@ def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models. """Get the "sub" config relevant to llm for multi modal models.
No op for pure text models. No op for pure text models.
""" """
if hasattr(config, "text_config"): # This block should be unnecessary after https://github.com/huggingface/transformers/pull/37517
# The code operates under the assumption that text_config should have if hasattr(config, "thinker_config"):
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert hasattr(config.text_config, "num_attention_heads")
return config.text_config
elif hasattr(config, "thinker_config"):
# TODO(suyang.fy): Refactor code. # TODO(suyang.fy): Refactor code.
# For Qwen2.5-Omni, change hf_text_config to # For Qwen2.5-Omni, change hf_text_config to
# thinker_config.text_config. # thinker_config.text_config.
return config.thinker_config.text_config return config.thinker_config.text_config
else:
return config text_config = config.get_text_config()
if text_config is not config:
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert hasattr(text_config, "num_attention_heads")
return text_config
def try_get_generation_config( def try_get_generation_config(

View File

@ -508,13 +508,8 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
logger.warning("Regarding multimodal models, vLLM currently " logger.warning("Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.") "only supports adding LoRA to language model.")
# It's necessary to distinguish between the max_position_embeddings # Use get_text_config() in case of multimodal models
# of VLMs and LLMs. text_config = self.model_config.hf_config.get_text_config()
if hasattr(self.model.config, "max_position_embeddings"):
max_pos_embeddings = self.model.config.max_position_embeddings
else:
max_pos_embeddings = (
self.model.config.text_config.max_position_embeddings)
self.lora_manager = LRUCacheWorkerLoRAManager( self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_seqs,
@ -524,7 +519,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
self.device, self.device,
self.model.embedding_modules, self.model.embedding_modules,
self.model.embedding_padding_modules, self.model.embedding_padding_modules,
max_position_embeddings=max_pos_embeddings, max_position_embeddings=text_config.max_position_embeddings,
) )
self.model = self.lora_manager.create_lora_manager(self.model) self.model = self.lora_manager.create_lora_manager(self.model)

View File

@ -724,14 +724,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
"Bias support in LoRA is not enabled in HPU yet." "Bias support in LoRA is not enabled in HPU yet."
assert not self.lora_config.fully_sharded_loras, \ assert not self.lora_config.fully_sharded_loras, \
"Fully sharded LoRAs is not enabled in HPU yet." "Fully sharded LoRAs is not enabled in HPU yet."
# It's necessary to distinguish between the
# max_position_embeddings of VLMs and LLMs. # Use get_text_config() in case of multimodal models
if hasattr(self.model.config, "max_position_embeddings"): text_config = self.model_config.hf_config.get_text_config()
max_pos_embeddings = (
self.model.config.max_position_embeddings)
else:
max_pos_embeddings = (
self.model.config.text_config.max_position_embeddings)
self.lora_manager = LRUCacheWorkerLoRAManager( self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_seqs,
@ -741,7 +736,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self.device, self.device,
self.model.embedding_modules, self.model.embedding_modules,
self.model.embedding_padding_modules, self.model.embedding_padding_modules,
max_position_embeddings=max_pos_embeddings, max_position_embeddings=text_config.
max_position_embeddings,
) )
self.model = self.lora_manager.create_lora_manager(self.model) self.model = self.lora_manager.create_lora_manager(self.model)

View File

@ -1130,14 +1130,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
logger.warning( logger.warning(
"Regarding multimodal models, vLLM currently " "Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.") "only supports adding LoRA to language model.")
# It's necessary to distinguish between the
# max_position_embeddings of VLMs and LLMs. # Use get_text_config() in case of multimodal models
if hasattr(self.model.config, "max_position_embeddings"): text_config = self.model_config.hf_config.get_text_config()
max_pos_embeddings = (
self.model.config.max_position_embeddings)
else:
max_pos_embeddings = (
self.model.config.text_config.max_position_embeddings)
self.lora_manager = LRUCacheWorkerLoRAManager( self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_seqs,
@ -1147,7 +1142,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.device, self.device,
self.model.embedding_modules, self.model.embedding_modules,
self.model.embedding_padding_modules, self.model.embedding_padding_modules,
max_position_embeddings=max_pos_embeddings, max_position_embeddings=text_config.
max_position_embeddings,
) )
self.model = self.lora_manager.create_lora_manager(self.model) self.model = self.lora_manager.create_lora_manager(self.model)
time_after_load = time.perf_counter() time_after_load = time.perf_counter()