mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 16:27:27 +08:00
[Model] Support math-shepherd-mistral-7b-prm model (#9697)
Signed-off-by: Went-Liang <wenteng_liang@163.com>
This commit is contained in:
parent
cc98f1e079
commit
81f09cfd80
115
vllm/config.py
115
vllm/config.py
@ -112,38 +112,58 @@ class ModelConfig:
|
||||
Defaults to 'auto' which defaults to 'hf'.
|
||||
mm_processor_kwargs: Arguments to be forwarded to the model's processor
|
||||
for multi-modal data, e.g., image processor.
|
||||
pooling_type: Used to configure the pooling method in the embedding
|
||||
model.
|
||||
pooling_norm: Used to determine whether to normalize the pooled
|
||||
data in the embedding model.
|
||||
pooling_softmax: Used to determine whether to softmax the pooled
|
||||
data in the embedding model.
|
||||
pooling_step_tag_id: When pooling_step_tag_id is not -1, it indicates
|
||||
that the score corresponding to the pooling_step_tag_id in the
|
||||
generated sentence should be returned. Otherwise, it returns
|
||||
the scores for all tokens.
|
||||
pooling_returned_token_ids: pooling_returned_token_ids represents a
|
||||
list of indices for the vocabulary dimensions to be extracted,
|
||||
such as the token IDs of good_token and bad_token in the
|
||||
math-shepherd-mistral-7b-prm model.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
task: Union[TaskOption, _Task],
|
||||
tokenizer: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
dtype: Union[str, torch.dtype],
|
||||
seed: int,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[dict] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
spec_target_max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
quantization_param_path: Optional[str] = None,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
max_seq_len_to_capture: Optional[int] = None,
|
||||
max_logprobs: int = 20,
|
||||
disable_sliding_window: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
||||
use_async_output_proc: bool = True,
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||
chat_template_text_format: str = "string",
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
task: Union[TaskOption, _Task],
|
||||
tokenizer: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
dtype: Union[str, torch.dtype],
|
||||
seed: int,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[dict] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
spec_target_max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
quantization_param_path: Optional[str] = None,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
max_seq_len_to_capture: Optional[int] = None,
|
||||
max_logprobs: int = 20,
|
||||
disable_sliding_window: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
||||
use_async_output_proc: bool = True,
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||
chat_template_text_format: str = "string",
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
pooling_type: Optional[str] = None,
|
||||
pooling_norm: Optional[bool] = None,
|
||||
pooling_softmax: Optional[bool] = None,
|
||||
pooling_step_tag_id: Optional[int] = None,
|
||||
pooling_returned_token_ids: Optional[List[int]] = None) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
@ -224,6 +244,13 @@ class ModelConfig:
|
||||
supported_tasks, task = self._resolve_task(task, self.hf_config)
|
||||
self.supported_tasks = supported_tasks
|
||||
self.task: Final = task
|
||||
self.pooler_config = self._init_pooler_config(
|
||||
pooling_type,
|
||||
pooling_norm,
|
||||
pooling_softmax,
|
||||
pooling_step_tag_id,
|
||||
pooling_returned_token_ids,
|
||||
)
|
||||
|
||||
self._verify_quantization()
|
||||
self._verify_cuda_graph()
|
||||
@ -242,6 +269,23 @@ class ModelConfig:
|
||||
|
||||
return None
|
||||
|
||||
def _init_pooler_config(
|
||||
self,
|
||||
pooling_type: Optional[str] = None,
|
||||
pooling_norm: Optional[bool] = None,
|
||||
pooling_softmax: Optional[bool] = None,
|
||||
pooling_step_tag_id: Optional[int] = None,
|
||||
pooling_returned_token_ids: Optional[List[int]] = None
|
||||
) -> Optional["PoolerConfig"]:
|
||||
if self.task == "embedding":
|
||||
return PoolerConfig(
|
||||
pooling_type=pooling_type,
|
||||
pooling_norm=pooling_norm,
|
||||
pooling_softmax=pooling_softmax,
|
||||
pooling_step_tag_id=pooling_step_tag_id,
|
||||
pooling_returned_token_ids=pooling_returned_token_ids)
|
||||
return None
|
||||
|
||||
def _init_attention_free(self) -> bool:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
return ModelRegistry.is_attention_free_model(architectures)
|
||||
@ -1647,6 +1691,17 @@ class MultiModalConfig:
|
||||
# TODO: Add configs to init vision tower or not.
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoolerConfig:
|
||||
"""Controls the behavior of pooler in embedding model"""
|
||||
|
||||
pooling_type: Optional[str] = None
|
||||
pooling_norm: Optional[bool] = None
|
||||
pooling_softmax: Optional[bool] = None
|
||||
pooling_step_tag_id: Optional[int] = None
|
||||
pooling_returned_token_ids: Optional[List[int]] = None
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.float16,
|
||||
"float16": torch.float16,
|
||||
|
||||
@ -184,6 +184,13 @@ class EngineArgs:
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
|
||||
|
||||
# Pooling configuration.
|
||||
pooling_type: Optional[str] = None
|
||||
pooling_norm: Optional[bool] = None
|
||||
pooling_softmax: Optional[bool] = None
|
||||
pooling_step_tag_id: Optional[int] = None
|
||||
pooling_returned_token_ids: Optional[List[int]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.tokenizer:
|
||||
self.tokenizer = self.model
|
||||
@ -850,6 +857,58 @@ class EngineArgs:
|
||||
'priority (lower value means earlier handling) and time of '
|
||||
'arrival deciding any ties).')
|
||||
|
||||
parser.add_argument(
|
||||
'--pooling-type',
|
||||
choices=['LAST', 'ALL', 'CLS', 'STEP'],
|
||||
default=None,
|
||||
help='Used to configure the pooling method in the embedding model.'
|
||||
)
|
||||
|
||||
parser.add_argument('--pooling-norm',
|
||||
default=None,
|
||||
action='store_true',
|
||||
help="Used to determine whether to normalize "
|
||||
"the pooled data in the embedding model.")
|
||||
|
||||
parser.add_argument('--no-pooling-norm',
|
||||
default=None,
|
||||
action='store_false',
|
||||
dest='pooling_norm',
|
||||
help="Used to determine whether to normalize "
|
||||
"the pooled data in the embedding model.")
|
||||
|
||||
parser.add_argument('--pooling-softmax',
|
||||
default=None,
|
||||
action='store_true',
|
||||
help="Used to determine whether to softmax "
|
||||
"the pooled data in the embedding model.")
|
||||
|
||||
parser.add_argument('--no-pooling-softmax',
|
||||
default=None,
|
||||
action='store_false',
|
||||
dest='pooling_softmax',
|
||||
help="Used to determine whether to softmax "
|
||||
"the pooled data in the embedding model.")
|
||||
|
||||
parser.add_argument(
|
||||
'--pooling-step-tag-id',
|
||||
type=int,
|
||||
default=None,
|
||||
help="When pooling-step-tag-id is not -1, it indicates "
|
||||
"that the score corresponding to the step-tag-ids in the "
|
||||
"generated sentence should be returned. Otherwise, it "
|
||||
"returns the scores for all tokens.")
|
||||
|
||||
parser.add_argument(
|
||||
'--pooling-returned-token-ids',
|
||||
nargs='+',
|
||||
type=int,
|
||||
default=None,
|
||||
help="pooling-returned-token-ids represents a list of "
|
||||
"indices for the vocabulary dimensions to be extracted, "
|
||||
"such as the token IDs of good_token and bad_token in "
|
||||
"the math-shepherd-mistral-7b-prm model.")
|
||||
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -891,6 +950,11 @@ class EngineArgs:
|
||||
override_neuron_config=self.override_neuron_config,
|
||||
config_format=self.config_format,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
pooling_type=self.pooling_type,
|
||||
pooling_norm=self.pooling_norm,
|
||||
pooling_softmax=self.pooling_softmax,
|
||||
pooling_step_tag_id=self.pooling_step_tag_id,
|
||||
pooling_returned_token_ids=self.pooling_returned_token_ids,
|
||||
)
|
||||
|
||||
def create_load_config(self) -> LoadConfig:
|
||||
|
||||
@ -257,7 +257,8 @@ class LLMEngine:
|
||||
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
|
||||
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
|
||||
"use_async_output_proc=%s, use_cached_outputs=%s, "
|
||||
"chat_template_text_format=%s, mm_processor_kwargs=%s)",
|
||||
"chat_template_text_format=%s, mm_processor_kwargs=%s, "
|
||||
"pooler_config=%r)",
|
||||
VLLM_VERSION,
|
||||
model_config.model,
|
||||
speculative_config,
|
||||
@ -294,6 +295,7 @@ class LLMEngine:
|
||||
use_cached_outputs,
|
||||
model_config.chat_template_text_format,
|
||||
model_config.mm_processor_kwargs,
|
||||
model_config.pooler_config,
|
||||
)
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
self.model_config = model_config
|
||||
|
||||
@ -159,6 +159,11 @@ class LLM:
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
# After positional args are removed, move this right below `model`
|
||||
task: TaskOption = "auto",
|
||||
pooling_type: Optional[str] = None,
|
||||
pooling_norm: Optional[bool] = None,
|
||||
pooling_softmax: Optional[bool] = None,
|
||||
pooling_step_tag_id: Optional[int] = None,
|
||||
pooling_returned_token_ids: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
'''
|
||||
@ -193,6 +198,11 @@ class LLM:
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
pooling_type=pooling_type,
|
||||
pooling_norm=pooling_norm,
|
||||
pooling_softmax=pooling_softmax,
|
||||
pooling_step_tag_id=pooling_step_tag_id,
|
||||
pooling_returned_token_ids=pooling_returned_token_ids,
|
||||
**kwargs,
|
||||
)
|
||||
self.llm_engine = LLMEngine.from_engine_args(
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
from enum import IntEnum
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import PoolerConfig
|
||||
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
||||
PoolingTensors)
|
||||
from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
|
||||
@ -13,6 +15,7 @@ class PoolingType(IntEnum):
|
||||
LAST = 0
|
||||
ALL = 1
|
||||
CLS = 2
|
||||
STEP = 3
|
||||
|
||||
|
||||
class Pooler(nn.Module):
|
||||
@ -28,15 +31,47 @@ class Pooler(nn.Module):
|
||||
normalize: Whether to normalize the pooled data.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pooling_type: PoolingType,
|
||||
normalize: bool,
|
||||
softmax: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
pooling_type: PoolingType,
|
||||
normalize: bool,
|
||||
softmax: bool,
|
||||
step_tag_id: Optional[int] = None,
|
||||
returned_token_ids: Optional[List[int]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.pooling_type = pooling_type
|
||||
self.normalize = normalize
|
||||
self.softmax = softmax
|
||||
self.step_tag_id = step_tag_id
|
||||
self.returned_token_ids = returned_token_ids
|
||||
|
||||
@classmethod
|
||||
def from_config_with_defaults(
|
||||
cls,
|
||||
pooler_config: PoolerConfig,
|
||||
pooling_type: PoolingType,
|
||||
normalize: bool,
|
||||
softmax: bool,
|
||||
step_tag_id: Optional[int] = None,
|
||||
returned_token_ids: Optional[List[int]] = None,
|
||||
) -> Optional["Pooler"]:
|
||||
if pooler_config is None:
|
||||
return None
|
||||
return cls(
|
||||
pooling_type=PoolingType[pooler_config.pooling_type]
|
||||
if pooler_config.pooling_type is not None else pooling_type,
|
||||
normalize=pooler_config.pooling_norm
|
||||
if pooler_config.pooling_norm is not None else normalize,
|
||||
softmax=pooler_config.pooling_softmax
|
||||
if pooler_config.pooling_softmax is not None else softmax,
|
||||
step_tag_id=pooler_config.pooling_step_tag_id
|
||||
if pooler_config.pooling_step_tag_id is not None else step_tag_id,
|
||||
returned_token_ids=pooler_config.pooling_returned_token_ids
|
||||
if pooler_config.pooling_returned_token_ids is not None else
|
||||
returned_token_ids,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -62,6 +97,25 @@ class Pooler(nn.Module):
|
||||
for prompt_len in prompt_lens:
|
||||
pooled_data.append(hidden_states[offset:offset + prompt_len])
|
||||
offset += prompt_len
|
||||
elif self.pooling_type == PoolingType.STEP:
|
||||
if self.returned_token_ids is not None and len(
|
||||
self.returned_token_ids) > 0:
|
||||
logits = hidden_states[:,
|
||||
self.returned_token_ids].softmax(dim=-1)
|
||||
else:
|
||||
logits = hidden_states.softmax(dim=-1)
|
||||
offset = 0
|
||||
pooled_data = []
|
||||
for prompt_len, seq_data_i in zip(
|
||||
prompt_lens, pooling_metadata.seq_data.values()):
|
||||
if self.step_tag_id is None:
|
||||
pooled_data.append(logits[offset:offset + prompt_len])
|
||||
else:
|
||||
step_idxs = torch.tensor(
|
||||
seq_data_i.prompt_token_ids) == self.step_tag_id
|
||||
pooled_data.append(logits[offset:offset +
|
||||
prompt_len][step_idxs])
|
||||
offset += prompt_len
|
||||
else:
|
||||
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
|
||||
LoRAConfig, ModelConfig, MultiModalConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
ParallelConfig, PoolerConfig, SchedulerConfig)
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||
@ -122,7 +122,8 @@ def _get_model_initialization_kwargs(
|
||||
model_class: Type[nn.Module],
|
||||
lora_config: Optional[LoRAConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]:
|
||||
scheduler_config: Optional[SchedulerConfig] = None,
|
||||
pooler_config: Optional[PoolerConfig] = None) -> Dict[str, Any]:
|
||||
"""Get extra kwargs for model initialization."""
|
||||
extra_kwargs: Dict[str, Any] = {}
|
||||
|
||||
@ -143,7 +144,8 @@ def _get_model_initialization_kwargs(
|
||||
|
||||
if has_inner_state(model_class) and scheduler_config:
|
||||
extra_kwargs["scheduler_config"] = scheduler_config
|
||||
|
||||
if pooler_config:
|
||||
extra_kwargs["pooler_config"] = pooler_config
|
||||
return extra_kwargs
|
||||
|
||||
|
||||
@ -155,10 +157,12 @@ def build_model(model_class: Type[nn.Module],
|
||||
lora_config: Optional[LoRAConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
scheduler_config: Optional[SchedulerConfig],
|
||||
prefix: Optional[str] = None) -> nn.Module:
|
||||
prefix: Optional[str] = None,
|
||||
pooler_config: Optional[PoolerConfig] = None) -> nn.Module:
|
||||
extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
|
||||
multimodal_config,
|
||||
scheduler_config)
|
||||
scheduler_config,
|
||||
pooler_config)
|
||||
if prefix:
|
||||
extra_kwargs["prefix"] = prefix
|
||||
|
||||
@ -185,6 +189,7 @@ def _initialize_model(
|
||||
lora_config=lora_config,
|
||||
multimodal_config=model_config.multimodal_config,
|
||||
scheduler_config=scheduler_config,
|
||||
pooler_config=model_config.pooler_config,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from transformers import BertConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.attention.backends.xformers import XFormersImpl
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.config import CacheConfig, PoolerConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -387,10 +387,15 @@ class BertEmbeddingModel(nn.Module):
|
||||
config: BertConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
pooler_config: Optional[PoolerConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = BertModel(config, cache_config, quant_config)
|
||||
self._pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.CLS,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -22,7 +22,7 @@ from transformers import Gemma2Config
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import GeluAndMul
|
||||
@ -473,13 +473,17 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pooler_config: Optional[PoolerConfig] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.model = Gemma2Model(**kwargs)
|
||||
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ from transformers import LlamaConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@ -502,6 +502,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
pooler_config: Optional[PoolerConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -543,6 +544,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.STEP,
|
||||
normalize=False,
|
||||
softmax=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -565,6 +571,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
logits = self.compute_logits(hidden_states, None)
|
||||
return self._pooler(logits, pooling_metadata)
|
||||
|
||||
def sample(self, logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
@ -630,12 +644,17 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pooler_config: Optional[PoolerConfig] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.model = LlamaModel(**kwargs)
|
||||
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from transformers.models.llava_next.modeling_llava_next import (
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.config import CacheConfig, MultiModalConfig, PoolerConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -285,7 +285,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config: LlavaNextConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
pooler_config: Optional[PoolerConfig] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
@ -312,8 +313,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
# The same model class supports both language generation and embedding
|
||||
# because the architecture name is the same
|
||||
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
|
||||
@ -26,7 +26,8 @@ from PIL import Image
|
||||
from transformers import CLIPVisionConfig, PretrainedConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
|
||||
from vllm.config import (CacheConfig, ModelConfig, MultiModalConfig,
|
||||
PoolerConfig)
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.logger import init_logger
|
||||
@ -530,7 +531,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
config: PretrainedConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
pooler_config: Optional[PoolerConfig] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
@ -556,8 +558,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
# The same model class supports both language generation and embedding
|
||||
# because the architecture name is the same
|
||||
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from torch import nn
|
||||
from transformers import Qwen2Config
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -53,6 +53,7 @@ class Qwen2ForSequenceClassification(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
pooler_config: Optional[PoolerConfig] = None,
|
||||
) -> None:
|
||||
# TODO (@robertgshaw2): see if this can be moved out
|
||||
if (cache_config.sliding_window is not None
|
||||
@ -77,9 +78,11 @@ class Qwen2ForSequenceClassification(nn.Module):
|
||||
self.score = RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
quant_config=quant_config)
|
||||
self._pooler = Pooler(pooling_type=PoolingType.LAST,
|
||||
normalize=False,
|
||||
softmax=True)
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=False,
|
||||
softmax=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -11,7 +11,7 @@ from torch import nn
|
||||
from transformers import Qwen2Config
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
@ -64,6 +64,7 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
pooler_config: Optional[PoolerConfig] = None,
|
||||
) -> None:
|
||||
# TODO (@robertgshaw2): see if this can be moved out
|
||||
if (cache_config.sliding_window is not None
|
||||
@ -93,8 +94,11 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
|
||||
RowParallelLinear(config.hidden_size, 1,
|
||||
quant_config=quant_config),
|
||||
)
|
||||
self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False)
|
||||
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.ALL,
|
||||
normalize=False,
|
||||
softmax=False)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
|
||||
@ -100,11 +100,27 @@ _EMBEDDING_MODELS = {
|
||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||
"Qwen2ForSequenceClassification": (
|
||||
"qwen2_cls", "Qwen2ForSequenceClassification"),
|
||||
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
||||
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
||||
# [Multimodal]
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
}
|
||||
|
||||
def add_embedding_models(base_models, embedding_models):
|
||||
with_pooler_method_models = {}
|
||||
embedding_models_name = embedding_models.keys()
|
||||
for name, (path, arch) in base_models.items():
|
||||
if arch in embedding_models_name:
|
||||
with_pooler_method_models[name] = (path, arch)
|
||||
return with_pooler_method_models
|
||||
|
||||
_EMBEDDING_MODELS = {
|
||||
**add_embedding_models(_TEXT_GENERATION_MODELS, _EMBEDDING_MODELS),
|
||||
**_EMBEDDING_MODELS,
|
||||
}
|
||||
|
||||
_MULTIMODAL_MODELS = {
|
||||
# [Decoder-only]
|
||||
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user