mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 09:46:09 +08:00
[Model] Support InternLM2 Reward models (#11571)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
b5cbe8eeb3
commit
d34be24bb1
@ -450,6 +450,11 @@ of the whole prompt are extracted from the normalized hidden state corresponding
|
|||||||
- Example HF Models
|
- Example HF Models
|
||||||
- :ref:`LoRA <lora-adapter>`
|
- :ref:`LoRA <lora-adapter>`
|
||||||
- :ref:`PP <distributed-serving>`
|
- :ref:`PP <distributed-serving>`
|
||||||
|
* - :code:`InternLM2ForRewardModel`
|
||||||
|
- InternLM2-based
|
||||||
|
- :code:`internlm/internlm2-1_8b-reward`, :code:`internlm/internlm2-7b-reward`, etc.
|
||||||
|
- ✅︎
|
||||||
|
- ✅︎
|
||||||
* - :code:`LlamaForCausalLM`
|
* - :code:`LlamaForCausalLM`
|
||||||
- Llama-based
|
- Llama-based
|
||||||
- :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc.
|
- :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc.
|
||||||
|
|||||||
@ -140,6 +140,8 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
|||||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
||||||
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
|
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
|
||||||
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
|
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
|
||||||
|
"InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward",
|
||||||
|
trust_remote_code=True),
|
||||||
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
|
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
|
||||||
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
|
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
|
||||||
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
|
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
|
||||||
|
|||||||
@ -18,14 +18,16 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (is_pp_missing_parameter,
|
||||||
@ -433,3 +435,59 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
|||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
loaded_params.add(name)
|
loaded_params.add(name)
|
||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
model_type: Type[InternLM2Model] = InternLM2Model,
|
||||||
|
):
|
||||||
|
super().__init__(vllm_config=vllm_config,
|
||||||
|
prefix=prefix,
|
||||||
|
model_type=model_type)
|
||||||
|
|
||||||
|
for attr in ("output", "logits_processor", "sampler"):
|
||||||
|
delattr(self, attr)
|
||||||
|
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
self.v_head = RowParallelLinear(
|
||||||
|
config.hidden_size,
|
||||||
|
1,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=False,
|
||||||
|
prefix=maybe_prefix(prefix, "v_head"),
|
||||||
|
)
|
||||||
|
|
||||||
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
|
self._pooler = Pooler.from_config_with_defaults(
|
||||||
|
pooler_config,
|
||||||
|
pooling_type=PoolingType.ALL,
|
||||||
|
normalize=False,
|
||||||
|
softmax=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
|
attn_metadata, intermediate_tensors,
|
||||||
|
inputs_embeds)
|
||||||
|
logits, _ = self.v_head(hidden_states)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def pooler(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Optional[PoolerOutput]:
|
||||||
|
return self._pooler(hidden_states, pooling_metadata)
|
||||||
|
|||||||
@ -113,6 +113,7 @@ _EMBEDDING_MODELS = {
|
|||||||
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
||||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||||
"GritLM": ("gritlm", "GritLM"),
|
"GritLM": ("gritlm", "GritLM"),
|
||||||
|
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
|
||||||
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
|
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
|
||||||
"LlamaModel": ("llama", "LlamaForCausalLM"),
|
"LlamaModel": ("llama", "LlamaForCausalLM"),
|
||||||
**{
|
**{
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user