[Bugfix] Fix Fuyu tensor parallel inference (#8986)

This commit is contained in:
Isotr0py 2024-10-01 17:51:41 +08:00 committed by GitHub
parent 82f3937e59
commit bc4eb65b54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 12 deletions

View File

@ -37,7 +37,9 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"), (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"),
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"), (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"),
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"), (1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"),
(1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp") (1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp"),
# TP only models
(2, 1, 1, 0, 0, "adept/fuyu-8b", "mp"),
], ],
) )
@fork_new_process_for_each_test @fork_new_process_for_each_test

View File

@ -237,8 +237,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
self.image_feature_size, self.image_feature_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
gather_output=True,
) )
self.language_model = PersimmonForCausalLM(config, self.language_model = PersimmonForCausalLM(config.text_config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)

View File

@ -25,11 +25,11 @@ from typing import Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PersimmonConfig from transformers import PersimmonConfig
from transformers.activations import ReLUSquaredActivation
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
@ -57,7 +57,7 @@ class PersimmonMLP(nn.Module):
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
self.act = ReLUSquaredActivation() self.act = get_act_fn(config.hidden_act, quant_config)
def forward(self, hidden_states) -> torch.Tensor: def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.dense_h_to_4h(hidden_states) hidden_states, _ = self.dense_h_to_4h(hidden_states)
@ -96,7 +96,7 @@ class PersimmonAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
self.num_heads * self.head_dim, self.total_num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
@ -213,10 +213,10 @@ class PersimmonModel(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.vocab_size = config.text_config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.text_config.vocab_size, config.hidden_size) config.hidden_size)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
PersimmonDecoderLayer(config, PersimmonDecoderLayer(config,
cache_config=cache_config, cache_config=cache_config,
@ -252,19 +252,19 @@ class PersimmonModel(nn.Module):
class PersimmonForCausalLM(nn.Module): class PersimmonForCausalLM(nn.Module):
def __init__(self, def __init__(self,
config, config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.vocab_size = config.text_config.vocab_size self.vocab_size = config.vocab_size
self.model = PersimmonModel(config, self.model = PersimmonModel(config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.lm_head = ParallelLMHead(config.text_config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
bias=False) bias=False)
self.logits_processor = LogitsProcessor(config.text_config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
def forward( def forward(