[Model] Support math-shepherd-mistral-7b-prm model (#9697)

Signed-off-by: Went-Liang <wenteng_liang@163.com>
This commit is contained in:
Went-Liang 2024-10-31 00:33:42 +08:00 committed by GitHub
parent cc98f1e079
commit 81f09cfd80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 312 additions and 62 deletions

View File

@ -112,38 +112,58 @@ class ModelConfig:
Defaults to 'auto' which defaults to 'hf'. Defaults to 'auto' which defaults to 'hf'.
mm_processor_kwargs: Arguments to be forwarded to the model's processor mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor. for multi-modal data, e.g., image processor.
pooling_type: Used to configure the pooling method in the embedding
model.
pooling_norm: Used to determine whether to normalize the pooled
data in the embedding model.
pooling_softmax: Used to determine whether to softmax the pooled
data in the embedding model.
pooling_step_tag_id: When pooling_step_tag_id is not -1, it indicates
that the score corresponding to the pooling_step_tag_id in the
generated sentence should be returned. Otherwise, it returns
the scores for all tokens.
pooling_returned_token_ids: pooling_returned_token_ids represents a
list of indices for the vocabulary dimensions to be extracted,
such as the token IDs of good_token and bad_token in the
math-shepherd-mistral-7b-prm model.
""" """
def __init__(self, def __init__(
model: str, self,
task: Union[TaskOption, _Task], model: str,
tokenizer: str, task: Union[TaskOption, _Task],
tokenizer_mode: str, tokenizer: str,
trust_remote_code: bool, tokenizer_mode: str,
dtype: Union[str, torch.dtype], trust_remote_code: bool,
seed: int, dtype: Union[str, torch.dtype],
revision: Optional[str] = None, seed: int,
code_revision: Optional[str] = None, revision: Optional[str] = None,
rope_scaling: Optional[dict] = None, code_revision: Optional[str] = None,
rope_theta: Optional[float] = None, rope_scaling: Optional[dict] = None,
tokenizer_revision: Optional[str] = None, rope_theta: Optional[float] = None,
max_model_len: Optional[int] = None, tokenizer_revision: Optional[str] = None,
spec_target_max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
quantization: Optional[str] = None, spec_target_max_model_len: Optional[int] = None,
quantization_param_path: Optional[str] = None, quantization: Optional[str] = None,
enforce_eager: Optional[bool] = None, quantization_param_path: Optional[str] = None,
max_context_len_to_capture: Optional[int] = None, enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: Optional[int] = None, max_context_len_to_capture: Optional[int] = None,
max_logprobs: int = 20, max_seq_len_to_capture: Optional[int] = None,
disable_sliding_window: bool = False, max_logprobs: int = 20,
skip_tokenizer_init: bool = False, disable_sliding_window: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None, skip_tokenizer_init: bool = False,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None, served_model_name: Optional[Union[str, List[str]]] = None,
use_async_output_proc: bool = True, limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
override_neuron_config: Optional[Dict[str, Any]] = None, use_async_output_proc: bool = True,
config_format: ConfigFormat = ConfigFormat.AUTO, override_neuron_config: Optional[Dict[str, Any]] = None,
chat_template_text_format: str = "string", config_format: ConfigFormat = ConfigFormat.AUTO,
mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None: chat_template_text_format: str = "string",
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
pooling_type: Optional[str] = None,
pooling_norm: Optional[bool] = None,
pooling_softmax: Optional[bool] = None,
pooling_step_tag_id: Optional[int] = None,
pooling_returned_token_ids: Optional[List[int]] = None) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
@ -224,6 +244,13 @@ class ModelConfig:
supported_tasks, task = self._resolve_task(task, self.hf_config) supported_tasks, task = self._resolve_task(task, self.hf_config)
self.supported_tasks = supported_tasks self.supported_tasks = supported_tasks
self.task: Final = task self.task: Final = task
self.pooler_config = self._init_pooler_config(
pooling_type,
pooling_norm,
pooling_softmax,
pooling_step_tag_id,
pooling_returned_token_ids,
)
self._verify_quantization() self._verify_quantization()
self._verify_cuda_graph() self._verify_cuda_graph()
@ -242,6 +269,23 @@ class ModelConfig:
return None return None
def _init_pooler_config(
self,
pooling_type: Optional[str] = None,
pooling_norm: Optional[bool] = None,
pooling_softmax: Optional[bool] = None,
pooling_step_tag_id: Optional[int] = None,
pooling_returned_token_ids: Optional[List[int]] = None
) -> Optional["PoolerConfig"]:
if self.task == "embedding":
return PoolerConfig(
pooling_type=pooling_type,
pooling_norm=pooling_norm,
pooling_softmax=pooling_softmax,
pooling_step_tag_id=pooling_step_tag_id,
pooling_returned_token_ids=pooling_returned_token_ids)
return None
def _init_attention_free(self) -> bool: def _init_attention_free(self) -> bool:
architectures = getattr(self.hf_config, "architectures", []) architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_attention_free_model(architectures) return ModelRegistry.is_attention_free_model(architectures)
@ -1647,6 +1691,17 @@ class MultiModalConfig:
# TODO: Add configs to init vision tower or not. # TODO: Add configs to init vision tower or not.
@dataclass
class PoolerConfig:
"""Controls the behavior of pooler in embedding model"""
pooling_type: Optional[str] = None
pooling_norm: Optional[bool] = None
pooling_softmax: Optional[bool] = None
pooling_step_tag_id: Optional[int] = None
pooling_returned_token_ids: Optional[List[int]] = None
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16, "half": torch.float16,
"float16": torch.float16, "float16": torch.float16,

View File

@ -184,6 +184,13 @@ class EngineArgs:
mm_processor_kwargs: Optional[Dict[str, Any]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None
scheduling_policy: Literal["fcfs", "priority"] = "fcfs" scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
# Pooling configuration.
pooling_type: Optional[str] = None
pooling_norm: Optional[bool] = None
pooling_softmax: Optional[bool] = None
pooling_step_tag_id: Optional[int] = None
pooling_returned_token_ids: Optional[List[int]] = None
def __post_init__(self): def __post_init__(self):
if not self.tokenizer: if not self.tokenizer:
self.tokenizer = self.model self.tokenizer = self.model
@ -850,6 +857,58 @@ class EngineArgs:
'priority (lower value means earlier handling) and time of ' 'priority (lower value means earlier handling) and time of '
'arrival deciding any ties).') 'arrival deciding any ties).')
parser.add_argument(
'--pooling-type',
choices=['LAST', 'ALL', 'CLS', 'STEP'],
default=None,
help='Used to configure the pooling method in the embedding model.'
)
parser.add_argument('--pooling-norm',
default=None,
action='store_true',
help="Used to determine whether to normalize "
"the pooled data in the embedding model.")
parser.add_argument('--no-pooling-norm',
default=None,
action='store_false',
dest='pooling_norm',
help="Used to determine whether to normalize "
"the pooled data in the embedding model.")
parser.add_argument('--pooling-softmax',
default=None,
action='store_true',
help="Used to determine whether to softmax "
"the pooled data in the embedding model.")
parser.add_argument('--no-pooling-softmax',
default=None,
action='store_false',
dest='pooling_softmax',
help="Used to determine whether to softmax "
"the pooled data in the embedding model.")
parser.add_argument(
'--pooling-step-tag-id',
type=int,
default=None,
help="When pooling-step-tag-id is not -1, it indicates "
"that the score corresponding to the step-tag-ids in the "
"generated sentence should be returned. Otherwise, it "
"returns the scores for all tokens.")
parser.add_argument(
'--pooling-returned-token-ids',
nargs='+',
type=int,
default=None,
help="pooling-returned-token-ids represents a list of "
"indices for the vocabulary dimensions to be extracted, "
"such as the token IDs of good_token and bad_token in "
"the math-shepherd-mistral-7b-prm model.")
return parser return parser
@classmethod @classmethod
@ -891,6 +950,11 @@ class EngineArgs:
override_neuron_config=self.override_neuron_config, override_neuron_config=self.override_neuron_config,
config_format=self.config_format, config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,
pooling_type=self.pooling_type,
pooling_norm=self.pooling_norm,
pooling_softmax=self.pooling_softmax,
pooling_step_tag_id=self.pooling_step_tag_id,
pooling_returned_token_ids=self.pooling_returned_token_ids,
) )
def create_load_config(self) -> LoadConfig: def create_load_config(self) -> LoadConfig:

