[Model] Support pp for qwen2-vl (#8696)

This commit is contained in:
Yanyi Liu 2024-09-23 21:46:59 +08:00 committed by GitHub
parent 3e83c12b5c
commit a79e522984
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 46 additions and 14 deletions

View File

@ -8,6 +8,8 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
import os
import pytest
from packaging import version
from transformers import __version__ as transformers_version
from vllm.logger import init_logger
@ -37,6 +39,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"),
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"),
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"),
(1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp")
],
)
@fork_new_process_for_each_test
@ -46,6 +49,11 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")
# Skip tests that require transformers>=4.45.0
if "Qwen2-VL" in MODEL_NAME and version.parse(
transformers_version) < version.parse("4.45.0.dev0"):
pytest.skip("This test requires transformers>=4.45.0")
pp_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",

View File

@ -51,6 +51,7 @@ _PP_SUPPORTED_MODELS = [
"Qwen2ForCausalLM",
"Qwen2MoeForCausalLM",
"QWenLMHeadModel",
"Qwen2VLForConditionalGeneration",
]

View File

@ -49,7 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
class Qwen2MLP(nn.Module):
@ -235,11 +235,16 @@ class Qwen2Model(nn.Module):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Qwen2DecoderLayer(config=config,
@ -248,7 +253,10 @@ class Qwen2Model(nn.Module):
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

View File

@ -45,7 +45,7 @@ from vllm.attention import AttentionMetadata
from vllm.attention.selector import (_Backend, backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import parallel_state
from vllm.distributed import get_pp_group, parallel_state
from vllm.distributed import utils as dist_utils
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
@ -68,6 +68,9 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import get_processor
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory)
logger = init_logger(__name__)
# === Vision Inputs === #
@ -856,15 +859,21 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
self.model = Qwen2Model(config, cache_config, quant_config)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def _validate_and_reshape_mm_tensor(self,
mm_input: Union[torch.Tensor,
@ -979,7 +988,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)
if image_input is None and video_input is None:
if (image_input is None
and video_input is None) or not get_pp_group().is_first_rank:
inputs_embeds = None
else:
if getattr(self.config, "rope_scaling", {}).get("type",
@ -1015,6 +1025,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
@ -1055,6 +1066,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
@ -1081,6 +1094,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
except KeyError:
print(params_dict.keys())