mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:25:45 +08:00
[Model] Support pp for qwen2-vl (#8696)
This commit is contained in:
parent
3e83c12b5c
commit
a79e522984
@ -8,6 +8,8 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from packaging import version
|
||||||
|
from transformers import __version__ as transformers_version
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
@ -37,6 +39,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
|||||||
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"),
|
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"),
|
||||||
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"),
|
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"),
|
||||||
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"),
|
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"),
|
||||||
|
(1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp")
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@fork_new_process_for_each_test
|
@fork_new_process_for_each_test
|
||||||
@ -46,6 +49,11 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
|
|||||||
pytest.skip("Skipping multi-node pipeline parallel test for "
|
pytest.skip("Skipping multi-node pipeline parallel test for "
|
||||||
"multiprocessing distributed backend")
|
"multiprocessing distributed backend")
|
||||||
|
|
||||||
|
# Skip tests that require transformers>=4.45.0
|
||||||
|
if "Qwen2-VL" in MODEL_NAME and version.parse(
|
||||||
|
transformers_version) < version.parse("4.45.0.dev0"):
|
||||||
|
pytest.skip("This test requires transformers>=4.45.0")
|
||||||
|
|
||||||
pp_args = [
|
pp_args = [
|
||||||
# use half precision for speed and memory savings in CI environment
|
# use half precision for speed and memory savings in CI environment
|
||||||
"--dtype",
|
"--dtype",
|
||||||
|
|||||||
@ -51,6 +51,7 @@ _PP_SUPPORTED_MODELS = [
|
|||||||
"Qwen2ForCausalLM",
|
"Qwen2ForCausalLM",
|
||||||
"Qwen2MoeForCausalLM",
|
"Qwen2MoeForCausalLM",
|
||||||
"QWenLMHeadModel",
|
"QWenLMHeadModel",
|
||||||
|
"Qwen2VLForConditionalGeneration",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -49,7 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA
|
from .interfaces import SupportsLoRA
|
||||||
from .utils import is_pp_missing_parameter, make_layers
|
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MLP(nn.Module):
|
class Qwen2MLP(nn.Module):
|
||||||
@ -235,11 +235,16 @@ class Qwen2Model(nn.Module):
|
|||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
||||||
config.vocab_size,
|
and get_pp_group().is_last_rank):
|
||||||
config.hidden_size,
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
quant_config=quant_config,
|
config.vocab_size,
|
||||||
)
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.embed_tokens = PPMissingLayer()
|
||||||
|
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: Qwen2DecoderLayer(config=config,
|
lambda prefix: Qwen2DecoderLayer(config=config,
|
||||||
@ -248,7 +253,10 @@ class Qwen2Model(nn.Module):
|
|||||||
prefix=f"{prefix}.layers",
|
prefix=f"{prefix}.layers",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
if get_pp_group().is_last_rank:
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
else:
|
||||||
|
self.norm = PPMissingLayer()
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.embed_tokens(input_ids)
|
return self.embed_tokens(input_ids)
|
||||||
|
|||||||
@ -45,7 +45,7 @@ from vllm.attention import AttentionMetadata
|
|||||||
from vllm.attention.selector import (_Backend, backend_name_to_enum,
|
from vllm.attention.selector import (_Backend, backend_name_to_enum,
|
||||||
get_global_forced_attn_backend)
|
get_global_forced_attn_backend)
|
||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
from vllm.distributed import parallel_state
|
from vllm.distributed import get_pp_group, parallel_state
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -68,6 +68,9 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.sequence import IntermediateTensors, SequenceData
|
from vllm.sequence import IntermediateTensors, SequenceData
|
||||||
from vllm.transformers_utils.processor import get_processor
|
from vllm.transformers_utils.processor import get_processor
|
||||||
|
|
||||||
|
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||||
|
make_empty_intermediate_tensors_factory)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# === Vision Inputs === #
|
# === Vision Inputs === #
|
||||||
@ -856,15 +859,21 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
|
|
||||||
self.model = Qwen2Model(config, cache_config, quant_config)
|
self.model = Qwen2Model(config, cache_config, quant_config)
|
||||||
|
|
||||||
if config.tie_word_embeddings:
|
if get_pp_group().is_last_rank:
|
||||||
self.lm_head = self.model.embed_tokens
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.model.embed_tokens
|
||||||
|
else:
|
||||||
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config)
|
||||||
else:
|
else:
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
self.lm_head = PPMissingLayer()
|
||||||
config.hidden_size,
|
|
||||||
quant_config=quant_config)
|
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
make_empty_intermediate_tensors_factory(
|
||||||
|
["hidden_states", "residual"], config.hidden_size))
|
||||||
|
|
||||||
def _validate_and_reshape_mm_tensor(self,
|
def _validate_and_reshape_mm_tensor(self,
|
||||||
mm_input: Union[torch.Tensor,
|
mm_input: Union[torch.Tensor,
|
||||||
@ -979,7 +988,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
video_input = self._parse_and_validate_video_input(**kwargs)
|
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||||
|
|
||||||
if image_input is None and video_input is None:
|
if (image_input is None
|
||||||
|
and video_input is None) or not get_pp_group().is_first_rank:
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
else:
|
else:
|
||||||
if getattr(self.config, "rope_scaling", {}).get("type",
|
if getattr(self.config, "rope_scaling", {}).get("type",
|
||||||
@ -1015,6 +1025,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
kv_caches=kv_caches,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@ -1055,6 +1066,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
@ -1081,6 +1094,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print(params_dict.keys())
|
print(params_dict.keys())
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user