Deduplicate Transformers backend code using inheritance (#21461)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-07-24 08:16:23 +01:00 committed by GitHub
parent 6d8d0a24c0
commit dde295a934
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -39,7 +39,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
@ -55,8 +54,8 @@ from vllm.utils import is_list_of
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, maybe_prefix)
flatten_bn, make_empty_intermediate_tensors_factory,
maybe_prefix)
logger = init_logger(__name__)
@ -414,40 +413,40 @@ class ConfigOverride:
setattr(self.config, key, value)
class TransformersModel:
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
logger.info("Using Transformers backend.")
config: PretrainedConfig = vllm_config.model_config.hf_config
cache_config: CacheConfig = vllm_config.cache_config
device_config: DeviceConfig = vllm_config.device_config
model_config: ModelConfig = vllm_config.model_config
parallel_config: ParallelConfig = vllm_config.parallel_config
quant_config: QuantizationConfig = vllm_config.quant_config
self.config = config
self.text_config = config.get_text_config()
self.cache_config = cache_config
self.device_config = device_config
self.model_config = model_config
self.parallel_config = parallel_config
self.quant_config = quant_config
self.config: PretrainedConfig = vllm_config.model_config.hf_config
self.text_config: PretrainedConfig = self.config.get_text_config()
self.cache_config: CacheConfig = vllm_config.cache_config
self.device_config: DeviceConfig = vllm_config.device_config
self.model_config: ModelConfig = vllm_config.model_config
self.parallel_config: ParallelConfig = vllm_config.parallel_config
self.quant_config: QuantizationConfig = vllm_config.quant_config
self.pp_group = get_pp_group()
self.pp_size = self.pp_group.world_size
self.pp_rank = self.pp_group.rank_in_group
self.tp_size = get_tensor_model_parallel_world_size()
# To be updated in child classes for use in `load_weights`
self.skip_prefixes: Optional[list[str]] = None
# vLLM handles interleaved sliding window attention by creating a new
# interleaved_sliding_window attribute and deleting the sliding_window
# attribute. This breaks the constructors in Transformers so we
# temporarily add the attribute back to construct the model.
config_override = nullcontext()
if hasattr(config, "interleaved_sliding_window"):
if hasattr(self.config, "interleaved_sliding_window"):
config_override = ConfigOverride(
config, sliding_window=config.interleaved_sliding_window)
self.config,
sliding_window=self.config.interleaved_sliding_window)
# Set correct attn and init on "meta" to delay allocating GPU tensors
# TODO: @raushan, use the public `model.set_attn_implementation()`
@ -455,23 +454,22 @@ class TransformersModel:
self.text_config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"), config_override:
self.model: PreTrainedModel = AutoModel.from_config(
config,
torch_dtype=model_config.dtype,
trust_remote_code=model_config.trust_remote_code,
self.config,
torch_dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
self.pipeline_parallel()
self.tensor_parallel()
# Input embeddings
text_config = config.get_text_config()
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
self.model.set_input_embeddings(
VocabParallelEmbedding(
text_config.vocab_size,
text_config.hidden_size,
org_num_embeddings=text_config.vocab_size,
quant_config=quant_config,
self.text_config.vocab_size,
self.text_config.hidden_size,
org_num_embeddings=self.text_config.vocab_size,
quant_config=self.quant_config,
))
# Attention layers
@ -481,8 +479,8 @@ class TransformersModel:
self.init_parameters(self.model)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
text_config.hidden_size))
make_empty_intermediate_tensors_factory(
["hidden_states"], self.text_config.hidden_size))
def pipeline_parallel(self):
"""
@ -654,78 +652,40 @@ class TransformersModel:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params = set[str]()
for name, loaded_weight in weights:
# Use "model" instead of base_model_prefix because
# the base model attribute in vLLM is always `model`
if not name.startswith(prefix := "model."):
name = prefix + name
if is_pp_missing_parameter(name, self):
continue
if name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
loader = AutoWeightsLoader(self, skip_prefixes=self.skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
@support_torch_compile
class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
SupportsPP):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it
class TransformersForCausalLM(TransformersBase):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config
quant_config: QuantizationConfig = vllm_config.quant_config
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.config = config
self.transformers_model = TransformersModel(vllm_config=vllm_config,
prefix=prefix)
self.model = self.transformers_model.model
# Tell `TransformersBase.load_weights` to skip
# `lm_head` if the model has tied word embeddings
if self.text_config.tie_word_embeddings:
self.skip_prefixes = ["lm_head."]
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
self.unpadded_vocab_size = self.text_config.vocab_size
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
self.text_config.vocab_size,
self.text_config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
if self.text_config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings())
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, self.text_config.vocab_size,
logit_scale)
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.transformers_model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.transformers_model.forward(input_ids, positions,
intermediate_tensors,
inputs_embeds)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
@ -735,23 +695,12 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
skip_prefixes = ["lm_head."
] if self.config.tie_word_embeddings else None
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights)
@MULTIMODAL_REGISTRY.register_processor(
MultiModalProcessor,
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder)
class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
SupportsPP, SupportsMultiModal):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"]
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
# Backwards compatibility for prev released models. State dicts back then
# had different formats and cannot be loaded with `AutoModel` mapping as is
hf_to_vllm_mapper = WeightsMapper(
@ -776,40 +725,10 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config
quant_config: QuantizationConfig = vllm_config.quant_config
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.config = config
self.dtype = vllm_config.model_config.dtype
self.transformers_model = TransformersModel(vllm_config=vllm_config,
prefix=prefix)
self.model = self.transformers_model.model
text_config = config.get_text_config()
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = text_config.vocab_size
self.lm_head = ParallelLMHead(
text_config.vocab_size,
text_config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if text_config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings())
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
text_config.vocab_size,
logit_scale)
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.transformers_model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: Optional[torch.Tensor],
@ -828,30 +747,10 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
input_ids, multimodal_embeds)
input_ids = None
model_output = self.transformers_model.forward(input_ids, positions,
intermediate_tensors,
inputs_embeds)
model_output = super().forward(input_ids, positions,
intermediate_tensors, inputs_embeds)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=([
"lm_head."
] if self.config.get_text_config().tie_word_embeddings else None),
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_multimodal_embeddings(self, **kwargs):
pixel_values = kwargs.pop("pixel_values", None)
pixel_values = pixel_values if pixel_values is not None else kwargs.pop(