diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index ba01f2309b3ca..56ea8c5d8372b 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -73,7 +73,7 @@ The Transformers fallback explicitly supports the following features: - (except GGUF) - -- (pipeline parallel coming soon !) +- (requires `transformers>=4.49.0`) #### Remote code diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 1342f0da29d89..e757db45c8cf5 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -175,6 +175,8 @@ TEXT_GENERATION_MODELS = { "inceptionai/jais-13b-chat": PPTestSettings.fast(), "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(), "meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(), + # Tests TransformersModel + "ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(), "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(), "openbmb/MiniCPM3-4B": PPTestSettings.fast(), # Uses Llama @@ -243,6 +245,7 @@ TEST_MODELS = [ # [LANGUAGE GENERATION] "microsoft/Phi-3.5-MoE-instruct", "meta-llama/Llama-3.2-1B-Instruct", + # "ArthurZ/Ilama-3.2-1B", NOTE: Uncomment after #13905 "ibm/PowerLM-3b", # [LANGUAGE EMBEDDING] "intfloat/e5-mistral-7b-instruct", diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index be788d6320029..fe6a9d7a4aa43 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -15,21 +15,25 @@ # limitations under the License. """Wrapper around `transformers` models""" import re +from itertools import chain from typing import Iterable, Literal, Optional, Union import torch from torch import nn -from transformers import AutoModel, PreTrainedModel +from transformers import AutoModel, PretrainedConfig, PreTrainedModel from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from vllm.attention import Attention -from vllm.config import VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, VllmConfig) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed.utils import get_pp_indices from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -37,8 +41,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsQuant -from .utils import maybe_prefix +from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, maybe_prefix) logger = init_logger(__name__) @@ -53,7 +58,7 @@ def vllm_flash_attention_forward( # Transformers kwargs scaling: Optional[float] = None, # vLLM kwargs - attention_instances: Optional[list[Attention]] = None, + attention_instances: Optional[dict[Attention]] = None, **kwargs): self_attn = attention_instances[module.layer_idx] if scaling is not None: @@ -72,13 +77,12 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): def replace_linear_class( - linear: nn.Linear, - style: Literal["colwise", "rowwise"], - quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]: + linear: nn.Linear, style: Literal["colwise", "rowwise"], + quant_config: QuantizationConfig +) -> Union[ColumnParallelLinear, RowParallelLinear]: """ Replace nn.Linear with one of vLLM's tensor parallel linear classes. - `quant_config` is not yet supported. Args: linear (nn.Linear): `nn.Linear` to be replaced. style (str): Tensor parallel style of the new linear, e.g. "colwise". @@ -105,7 +109,7 @@ def replace_linear_class( ) -class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA): +class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): embedding_padding_modules = ["lm_head"] embedding_modules = ["embed_tokens" ] # TODO transformers will have a util to get it @@ -114,31 +118,175 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA): super().__init__() logger.info("Using Transformers backend.") - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - model_config = vllm_config.model_config - parallel_config = vllm_config.parallel_config + config: PretrainedConfig = vllm_config.model_config.hf_config + cache_config: CacheConfig = vllm_config.cache_config + device_config: DeviceConfig = vllm_config.device_config + model_config: ModelConfig = vllm_config.model_config + parallel_config: ParallelConfig = vllm_config.parallel_config + quant_config: QuantizationConfig = vllm_config.quant_config self.config = config + self.cache_config = cache_config + self.device_config = device_config + self.model_config = model_config + self.parallel_config = parallel_config + self.quant_config = quant_config + self.vocab_size = model_config.get_vocab_size() self.unpadded_vocab_size = model_config.get_vocab_size() - self.model: PreTrainedModel = AutoModel.from_config( - self.config, - attn_implementation="vllm", - torch_dtype=vllm_config.model_config.dtype, - trust_remote_code=vllm_config.model_config.trust_remote_code, - ) + self.pp_group = get_pp_group() + self.pp_size = self.pp_group.world_size + self.pp_rank = self.pp_group.rank_in_group + self.tp_size = get_tensor_model_parallel_world_size() + + # Use meta device to delay allocating GPU tensors + with torch.device("meta"): + self.model: PreTrainedModel = AutoModel.from_config( + config, + attn_implementation="vllm", + torch_dtype=model_config.dtype, + trust_remote_code=model_config.trust_remote_code, + ) prefix = self.model.base_model_prefix - # MLP modifications - self.apply_base_model_tp_plan(self.model) + self.pipeline_parallel() + self.tensor_parallel() - # Attention modifications (assumes 1 attention op per hidden layer) - num_heads = model_config.get_num_attention_heads(parallel_config) - head_size = model_config.get_head_size() - num_kv_heads = model_config.get_num_kv_heads(parallel_config) - self.attention_instances = [ + # Input embeddings + if not isinstance(self.model.get_input_embeddings(), PPMissingLayer): + self.model.set_input_embeddings( + VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + )) + + # Attention layers + self.attention_instances = self.create_attention_instances() + + # Output embeddings + if not isinstance(getattr(self, "lm_head", None), PPMissingLayer): + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.get_input_embeddings()) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + + # Initialize buffers (e.g. rotary embedding inverse frequency) + self.init_buffers(self.model) + + # Move remaining meta tensors to device (should happen last) + self.meta_to_empty(self.model) + + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def pipeline_parallel(self): + """ + Apply the model's pipeline parallelization plan. + """ + if self.pp_size <= 1: + return + + if not self.model.supports_pp_plan: + raise ValueError( + f"{type(self.model)} does not support pipeline parallel yet!") + + module_lists = [] + module_list_idx = None + pp_plan = list(self.model._pp_plan.keys()) + for i, name in enumerate(pp_plan): + if isinstance(getattr(self.model, name), nn.ModuleList): + module_lists.append(name) + module_list_idx = i + + if len(module_lists) > 1: + raise ValueError( + "Pipeline parallel of models with multiple `ModuleList`s " + "in the base model are not supported yet!") + if module_list_idx is None: + raise ValueError( + f"Could not find `ModuleList` in {type(self.model)}") + + # Layers before module list + for name in pp_plan[:module_list_idx]: + if self.pp_group.is_first_rank or (self.config.tie_word_embeddings + and self.pp_group.is_last_rank): + continue + setattr(self.model, name, PPMissingLayer()) + + # Module list + start_layer, end_layer = get_pp_indices(self.config.num_hidden_layers, + self.pp_rank, self.pp_size) + layers_name = pp_plan[module_list_idx] + layers = getattr(self.model, layers_name) + for i in range(len(layers)): + if start_layer <= i and i < end_layer: + continue + layers[i] = PPMissingLayer(return_tuple=True) + + # Layers after module list + for name in pp_plan[module_list_idx + 1:]: + # Modules that should be on last rank + if not self.pp_group.is_last_rank: + setattr(self.model, name, PPMissingLayer()) + + if not self.pp_group.is_last_rank: + self.lm_head = PPMissingLayer() + + def tensor_parallel(self): + """ + Apply the model's tensor parallelization plan. + Currently only supports linear layers. + """ + if self.tp_size > 1 and self.config.base_model_tp_plan is None: + raise ValueError( + f"{type(self.model)} does not support tensor parallel yet!") + + tp_plan = self.model._tp_plan + + def _tensor_parallel(module: nn.Module, prefix: str = ""): + for child_name, child_module in module.named_children(): + qual_name = maybe_prefix(prefix, child_name) + for pattern, style in tp_plan.items(): + if re.match(pattern, qual_name) and isinstance( + child_module, nn.Linear): + new_module = replace_linear_class( + child_module, style, self.quant_config) + setattr(module, child_name, new_module) + log_replacement(qual_name, child_module, new_module) + else: + _tensor_parallel(child_module, prefix=qual_name) + + _tensor_parallel(self.model) + + def create_attention_instances(self) -> dict[int, Attention]: + """ + Create `Attention` instances to inform KV cache allocation. + """ + num_heads = self.model_config.get_num_attention_heads( + self.parallel_config) + head_size = self.model_config.get_head_size() + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + start, end = get_pp_indices(self.config.num_hidden_layers, + self.pp_rank, self.pp_size) + return { + i: Attention( num_heads=num_heads, head_size=head_size, @@ -146,77 +294,70 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA): # Transformers, it's updated in vllm_flash_attention_forward scale=head_size**-0.5, num_kv_heads=num_kv_heads, - cache_config=cache_config, + cache_config=self.cache_config, quant_config=self.quant_config, - prefix=f"{i}.attn") for i in range(config.num_hidden_layers) - ] + prefix=f"{i}.attn") + for i in range(start, end) + } - # Model modifications - self.replace_vocab_embed_class(self.model) - - # ForCausalLM modifications - self.lm_head = ParallelLMHead(self.vocab_size, - config.hidden_size, - quant_config=self.quant_config, - prefix=maybe_prefix(prefix, "lm_head")) - if config.tie_word_embeddings: - self.lm_head.weight = self.model.get_input_embeddings().weight - - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - self.vocab_size, logit_scale) - self.sampler = get_sampler() - - def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""): + def init_buffers(self, module: nn.Module): """ - Apply the base model tensor parallelization plan to a module. - Currently only supports linear layers. + If a `buffer` is on the `meta` device, then its parent + `module` is the original module created by: + + ```python + with torch.device("meta"): + self.model: PreTrainedModel = AutoModel.from_config(...) + ``` + + This means that: + - `type(module)` is a class from `transformers` + - This class is constructed using a `PretrainedConfig` """ - if (self.config.base_model_tp_plan is None - and get_tensor_model_parallel_world_size() > 1): - raise ValueError( - "Trying to run tensor parallelization but the model does not " - "support it yet!") + for name, buffer in module.named_buffers(recurse=False): + if buffer.device == torch.device("meta"): + new_buffer = getattr(type(module)(self.config), name) + setattr(module, name, new_buffer) + for child in module.children(): + self.init_buffers(child) - for child_name, child_module in module.named_children(): - qual_name = maybe_prefix(prefix, child_name) - for pattern, style in self.config.base_model_tp_plan.items(): - if re.match(pattern, qual_name) and isinstance( - child_module, nn.Linear): - new_module = replace_linear_class(child_module, style, - self.quant_config) - setattr(module, child_name, new_module) - log_replacement(qual_name, child_module, new_module) - else: - self.apply_base_model_tp_plan(child_module, prefix=qual_name) - - def replace_vocab_embed_class(self, module: nn.Module): - # Use native set input embeddings - new_module = VocabParallelEmbedding( - self.vocab_size, - self.config.hidden_size, - org_num_embeddings=self.vocab_size, - quant_config=None, - ) - log_replacement("input embedding", self.model.get_input_embeddings(), - new_module) - module.set_input_embeddings(new_module) + def meta_to_empty(self, module: nn.Module): + tensors = list(chain(module.buffers(), module.parameters())) + if tensors and all(t.device == torch.device("meta") for t in tensors): + module.to_empty(device=self.device_config.device) + return # We can stop recursing because to_empty is recursive + for child in module.children(): + self.meta_to_empty(child) def forward( self, - input_ids: torch.Tensor, + input_ids: Optional[torch.Tensor], positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model( - input_ids[None, ...], + if not get_pp_group().is_first_rank: + assert intermediate_tensors is not None + input_ids = None + inputs_embeds = intermediate_tensors["hidden_states"] + + if input_ids is not None: + input_ids = input_ids[None, ...] + if inputs_embeds is not None: + inputs_embeds = inputs_embeds[None, ...] + + hidden_states = self.model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, use_cache=False, position_ids=positions[None, ...], - intermediate_tensors=intermediate_tensors, attention_instances=self.attention_instances, return_dict=False)[0][0, ...] # we remove batch dimension for now - return model_output + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + return hidden_states def compute_logits( self, @@ -238,8 +379,11 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA): params_dict = dict(self.named_parameters()) loaded_params = set[str]() for name, loaded_weight in weights: - if name not in params_dict: - name = f"{self.model.base_model_prefix}.{name}" + # Necessary for some models which use remote code + if not name.startswith(prefix := self.model.base_model_prefix): + name = maybe_prefix(prefix, name) + if is_pp_missing_parameter(name, self): + continue if name in params_dict: param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index a705aeffef35a..1e3d78c7f6fd7 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -472,6 +472,16 @@ class PPMissingLayer(torch.nn.Identity): def __init__(self, *args, **kwargs): super().__init__() + self.return_tuple = kwargs.get("return_tuple", False) + + def forward(self, *args, **kwargs): + """ + Return the first arg from args or the first value from kwargs. + + Wraps the input in a tuple if `self.return_tuple` is True. + """ + input = args[0] if args else next(iter(kwargs.values())) + return (input, ) if self.return_tuple else input _CPU_OFFLOAD_BYTES = 0 @@ -650,4 +660,4 @@ def cast_overflow_tensors( if tensors.isinf().any() or tensors.isnan().any(): clamp_value = torch.finfo(tensors.dtype).max - offset tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value) - return tensors \ No newline at end of file + return tensors