mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 21:25:33 +08:00
Deduplicate Transformers backend code using inheritance (#21461)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
6d8d0a24c0
commit
dde295a934
@ -39,7 +39,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||||
@ -55,8 +54,8 @@ from vllm.utils import is_list_of
|
|||||||
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
|
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
|
||||||
SupportsQuant)
|
SupportsQuant)
|
||||||
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||||
flatten_bn, is_pp_missing_parameter,
|
flatten_bn, make_empty_intermediate_tensors_factory,
|
||||||
make_empty_intermediate_tensors_factory, maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -414,40 +413,40 @@ class ConfigOverride:
|
|||||||
setattr(self.config, key, value)
|
setattr(self.config, key, value)
|
||||||
|
|
||||||
|
|
||||||
class TransformersModel:
|
class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||||
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
embedding_modules = ["embed_tokens"
|
||||||
|
] # TODO transformers will have a util to get it
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
logger.info("Using Transformers backend.")
|
logger.info("Using Transformers backend.")
|
||||||
|
|
||||||
config: PretrainedConfig = vllm_config.model_config.hf_config
|
self.config: PretrainedConfig = vllm_config.model_config.hf_config
|
||||||
cache_config: CacheConfig = vllm_config.cache_config
|
self.text_config: PretrainedConfig = self.config.get_text_config()
|
||||||
device_config: DeviceConfig = vllm_config.device_config
|
self.cache_config: CacheConfig = vllm_config.cache_config
|
||||||
model_config: ModelConfig = vllm_config.model_config
|
self.device_config: DeviceConfig = vllm_config.device_config
|
||||||
parallel_config: ParallelConfig = vllm_config.parallel_config
|
self.model_config: ModelConfig = vllm_config.model_config
|
||||||
quant_config: QuantizationConfig = vllm_config.quant_config
|
self.parallel_config: ParallelConfig = vllm_config.parallel_config
|
||||||
|
self.quant_config: QuantizationConfig = vllm_config.quant_config
|
||||||
self.config = config
|
|
||||||
self.text_config = config.get_text_config()
|
|
||||||
self.cache_config = cache_config
|
|
||||||
self.device_config = device_config
|
|
||||||
self.model_config = model_config
|
|
||||||
self.parallel_config = parallel_config
|
|
||||||
self.quant_config = quant_config
|
|
||||||
|
|
||||||
self.pp_group = get_pp_group()
|
self.pp_group = get_pp_group()
|
||||||
self.pp_size = self.pp_group.world_size
|
self.pp_size = self.pp_group.world_size
|
||||||
self.pp_rank = self.pp_group.rank_in_group
|
self.pp_rank = self.pp_group.rank_in_group
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
# To be updated in child classes for use in `load_weights`
|
||||||
|
self.skip_prefixes: Optional[list[str]] = None
|
||||||
|
|
||||||
# vLLM handles interleaved sliding window attention by creating a new
|
# vLLM handles interleaved sliding window attention by creating a new
|
||||||
# interleaved_sliding_window attribute and deleting the sliding_window
|
# interleaved_sliding_window attribute and deleting the sliding_window
|
||||||
# attribute. This breaks the constructors in Transformers so we
|
# attribute. This breaks the constructors in Transformers so we
|
||||||
# temporarily add the attribute back to construct the model.
|
# temporarily add the attribute back to construct the model.
|
||||||
config_override = nullcontext()
|
config_override = nullcontext()
|
||||||
if hasattr(config, "interleaved_sliding_window"):
|
if hasattr(self.config, "interleaved_sliding_window"):
|
||||||
config_override = ConfigOverride(
|
config_override = ConfigOverride(
|
||||||
config, sliding_window=config.interleaved_sliding_window)
|
self.config,
|
||||||
|
sliding_window=self.config.interleaved_sliding_window)
|
||||||
|
|
||||||
# Set correct attn and init on "meta" to delay allocating GPU tensors
|
# Set correct attn and init on "meta" to delay allocating GPU tensors
|
||||||
# TODO: @raushan, use the public `model.set_attn_implementation()`
|
# TODO: @raushan, use the public `model.set_attn_implementation()`
|
||||||
@ -455,23 +454,22 @@ class TransformersModel:
|
|||||||
self.text_config._attn_implementation = "vllm"
|
self.text_config._attn_implementation = "vllm"
|
||||||
with init_on_device_without_buffers("meta"), config_override:
|
with init_on_device_without_buffers("meta"), config_override:
|
||||||
self.model: PreTrainedModel = AutoModel.from_config(
|
self.model: PreTrainedModel = AutoModel.from_config(
|
||||||
config,
|
self.config,
|
||||||
torch_dtype=model_config.dtype,
|
torch_dtype=self.model_config.dtype,
|
||||||
trust_remote_code=model_config.trust_remote_code,
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.pipeline_parallel()
|
self.pipeline_parallel()
|
||||||
self.tensor_parallel()
|
self.tensor_parallel()
|
||||||
|
|
||||||
# Input embeddings
|
# Input embeddings
|
||||||
text_config = config.get_text_config()
|
|
||||||
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
|
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
|
||||||
self.model.set_input_embeddings(
|
self.model.set_input_embeddings(
|
||||||
VocabParallelEmbedding(
|
VocabParallelEmbedding(
|
||||||
text_config.vocab_size,
|
self.text_config.vocab_size,
|
||||||
text_config.hidden_size,
|
self.text_config.hidden_size,
|
||||||
org_num_embeddings=text_config.vocab_size,
|
org_num_embeddings=self.text_config.vocab_size,
|
||||||
quant_config=quant_config,
|
quant_config=self.quant_config,
|
||||||
))
|
))
|
||||||
|
|
||||||
# Attention layers
|
# Attention layers
|
||||||
@ -481,8 +479,8 @@ class TransformersModel:
|
|||||||
self.init_parameters(self.model)
|
self.init_parameters(self.model)
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
make_empty_intermediate_tensors_factory(
|
||||||
text_config.hidden_size))
|
["hidden_states"], self.text_config.hidden_size))
|
||||||
|
|
||||||
def pipeline_parallel(self):
|
def pipeline_parallel(self):
|
||||||
"""
|
"""
|
||||||
@ -654,78 +652,40 @@ class TransformersModel:
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
params_dict = dict(self.named_parameters())
|
loader = AutoWeightsLoader(self, skip_prefixes=self.skip_prefixes)
|
||||||
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
loaded_params = set[str]()
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
# Use "model" instead of base_model_prefix because
|
|
||||||
# the base model attribute in vLLM is always `model`
|
|
||||||
if not name.startswith(prefix := "model."):
|
|
||||||
name = prefix + name
|
|
||||||
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
if name in params_dict:
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
|
class TransformersForCausalLM(TransformersBase):
|
||||||
SupportsPP):
|
|
||||||
embedding_padding_modules = ["lm_head"]
|
|
||||||
embedding_modules = ["embed_tokens"
|
|
||||||
] # TODO transformers will have a util to get it
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
config: PretrainedConfig = vllm_config.model_config.hf_config
|
|
||||||
quant_config: QuantizationConfig = vllm_config.quant_config
|
|
||||||
|
|
||||||
self.config = config
|
# Tell `TransformersBase.load_weights` to skip
|
||||||
|
# `lm_head` if the model has tied word embeddings
|
||||||
self.transformers_model = TransformersModel(vllm_config=vllm_config,
|
if self.text_config.tie_word_embeddings:
|
||||||
prefix=prefix)
|
self.skip_prefixes = ["lm_head."]
|
||||||
self.model = self.transformers_model.model
|
|
||||||
|
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
self.unpadded_vocab_size = config.vocab_size
|
self.unpadded_vocab_size = self.text_config.vocab_size
|
||||||
self.lm_head = ParallelLMHead(
|
self.lm_head = ParallelLMHead(
|
||||||
config.vocab_size,
|
self.text_config.vocab_size,
|
||||||
config.hidden_size,
|
self.text_config.hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=self.quant_config,
|
||||||
prefix=maybe_prefix(prefix, "lm_head"),
|
prefix=maybe_prefix(prefix, "lm_head"),
|
||||||
)
|
)
|
||||||
if config.tie_word_embeddings:
|
if self.text_config.tie_word_embeddings:
|
||||||
self.lm_head = self.lm_head.tie_weights(
|
self.lm_head = self.lm_head.tie_weights(
|
||||||
self.model.get_input_embeddings())
|
self.model.get_input_embeddings())
|
||||||
|
|
||||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
self.logits_processor = LogitsProcessor(
|
||||||
config.vocab_size,
|
self.unpadded_vocab_size, self.text_config.vocab_size,
|
||||||
logit_scale)
|
logit_scale)
|
||||||
else:
|
else:
|
||||||
self.lm_head = PPMissingLayer()
|
self.lm_head = PPMissingLayer()
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = (
|
|
||||||
self.transformers_model.make_empty_intermediate_tensors)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.Tensor],
|
|
||||||
positions: torch.Tensor,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
||||||
model_output = self.transformers_model.forward(input_ids, positions,
|
|
||||||
intermediate_tensors,
|
|
||||||
inputs_embeds)
|
|
||||||
return model_output
|
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -735,23 +695,12 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
|
|||||||
sampling_metadata)
|
sampling_metadata)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
|
||||||
torch.Tensor]]) -> set[str]:
|
|
||||||
skip_prefixes = ["lm_head."
|
|
||||||
] if self.config.tie_word_embeddings else None
|
|
||||||
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
|
||||||
return loader.load_weights(weights)
|
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
MultiModalProcessor,
|
MultiModalProcessor,
|
||||||
info=MultiModalProcessingInfo,
|
info=MultiModalProcessingInfo,
|
||||||
dummy_inputs=MultiModalDummyInputsBuilder)
|
dummy_inputs=MultiModalDummyInputsBuilder)
|
||||||
class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
|
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
|
||||||
SupportsPP, SupportsMultiModal):
|
|
||||||
embedding_padding_modules = ["lm_head"]
|
|
||||||
embedding_modules = ["embed_tokens"]
|
|
||||||
|
|
||||||
# Backwards compatibility for prev released models. State dicts back then
|
# Backwards compatibility for prev released models. State dicts back then
|
||||||
# had different formats and cannot be loaded with `AutoModel` mapping as is
|
# had different formats and cannot be loaded with `AutoModel` mapping as is
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
@ -776,40 +725,10 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
|
|||||||
})
|
})
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
config: PretrainedConfig = vllm_config.model_config.hf_config
|
|
||||||
quant_config: QuantizationConfig = vllm_config.quant_config
|
|
||||||
|
|
||||||
self.config = config
|
|
||||||
self.dtype = vllm_config.model_config.dtype
|
self.dtype = vllm_config.model_config.dtype
|
||||||
|
|
||||||
self.transformers_model = TransformersModel(vllm_config=vllm_config,
|
|
||||||
prefix=prefix)
|
|
||||||
self.model = self.transformers_model.model
|
|
||||||
text_config = config.get_text_config()
|
|
||||||
|
|
||||||
if get_pp_group().is_last_rank:
|
|
||||||
self.unpadded_vocab_size = text_config.vocab_size
|
|
||||||
self.lm_head = ParallelLMHead(
|
|
||||||
text_config.vocab_size,
|
|
||||||
text_config.hidden_size,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=maybe_prefix(prefix, "lm_head"),
|
|
||||||
)
|
|
||||||
if text_config.tie_word_embeddings:
|
|
||||||
self.lm_head = self.lm_head.tie_weights(
|
|
||||||
self.model.get_input_embeddings())
|
|
||||||
|
|
||||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
|
||||||
text_config.vocab_size,
|
|
||||||
logit_scale)
|
|
||||||
else:
|
|
||||||
self.lm_head = PPMissingLayer()
|
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = (
|
|
||||||
self.transformers_model.make_empty_intermediate_tensors)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
@ -828,30 +747,10 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
|
|||||||
input_ids, multimodal_embeds)
|
input_ids, multimodal_embeds)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
model_output = self.transformers_model.forward(input_ids, positions,
|
model_output = super().forward(input_ids, positions,
|
||||||
intermediate_tensors,
|
intermediate_tensors, inputs_embeds)
|
||||||
inputs_embeds)
|
|
||||||
return model_output
|
return model_output
|
||||||
|
|
||||||
def compute_logits(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
||||||
sampling_metadata)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
|
||||||
torch.Tensor]]) -> set[str]:
|
|
||||||
loader = AutoWeightsLoader(
|
|
||||||
self,
|
|
||||||
skip_prefixes=([
|
|
||||||
"lm_head."
|
|
||||||
] if self.config.get_text_config().tie_word_embeddings else None),
|
|
||||||
)
|
|
||||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs):
|
def get_multimodal_embeddings(self, **kwargs):
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
pixel_values = pixel_values if pixel_values is not None else kwargs.pop(
|
pixel_values = pixel_values if pixel_values is not None else kwargs.pop(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user