View File

@ -257,7 +257,8 @@ class LLMEngine:
"num_scheduler_steps=%d, chunked_prefill_enabled=%s " "num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, " "multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, " "use_async_output_proc=%s, use_cached_outputs=%s, "
"chat_template_text_format=%s, mm_processor_kwargs=%s)", "chat_template_text_format=%s, mm_processor_kwargs=%s, "
"pooler_config=%r)",
VLLM_VERSION, VLLM_VERSION,
model_config.model, model_config.model,
speculative_config, speculative_config,
@ -294,6 +295,7 @@ class LLMEngine:
use_cached_outputs, use_cached_outputs,
model_config.chat_template_text_format, model_config.chat_template_text_format,
model_config.mm_processor_kwargs, model_config.mm_processor_kwargs,
model_config.pooler_config,
) )
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config self.model_config = model_config

View File

@ -159,6 +159,11 @@ class LLM:
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# After positional args are removed, move this right below `model` # After positional args are removed, move this right below `model`
task: TaskOption = "auto", task: TaskOption = "auto",
pooling_type: Optional[str] = None,
pooling_norm: Optional[bool] = None,
pooling_softmax: Optional[bool] = None,
pooling_step_tag_id: Optional[int] = None,
pooling_returned_token_ids: Optional[List[int]] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
''' '''
@ -193,6 +198,11 @@ class LLM:
disable_custom_all_reduce=disable_custom_all_reduce, disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc, disable_async_output_proc=disable_async_output_proc,
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
pooling_type=pooling_type,
pooling_norm=pooling_norm,
pooling_softmax=pooling_softmax,
pooling_step_tag_id=pooling_step_tag_id,
pooling_returned_token_ids=pooling_returned_token_ids,
**kwargs, **kwargs,
) )
self.llm_engine = LLMEngine.from_engine_args( self.llm_engine = LLMEngine.from_engine_args(

View File

@ -1,8 +1,10 @@
from enum import IntEnum from enum import IntEnum
from typing import List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import PoolerConfig
from vllm.model_executor.pooling_metadata import (PoolingMetadata, from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors) PoolingTensors)
from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
@ -13,6 +15,7 @@ class PoolingType(IntEnum):
LAST = 0 LAST = 0
ALL = 1 ALL = 1
CLS = 2 CLS = 2
STEP = 3
class Pooler(nn.Module): class Pooler(nn.Module):
@ -28,15 +31,47 @@ class Pooler(nn.Module):
normalize: Whether to normalize the pooled data. normalize: Whether to normalize the pooled data.
""" """
def __init__(self, def __init__(
pooling_type: PoolingType, self,
normalize: bool, pooling_type: PoolingType,
softmax: bool = False): normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None,
):
super().__init__() super().__init__()
self.pooling_type = pooling_type self.pooling_type = pooling_type
self.normalize = normalize self.normalize = normalize
self.softmax = softmax self.softmax = softmax
self.step_tag_id = step_tag_id
self.returned_token_ids = returned_token_ids
@classmethod
def from_config_with_defaults(
cls,
pooler_config: PoolerConfig,
pooling_type: PoolingType,
normalize: bool,
softmax: bool,
step_tag_id: Optional[int] = None,
returned_token_ids: Optional[List[int]] = None,
) -> Optional["Pooler"]:
if pooler_config is None:
return None
return cls(
pooling_type=PoolingType[pooler_config.pooling_type]
if pooler_config.pooling_type is not None else pooling_type,
normalize=pooler_config.pooling_norm
if pooler_config.pooling_norm is not None else normalize,
softmax=pooler_config.pooling_softmax
if pooler_config.pooling_softmax is not None else softmax,
step_tag_id=pooler_config.pooling_step_tag_id
if pooler_config.pooling_step_tag_id is not None else step_tag_id,
returned_token_ids=pooler_config.pooling_returned_token_ids
if pooler_config.pooling_returned_token_ids is not None else
returned_token_ids,
)
def forward( def forward(
self, self,
@ -62,6 +97,25 @@ class Pooler(nn.Module):
for prompt_len in prompt_lens: for prompt_len in prompt_lens:
pooled_data.append(hidden_states[offset:offset + prompt_len]) pooled_data.append(hidden_states[offset:offset + prompt_len])
offset += prompt_len offset += prompt_len
elif self.pooling_type == PoolingType.STEP:
if self.returned_token_ids is not None and len(
self.returned_token_ids) > 0:
logits = hidden_states[:,
self.returned_token_ids].softmax(dim=-1)
else:
logits = hidden_states.softmax(dim=-1)
offset = 0
pooled_data = []
for prompt_len, seq_data_i in zip(
prompt_lens, pooling_metadata.seq_data.values()):
if self.step_tag_id is None:
pooled_data.append(logits[offset:offset + prompt_len])
else:
step_idxs = torch.tensor(
seq_data_i.prompt_token_ids) == self.step_tag_id
pooled_data.append(logits[offset:offset +
prompt_len][step_idxs])
offset += prompt_len
else: else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}") raise ValueError(f"Invalid pooling type: {self.pooling_type}")

View File

@ -23,7 +23,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, MultiModalConfig, LoRAConfig, ModelConfig, MultiModalConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, PoolerConfig, SchedulerConfig)
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
@ -122,7 +122,8 @@ def _get_model_initialization_kwargs(
model_class: Type[nn.Module], model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig], multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]: scheduler_config: Optional[SchedulerConfig] = None,
pooler_config: Optional[PoolerConfig] = None) -> Dict[str, Any]:
"""Get extra kwargs for model initialization.""" """Get extra kwargs for model initialization."""
extra_kwargs: Dict[str, Any] = {} extra_kwargs: Dict[str, Any] = {}
@ -143,7 +144,8 @@ def _get_model_initialization_kwargs(
if has_inner_state(model_class) and scheduler_config: if has_inner_state(model_class) and scheduler_config:
extra_kwargs["scheduler_config"] = scheduler_config extra_kwargs["scheduler_config"] = scheduler_config
if pooler_config:
extra_kwargs["pooler_config"] = pooler_config
return extra_kwargs return extra_kwargs
@ -155,10 +157,12 @@ def build_model(model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig], multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig], scheduler_config: Optional[SchedulerConfig],
prefix: Optional[str] = None) -> nn.Module: prefix: Optional[str] = None,
pooler_config: Optional[PoolerConfig] = None) -> nn.Module:
extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config, extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
multimodal_config, multimodal_config,
scheduler_config) scheduler_config,
pooler_config)
if prefix: if prefix:
extra_kwargs["prefix"] = prefix extra_kwargs["prefix"] = prefix
@ -185,6 +189,7 @@ def _initialize_model(
lora_config=lora_config, lora_config=lora_config,
multimodal_config=model_config.multimodal_config, multimodal_config=model_config.multimodal_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
pooler_config=model_config.pooler_config,
) )

