mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[Bugfix] Fix weight loading for Chameleon when TP>1 (#7410)
This commit is contained in:
parent
5469146bcc
commit
7025b11d94
@ -4,7 +4,8 @@ import os
|
||||
import sys
|
||||
from collections import UserList
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union
|
||||
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
|
||||
TypeVar, Union)
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -27,7 +28,7 @@ from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
|
||||
is_cpu)
|
||||
identity, is_cpu)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -197,6 +198,8 @@ class HfRunner:
|
||||
is_embedding_model: bool = False,
|
||||
is_vision_model: bool = False,
|
||||
is_encoder_decoder_model: bool = False,
|
||||
postprocess_inputs: Callable[[BatchEncoding],
|
||||
BatchEncoding] = identity,
|
||||
) -> None:
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
|
||||
@ -242,12 +245,14 @@ class HfRunner:
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
except Exception:
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Unable to auto-load processor from HuggingFace for "
|
||||
"model %s. Using tokenizer instead.", model_name)
|
||||
"Unable to auto-load HuggingFace processor for model (%s). "
|
||||
"Using tokenizer instead. Reason: %s", model_name, exc)
|
||||
self.processor = self.tokenizer
|
||||
|
||||
self.postprocess_inputs = postprocess_inputs
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -267,6 +272,7 @@ class HfRunner:
|
||||
processor_kwargs["images"] = images[i]
|
||||
|
||||
inputs = self.processor(**processor_kwargs)
|
||||
inputs = self.postprocess_inputs(inputs)
|
||||
|
||||
output_ids = self.model.generate(
|
||||
**self.wrap_device(inputs),
|
||||
@ -336,6 +342,7 @@ class HfRunner:
|
||||
processor_kwargs["images"] = images[i]
|
||||
|
||||
inputs = self.processor(**processor_kwargs)
|
||||
inputs = self.postprocess_inputs(inputs)
|
||||
|
||||
output = self.model.generate(
|
||||
**self.wrap_device(inputs),
|
||||
@ -420,6 +427,7 @@ class HfRunner:
|
||||
processor_kwargs["images"] = images[i]
|
||||
|
||||
inputs = self.processor(**processor_kwargs)
|
||||
inputs = self.postprocess_inputs(inputs)
|
||||
|
||||
output = self.model.generate(
|
||||
**self.wrap_device(inputs),
|
||||
@ -552,7 +560,8 @@ class VllmRunner:
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
images: Optional[List[Image.Image]] = None,
|
||||
images: Optional[Union[List[Image.Image],
|
||||
List[List[Image.Image]]]] = None,
|
||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||
if images is not None:
|
||||
assert len(prompts) == len(images)
|
||||
@ -587,7 +596,7 @@ class VllmRunner:
|
||||
for req_output in req_outputs:
|
||||
for sample in req_output.outputs:
|
||||
output_str = sample.text
|
||||
output_ids = sample.token_ids
|
||||
output_ids = list(sample.token_ids)
|
||||
output_logprobs = sample.logprobs
|
||||
outputs.append((output_ids, output_str, output_logprobs))
|
||||
return outputs
|
||||
@ -596,7 +605,8 @@ class VllmRunner:
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
images: Optional[List[Image.Image]] = None,
|
||||
images: Optional[Union[List[Image.Image],
|
||||
List[List[Image.Image]]]] = None,
|
||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||
assert sampling_params.logprobs is not None
|
||||
|
||||
|
||||
@ -18,8 +18,10 @@ from ..utils import fork_new_process_for_each_test
|
||||
@pytest.mark.parametrize("model, distributed_executor_backend", [
|
||||
("llava-hf/llava-1.5-7b-hf", "ray"),
|
||||
("llava-hf/llava-v1.6-mistral-7b-hf", "ray"),
|
||||
("facebook/chameleon-7b", "ray"),
|
||||
("llava-hf/llava-1.5-7b-hf", "mp"),
|
||||
("llava-hf/llava-v1.6-mistral-7b-hf", "mp"),
|
||||
("facebook/chameleon-7b", "mp"),
|
||||
])
|
||||
@fork_new_process_for_each_test
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model: str,
|
||||
@ -34,6 +36,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str,
|
||||
from ..models.test_llava import models, run_test
|
||||
elif model.startswith("llava-hf/llava-v1.6"):
|
||||
from ..models.test_llava_next import models, run_test
|
||||
elif model.startswith("facebook/chameleon"):
|
||||
from ..models.test_chameleon import models, run_test
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported model: {model}")
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from openai import OpenAI, OpenAIError
|
||||
@ -17,8 +18,11 @@ assert chatml_jinja_path.exists()
|
||||
|
||||
class MyOPTForCausalLM(OPTForCausalLM):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
# this dummy model always predicts the first token
|
||||
logits = super().compute_logits(hidden_states, sampling_metadata)
|
||||
logits.zero_()
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
import re
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import pytest
|
||||
from transformers import BatchEncoding
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
from ..conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
from .utils import check_outputs_equal
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
|
||||
@ -19,9 +21,8 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
models = ["facebook/chameleon-7b"]
|
||||
|
||||
|
||||
#TODO (ywang96): Add correctness test when chameleon is
|
||||
# available on transformers.
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
model: str,
|
||||
@ -29,13 +30,20 @@ def run_test(
|
||||
size_factors: List[float],
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
"""Test if the model can generate text given
|
||||
a batch of images and prompts.
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding vision language config as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
@ -50,35 +58,49 @@ def run_test(
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True) as vllm_model:
|
||||
|
||||
for prompts, images in inputs_per_image:
|
||||
vllm_outputs = vllm_model.generate_greedy(prompts,
|
||||
max_tokens,
|
||||
images=images)
|
||||
for i in range(len(vllm_outputs)):
|
||||
vllm_outputs_per_image = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
# format prompt back to original
|
||||
replacements = {
|
||||
"<racm3:break>": "",
|
||||
"<eoss>": "",
|
||||
"<reserved08706>": ""
|
||||
}
|
||||
pattern = '|'.join(replacements.keys())
|
||||
vllm_result = re.sub(
|
||||
pattern,
|
||||
lambda match: replacements[match.group(0)], #noqa B023
|
||||
vllm_outputs[i][1])
|
||||
vllm_result = vllm_result.replace("<image>", "", 1023)
|
||||
assert vllm_result[:len(prompts[i])] == prompts[i]
|
||||
def process(hf_inputs: BatchEncoding):
|
||||
hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
|
||||
.to(torch_dtype) # type: ignore
|
||||
return hf_inputs
|
||||
|
||||
# assert at least 10 new characters are generated
|
||||
# (to take stop token into account)
|
||||
assert len(vllm_outputs[i][1]) - len(prompts[i]) > 10
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
postprocess_inputs=process,
|
||||
is_vision_model=True) as hf_model:
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||
vllm_outputs_per_image):
|
||||
# HF Logprobs include image tokens, unlike vLLM, so we don't directly
|
||||
# compare them
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=[outputs[:2] for outputs in hf_outputs],
|
||||
outputs_1_lst=[outputs[:2] for outputs in vllm_outputs],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# No image
|
||||
[],
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
@ -88,15 +110,18 @@ def run_test(
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(vllm_runner, image_assets, model, size_factors, dtype: str,
|
||||
max_tokens: int) -> None:
|
||||
@pytest.mark.parametrize("max_tokens", [8])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
dtype, max_tokens, num_logprobs) -> None:
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model,
|
||||
size_factors=size_factors,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers import AutoConfig, AutoTokenizer, BatchEncoding
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
@ -110,16 +110,21 @@ def run_test(
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
|
||||
if mantis_processor is not None:
|
||||
if mantis_processor is not None:
|
||||
|
||||
def process(*args, **kwargs):
|
||||
output = mantis_processor(*args, **kwargs)
|
||||
output["pixel_values"] = output["pixel_values"].to(torch_dtype)
|
||||
return output
|
||||
def process(hf_inputs: BatchEncoding):
|
||||
hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
|
||||
.to(torch_dtype) # type: ignore
|
||||
return hf_inputs
|
||||
else:
|
||||
|
||||
hf_model.processor = process
|
||||
def process(hf_inputs: BatchEncoding):
|
||||
return hf_inputs
|
||||
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
postprocess_inputs=process,
|
||||
is_vision_model=True) as hf_model:
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
from collections import UserDict
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.types
|
||||
from transformers import BatchFeature
|
||||
from transformers import BatchEncoding
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
@ -14,18 +13,6 @@ from .utils import check_logprobs_close
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
|
||||
|
||||
class NestedInputs(UserDict):
|
||||
|
||||
def __init__(self, model_inputs: BatchFeature):
|
||||
super().__init__({"model_inputs": model_inputs})
|
||||
|
||||
self.model_inputs = model_inputs
|
||||
|
||||
def to(self, device: torch.types.Device):
|
||||
return NestedInputs(self.model_inputs.to(device))
|
||||
|
||||
|
||||
# The image token is placed before "user" on purpose so that the test can pass
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
@ -41,6 +28,10 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
models = ["openbmb/MiniCPM-Llama3-V-2_5"]
|
||||
|
||||
|
||||
def _wrap_inputs(hf_inputs: BatchEncoding) -> BatchEncoding:
|
||||
return BatchEncoding({"model_inputs": hf_inputs})
|
||||
|
||||
|
||||
def trunc_hf_output(hf_output: Tuple[List[int], str,
|
||||
Optional[SampleLogprobs]]):
|
||||
output_ids, output_str, out_logprobs = hf_output
|
||||
@ -105,11 +96,8 @@ def run_test(
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
|
||||
hf_processor = hf_model.processor
|
||||
hf_model.processor = lambda **kw: NestedInputs(
|
||||
hf_processor(**kw) # type: ignore
|
||||
)
|
||||
hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs)
|
||||
with hf_model, torch.no_grad():
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
@ -224,11 +212,8 @@ def run_multi_image_test(
|
||||
for prompts, images in inputs_per_case
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
|
||||
hf_processor = hf_model.processor
|
||||
hf_model.processor = lambda **kw: NestedInputs(
|
||||
hf_processor(**kw) # type: ignore
|
||||
)
|
||||
hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs)
|
||||
with hf_model, torch.no_grad():
|
||||
hf_outputs_per_case = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import LLM, ModelRegistry, SamplingParams
|
||||
@ -7,8 +9,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
class MyOPTForCausalLM(OPTForCausalLM):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
# this dummy model always predicts the first token
|
||||
logits = super().compute_logits(hidden_states, sampling_metadata)
|
||||
logits.zero_()
|
||||
|
||||
@ -19,7 +19,7 @@ def tensor_model_parallel_all_gather(input_: torch.Tensor,
|
||||
|
||||
def tensor_model_parallel_gather(input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
dim: int = -1) -> Optional[torch.Tensor]:
|
||||
"""Gather the input tensor across model parallel group."""
|
||||
return get_tp_group().gather(input_, dst, dim)
|
||||
|
||||
|
||||
@ -329,7 +329,7 @@ class GroupCoordinator:
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
dim: int = -1) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
all the ranks.
|
||||
|
||||
@ -50,7 +50,7 @@ class LogitsProcessor(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Optional[torch.Tensor]:
|
||||
if self.logits_as_input:
|
||||
logits = hidden_states
|
||||
else:
|
||||
@ -73,14 +73,18 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
return logits
|
||||
|
||||
def _get_logits(self, hidden_states: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
def _get_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
embedding_bias: Optional[torch.Tensor],
|
||||
) -> Optional[torch.Tensor]:
|
||||
# Get the logits for the next tokens.
|
||||
logits = lm_head.linear_method.apply(lm_head,
|
||||
hidden_states,
|
||||
bias=embedding_bias)
|
||||
if self.use_gather:
|
||||
# None may be returned for rank > 0
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
else:
|
||||
# Gather is not supported for some devices such as TPUs.
|
||||
|
||||
@ -19,6 +19,7 @@ from tqdm.auto import tqdm
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import (QuantizationConfig,
|
||||
get_quantization_config)
|
||||
@ -514,8 +515,30 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
||||
def default_weight_loader(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
"""Default weight loader."""
|
||||
assert param.size() == loaded_weight.size()
|
||||
param.data.copy_(loaded_weight)
|
||||
try:
|
||||
assert param.size() == loaded_weight.size(), (
|
||||
f"Attempted to load weight ({loaded_weight.size()}) "
|
||||
f"into parameter ({param.size()})")
|
||||
|
||||
param.data.copy_(loaded_weight)
|
||||
except Exception:
|
||||
# NOTE: This exception is added for the purpose of setting breakpoint to
|
||||
# debug weight loading issues.
|
||||
raise
|
||||
|
||||
|
||||
def row_parallel_weight_loader(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
"""Load weights that are row-parallelized."""
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_dim = 0 if param.dim() != 1 else None
|
||||
|
||||
if shard_dim is not None:
|
||||
shard_size = param.data.shape[shard_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
|
||||
|
||||
return default_weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
def initialize_dummy_weights(
|
||||
|
||||
@ -433,8 +433,11 @@ class ArcticForCausalLM(nn.Module):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -346,8 +346,11 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -872,8 +872,11 @@ class BartForConditionalGeneration(nn.Module):
|
||||
return self.model(input_ids, positions, encoder_input_ids,
|
||||
encoder_positions, kv_caches, attn_metadata)
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -637,8 +637,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.get_lm_head(), hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -292,8 +292,11 @@ class BloomForCausalLM(nn.Module):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -25,8 +25,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, row_parallel_weight_loader)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import (cached_get_tokenizer,
|
||||
repeat_and_pad_image_tokens)
|
||||
@ -141,6 +143,11 @@ class ChameleonLayerNorm(nn.LayerNorm):
|
||||
super().__init__(hidden_size, *args, **kwargs)
|
||||
self.normalized_shape = (hidden_size[-1], )
|
||||
|
||||
set_weight_attrs(self.weight,
|
||||
{"weight_loader": row_parallel_weight_loader})
|
||||
set_weight_attrs(self.bias,
|
||||
{"weight_loader": row_parallel_weight_loader})
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = F.layer_norm(hidden_states,
|
||||
self.normalized_shape,
|
||||
@ -697,6 +704,8 @@ class ChameleonVQVAEEncoder(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor):
|
||||
pixel_values = pixel_values.to(self.conv_in.weight.dtype)
|
||||
|
||||
# downsampling
|
||||
hidden_states = [self.conv_in(pixel_values)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
@ -959,15 +968,19 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
# Disallow image tokens which does not include special
|
||||
# begin-image and end-image tokens
|
||||
image_tokens = self.model.vocabulary_mapping.image_tokens
|
||||
logits[:, image_tokens] = torch.finfo(logits.dtype).min
|
||||
if logits is not None:
|
||||
image_tokens = self.model.vocabulary_mapping.image_tokens
|
||||
logits[:, image_tokens] = torch.finfo(logits.dtype).min
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
@ -372,8 +372,11 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -25,13 +25,11 @@ from typing import Iterable, List, Optional, Set, Tuple
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
from transformers import CohereConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -43,7 +41,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, row_parallel_weight_loader)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
@ -67,25 +66,14 @@ class LayerNorm(nn.Module):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(param_shape))
|
||||
self.variance_epsilon = eps
|
||||
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
|
||||
set_weight_attrs(self.weight,
|
||||
{"weight_loader": row_parallel_weight_loader})
|
||||
|
||||
def forward(self, hidden_states, residuals=None):
|
||||
hidden_states = layer_norm_func(hidden_states, self.weight,
|
||||
self.variance_epsilon)
|
||||
return hidden_states, residuals
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_dim = 0 if param.dim() != 1 else None
|
||||
param_data = param.data
|
||||
if shard_dim is not None:
|
||||
shard_size = param_data.shape[shard_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
|
||||
shard_size)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
|
||||
class CohereMLP(nn.Module):
|
||||
@ -359,8 +347,11 @@ class CohereForCausalLM(nn.Module):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
is_not_lora = hasattr(self.model.embed_tokens, 'weight')
|
||||
if is_not_lora:
|
||||
logits = self.logits_processor(self.model.embed_tokens,
|
||||
|
||||
@ -388,8 +388,11 @@ class DbrxForCausalLM(nn.Module):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -395,8 +395,11 @@ class DeepseekForCausalLM(nn.Module):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -505,8 +505,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -420,8 +420,11 @@ class FalconForCausalLM(nn.Module):
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -287,8 +287,11 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.language_model.logits_processor(
|
||||
self.language_model.lm_head, hidden_states, sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -352,8 +352,11 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -343,8 +343,11 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -265,8 +265,11 @@ class GPT2LMHeadModel(nn.Module):
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -279,8 +279,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -246,8 +246,11 @@ class GPTJForCausalLM(nn.Module):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata, self.lm_head.bias)
|
||||
return logits
|
||||
|
||||
@ -258,8 +258,11 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.embed_out, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -279,8 +279,11 @@ class InternLM2ForCausalLM(nn.Module):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.output, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -466,8 +466,11 @@ class InternVLChatModel(nn.Module, SupportsVision):
|
||||
inputs_embeds=inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
|
||||
@ -295,8 +295,11 @@ class JAISLMHeadModel(nn.Module):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -861,8 +861,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
||||
dtype=dtype,
|
||||
device="cuda"))
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -430,8 +430,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
attn_metadata, intermediate_tensors)
|
||||
return model_output
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -355,8 +355,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
|
||||
@ -588,8 +588,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
|
||||
@ -65,22 +65,28 @@ class Medusa(nn.Module):
|
||||
def compute_logits(
|
||||
self, hidden_states: List[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata) -> List[torch.Tensor]:
|
||||
logits = []
|
||||
logits_lst: List[torch.Tensor] = []
|
||||
|
||||
for hs, lm_head in zip(hidden_states, self.lm_heads):
|
||||
_logits = self.logits_processor(lm_head, hs, sampling_metadata)
|
||||
|
||||
if _logits is None:
|
||||
# _logits should only be None on rank > 0, in which case
|
||||
# it should remain true for every lm_head
|
||||
assert len(logits_lst) == 0
|
||||
continue
|
||||
|
||||
if self.token_map is None:
|
||||
logits.append(_logits)
|
||||
logits_lst.append(_logits)
|
||||
else:
|
||||
logits.append(-torch.inf * torch.ones(
|
||||
logits_lst.append(-torch.inf * torch.ones(
|
||||
size=(*_logits.shape[:-1], self.orig_vocab_size),
|
||||
device=_logits.device,
|
||||
dtype=_logits.dtype))
|
||||
|
||||
logits[-1][..., self.token_map] = _logits
|
||||
logits_lst[-1][..., self.token_map] = _logits
|
||||
|
||||
return logits
|
||||
return logits_lst
|
||||
|
||||
def sample(
|
||||
self,
|
||||
|
||||
@ -470,8 +470,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
hidden_states = hidden_states / self.scale_width
|
||||
if self.config.tie_word_embeddings:
|
||||
lm_head = self.model.embed_tokens
|
||||
|
||||
@ -630,8 +630,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsVision):
|
||||
)
|
||||
return output
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -375,8 +375,11 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -362,8 +362,11 @@ class MixtralForCausalLM(nn.Module):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -279,8 +279,11 @@ class MPTForCausalLM(nn.Module):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -453,8 +453,11 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA):
|
||||
attn_metadata, intermediate_tensors)
|
||||
return model_output
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -311,8 +311,11 @@ class OlmoForCausalLM(nn.Module):
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -323,8 +323,11 @@ class OPTForCausalLM(nn.Module):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -277,8 +277,11 @@ class OrionForCausalLM(nn.Module):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -262,8 +262,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
return hidden_states
|
||||
|
||||
# Copied from vllm/model_executor/models/gemma.py
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.language_model.embed_tokens,
|
||||
hidden_states, sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -285,8 +285,11 @@ class PersimmonForCausalLM(nn.Module):
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -286,8 +286,11 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata, self.lm_head.bias)
|
||||
return logits
|
||||
|
||||
@ -399,8 +399,11 @@ class Phi3SmallForCausalLM(nn.Module):
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
if self.dummy_token_indices is not None and logits is not None:
|
||||
|
||||
@ -584,8 +584,11 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -281,8 +281,11 @@ class QWenLMHeadModel(nn.Module):
|
||||
device=device),
|
||||
})
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -362,8 +362,11 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -400,8 +400,11 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -258,8 +258,11 @@ class StablelmForCausalLM(nn.Module):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -268,8 +268,11 @@ class Starcoder2ForCausalLM(nn.Module):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -328,8 +328,11 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
|
||||
@ -28,7 +30,7 @@ class CompletionOutput:
|
||||
|
||||
index: int
|
||||
text: str
|
||||
token_ids: Tuple[int, ...]
|
||||
token_ids: GenericSequence[int]
|
||||
cumulative_logprob: Optional[float]
|
||||
logprobs: Optional[SampleLogprobs]
|
||||
finish_reason: Optional[str] = None
|
||||
@ -139,7 +141,7 @@ class RequestOutput:
|
||||
CompletionOutput(
|
||||
seqs.index(seq),
|
||||
seq.get_output_text_to_return(text_buffer_length),
|
||||
seq.data._output_token_ids, # type: ignore
|
||||
seq.data._output_token_ids,
|
||||
seq.get_cumulative_logprob() if include_logprobs else None,
|
||||
seq.output_logprobs if include_logprobs else None,
|
||||
SequenceStatus.get_finished_reason(seq.status),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user