View File

@ -6,7 +6,7 @@ from transformers import BertConfig
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.xformers import XFormersImpl from vllm.attention.backends.xformers import XFormersImpl
from vllm.config import CacheConfig from vllm.config import CacheConfig, PoolerConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -387,10 +387,15 @@ class BertEmbeddingModel(nn.Module):
config: BertConfig, config: BertConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
pooler_config: Optional[PoolerConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.model = BertModel(config, cache_config, quant_config) self.model = BertModel(config, cache_config, quant_config)
self._pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True) self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)
def forward( def forward(
self, self,

View File

@ -22,7 +22,7 @@ from transformers import Gemma2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
@ -473,13 +473,17 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
pooler_config: Optional[PoolerConfig] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
self.model = Gemma2Model(**kwargs) self.model = Gemma2Model(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)

View File

@ -29,7 +29,7 @@ from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
@ -502,6 +502,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
prefix: str = "", prefix: str = "",
pooler_config: Optional[PoolerConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -543,6 +544,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.STEP,
normalize=False,
softmax=False)
def forward( def forward(
self, self,
@ -565,6 +571,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
logits = self.compute_logits(hidden_states, None)
return self._pooler(logits, pooling_metadata)
def sample(self, logits: torch.Tensor, def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
@ -630,12 +644,17 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
pooler_config: Optional[PoolerConfig] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
self.model = LlamaModel(**kwargs) self.model = LlamaModel(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)

View File

@ -11,7 +11,7 @@ from transformers.models.llava_next.modeling_llava_next import (
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig, PoolerConfig
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
@ -285,7 +285,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
config: LlavaNextConfig, config: LlavaNextConfig,
multimodal_config: MultiModalConfig, multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None,
pooler_config: Optional[PoolerConfig] = None) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
@ -312,8 +313,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
# The same model class supports both language generation and embedding # The same model class supports both language generation and embedding
# because the architecture name is the same # because the architecture name is the same
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)

View File

@ -26,7 +26,8 @@ from PIL import Image
from transformers import CLIPVisionConfig, PretrainedConfig from transformers import CLIPVisionConfig, PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig from vllm.config import (CacheConfig, ModelConfig, MultiModalConfig,
PoolerConfig)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs) token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -530,7 +531,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
config: PretrainedConfig, config: PretrainedConfig,
multimodal_config: MultiModalConfig, multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None,
pooler_config: Optional[PoolerConfig] = None) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
@ -556,8 +558,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# The same model class supports both language generation and embedding # The same model class supports both language generation and embedding
# because the architecture name is the same # because the architecture name is the same
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)

View File

@ -12,7 +12,7 @@ from torch import nn
from transformers import Qwen2Config from transformers import Qwen2Config
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
@ -53,6 +53,7 @@ class Qwen2ForSequenceClassification(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
pooler_config: Optional[PoolerConfig] = None,
) -> None: ) -> None:
# TODO (@robertgshaw2): see if this can be moved out # TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None if (cache_config.sliding_window is not None
@ -77,9 +78,11 @@ class Qwen2ForSequenceClassification(nn.Module):
self.score = RowParallelLinear(config.hidden_size, self.score = RowParallelLinear(config.hidden_size,
config.num_labels, config.num_labels,
quant_config=quant_config) quant_config=quant_config)
self._pooler = Pooler(pooling_type=PoolingType.LAST, self._pooler = Pooler.from_config_with_defaults(
normalize=False, pooler_config,
softmax=True) pooling_type=PoolingType.LAST,
normalize=False,
softmax=True)
def forward( def forward(
self, self,

View File

@ -11,7 +11,7 @@ from torch import nn
from transformers import Qwen2Config from transformers import Qwen2Config
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
@ -64,6 +64,7 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
pooler_config: Optional[PoolerConfig] = None,
) -> None: ) -> None:
# TODO (@robertgshaw2): see if this can be moved out # TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None if (cache_config.sliding_window is not None
@ -93,8 +94,11 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
RowParallelLinear(config.hidden_size, 1, RowParallelLinear(config.hidden_size, 1,
quant_config=quant_config), quant_config=quant_config),
) )
self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False) self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.ALL,
normalize=False,
softmax=False)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)

View File

@ -100,11 +100,27 @@ _EMBEDDING_MODELS = {
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForSequenceClassification": ( "Qwen2ForSequenceClassification": (
"qwen2_cls", "Qwen2ForSequenceClassification"), "qwen2_cls", "Qwen2ForSequenceClassification"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
# [Multimodal] # [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
} }
def add_embedding_models(base_models, embedding_models):
with_pooler_method_models = {}
embedding_models_name = embedding_models.keys()
for name, (path, arch) in base_models.items():
if arch in embedding_models_name:
with_pooler_method_models[name] = (path, arch)
return with_pooler_method_models
_EMBEDDING_MODELS = {
**add_embedding_models(_TEXT_GENERATION_MODELS, _EMBEDDING_MODELS),
**_EMBEDDING_MODELS,
}
_MULTIMODAL_MODELS = { _MULTIMODAL_MODELS = {
# [Decoder-only] # [Decoder-only]
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),