[Bugfix] Fix weight loading for Chameleon when TP>1 (#7410)

This commit is contained in:
Cyrus Leung 2024-08-13 13:33:41 +08:00 committed by GitHub
parent 5469146bcc
commit 7025b11d94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
59 changed files with 411 additions and 202 deletions

View File

@ -4,7 +4,8 @@ import os
import sys
from collections import UserList
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
TypeVar, Union)
import pytest
import torch
@ -27,7 +28,7 @@ from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sequence import SampleLogprobs
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
is_cpu)
identity, is_cpu)
logger = init_logger(__name__)
@ -197,6 +198,8 @@ class HfRunner:
is_embedding_model: bool = False,
is_vision_model: bool = False,
is_encoder_decoder_model: bool = False,
postprocess_inputs: Callable[[BatchEncoding],
BatchEncoding] = identity,
) -> None:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
@ -242,12 +245,14 @@ class HfRunner:
torch_dtype=torch_dtype,
trust_remote_code=True,
)
except Exception:
except Exception as exc:
logger.warning(
"Unable to auto-load processor from HuggingFace for "
"model %s. Using tokenizer instead.", model_name)
"Unable to auto-load HuggingFace processor for model (%s). "
"Using tokenizer instead. Reason: %s", model_name, exc)
self.processor = self.tokenizer
self.postprocess_inputs = postprocess_inputs
def generate(
self,
prompts: List[str],
@ -267,6 +272,7 @@ class HfRunner:
processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
output_ids = self.model.generate(
**self.wrap_device(inputs),
@ -336,6 +342,7 @@ class HfRunner:
processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
output = self.model.generate(
**self.wrap_device(inputs),
@ -420,6 +427,7 @@ class HfRunner:
processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
output = self.model.generate(
**self.wrap_device(inputs),
@ -552,7 +560,8 @@ class VllmRunner:
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[List[Image.Image]] = None,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None:
assert len(prompts) == len(images)
@ -587,7 +596,7 @@ class VllmRunner:
for req_output in req_outputs:
for sample in req_output.outputs:
output_str = sample.text
output_ids = sample.token_ids
output_ids = list(sample.token_ids)
output_logprobs = sample.logprobs
outputs.append((output_ids, output_str, output_logprobs))
return outputs
@ -596,7 +605,8 @@ class VllmRunner:
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[List[Image.Image]] = None,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None

View File

@ -18,8 +18,10 @@ from ..utils import fork_new_process_for_each_test
@pytest.mark.parametrize("model, distributed_executor_backend", [
("llava-hf/llava-1.5-7b-hf", "ray"),
("llava-hf/llava-v1.6-mistral-7b-hf", "ray"),
("facebook/chameleon-7b", "ray"),
("llava-hf/llava-1.5-7b-hf", "mp"),
("llava-hf/llava-v1.6-mistral-7b-hf", "mp"),
("facebook/chameleon-7b", "mp"),
])
@fork_new_process_for_each_test
def test_models(hf_runner, vllm_runner, image_assets, model: str,
@ -34,6 +36,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str,
from ..models.test_llava import models, run_test
elif model.startswith("llava-hf/llava-v1.6"):
from ..models.test_llava_next import models, run_test
elif model.startswith("facebook/chameleon"):
from ..models.test_chameleon import models, run_test
else:
raise NotImplementedError(f"Unsupported model: {model}")

View File

@ -1,5 +1,6 @@
import sys
import time
from typing import Optional
import torch
from openai import OpenAI, OpenAIError
@ -17,8 +18,11 @@ assert chatml_jinja_path.exists()
class MyOPTForCausalLM(OPTForCausalLM):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
logits.zero_()

View File

@ -1,11 +1,13 @@
import re
from typing import List, Optional, Type
import pytest
from transformers import BatchEncoding
from vllm.multimodal.utils import rescale_image_size
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ..conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_outputs_equal
pytestmark = pytest.mark.vlm
@ -19,9 +21,8 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
models = ["facebook/chameleon-7b"]
#TODO (ywang96): Add correctness test when chameleon is
# available on transformers.
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
@ -29,13 +30,20 @@ def run_test(
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Test if the model can generate text given
a batch of images and prompts.
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
@ -50,35 +58,49 @@ def run_test(
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
for prompts, images in inputs_per_image:
vllm_outputs = vllm_model.generate_greedy(prompts,
max_tokens,
images=images)
for i in range(len(vllm_outputs)):
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
# format prompt back to original
replacements = {
"<racm3:break>": "",
"<eoss>": "",
"<reserved08706>": ""
}
pattern = '|'.join(replacements.keys())
vllm_result = re.sub(
pattern,
lambda match: replacements[match.group(0)], #noqa B023
vllm_outputs[i][1])
vllm_result = vllm_result.replace("<image>", "", 1023)
assert vllm_result[:len(prompts[i])] == prompts[i]
def process(hf_inputs: BatchEncoding):
hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
.to(torch_dtype) # type: ignore
return hf_inputs
# assert at least 10 new characters are generated
# (to take stop token into account)
assert len(vllm_outputs[i][1]) - len(prompts[i]) > 10
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
is_vision_model=True) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
# HF Logprobs include image tokens, unlike vLLM, so we don't directly
# compare them
check_outputs_equal(
outputs_0_lst=[outputs[:2] for outputs in hf_outputs],
outputs_1_lst=[outputs[:2] for outputs in vllm_outputs],
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
@ -88,15 +110,18 @@ def run_test(
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(vllm_runner, image_assets, model, size_factors, dtype: str,
max_tokens: int) -> None:
@pytest.mark.parametrize("max_tokens", [8])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype, max_tokens, num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)

View File

@ -1,7 +1,7 @@
from typing import List, Optional, Tuple, Type
import pytest
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, AutoTokenizer, BatchEncoding
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
@ -110,16 +110,21 @@ def run_test(
for prompts, images in inputs_per_image
]
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
if mantis_processor is not None:
if mantis_processor is not None:
def process(*args, **kwargs):
output = mantis_processor(*args, **kwargs)
output["pixel_values"] = output["pixel_values"].to(torch_dtype)
return output
def process(hf_inputs: BatchEncoding):
hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
.to(torch_dtype) # type: ignore
return hf_inputs
else:
hf_model.processor = process
def process(hf_inputs: BatchEncoding):
return hf_inputs
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
is_vision_model=True) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,

View File

@ -1,10 +1,9 @@
from collections import UserDict
from typing import List, Optional, Tuple, Type
import pytest
import torch
import torch.types
from transformers import BatchFeature
from transformers import BatchEncoding
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
@ -14,18 +13,6 @@ from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
class NestedInputs(UserDict):
def __init__(self, model_inputs: BatchFeature):
super().__init__({"model_inputs": model_inputs})
self.model_inputs = model_inputs
def to(self, device: torch.types.Device):
return NestedInputs(self.model_inputs.to(device))
# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
@ -41,6 +28,10 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
models = ["openbmb/MiniCPM-Llama3-V-2_5"]
def _wrap_inputs(hf_inputs: BatchEncoding) -> BatchEncoding:
return BatchEncoding({"model_inputs": hf_inputs})
def trunc_hf_output(hf_output: Tuple[List[int], str,
Optional[SampleLogprobs]]):
output_ids, output_str, out_logprobs = hf_output
@ -105,11 +96,8 @@ def run_test(
for prompts, images in inputs_per_image
]
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
hf_processor = hf_model.processor
hf_model.processor = lambda **kw: NestedInputs(
hf_processor(**kw) # type: ignore
)
hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs)
with hf_model, torch.no_grad():
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
@ -224,11 +212,8 @@ def run_multi_image_test(
for prompts, images in inputs_per_case
]
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
hf_processor = hf_model.processor
hf_model.processor = lambda **kw: NestedInputs(
hf_processor(**kw) # type: ignore
)
hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs)
with hf_model, torch.no_grad():
hf_outputs_per_case = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,

View File

@ -1,3 +1,5 @@
from typing import Optional
import torch
from vllm import LLM, ModelRegistry, SamplingParams
@ -7,8 +9,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
class MyOPTForCausalLM(OPTForCausalLM):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
logits.zero_()

View File

@ -19,7 +19,7 @@ def tensor_model_parallel_all_gather(input_: torch.Tensor,
def tensor_model_parallel_gather(input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> torch.Tensor:
dim: int = -1) -> Optional[torch.Tensor]:
"""Gather the input tensor across model parallel group."""
return get_tp_group().gather(input_, dst, dim)

View File

@ -329,7 +329,7 @@ class GroupCoordinator:
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> torch.Tensor:
dim: int = -1) -> Optional[torch.Tensor]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.

View File

@ -50,7 +50,7 @@ class LogitsProcessor(nn.Module):
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Optional[torch.Tensor]:
if self.logits_as_input:
logits = hidden_states
else:
@ -73,14 +73,18 @@ class LogitsProcessor(nn.Module):
return logits
def _get_logits(self, hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
if self.use_gather:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits)
else:
# Gather is not supported for some devices such as TPUs.

View File

@ -19,6 +19,7 @@ from tqdm.auto import tqdm
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import LoadConfig, ModelConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config)
@ -514,8 +515,30 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
def default_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
try:
assert param.size() == loaded_weight.size(), (
f"Attempted to load weight ({loaded_weight.size()}) "
f"into parameter ({param.size()})")
param.data.copy_(loaded_weight)
except Exception:
# NOTE: This exception is added for the purpose of setting breakpoint to
# debug weight loading issues.
raise
def row_parallel_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
"""Load weights that are row-parallelized."""
tp_rank = get_tensor_model_parallel_rank()
shard_dim = 0 if param.dim() != 1 else None
if shard_dim is not None:
shard_size = param.data.shape[shard_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
return default_weight_loader(param, loaded_weight)
def initialize_dummy_weights(

View File

@ -433,8 +433,11 @@ class ArcticForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -346,8 +346,11 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -872,8 +872,11 @@ class BartForConditionalGeneration(nn.Module):
return self.model(input_ids, positions, encoder_input_ids,
encoder_positions, kv_caches, attn_metadata)
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -637,8 +637,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.get_lm_head(), hidden_states,
sampling_metadata)
return logits

View File

@ -292,8 +292,11 @@ class BloomForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -25,8 +25,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, row_parallel_weight_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
@ -141,6 +143,11 @@ class ChameleonLayerNorm(nn.LayerNorm):
super().__init__(hidden_size, *args, **kwargs)
self.normalized_shape = (hidden_size[-1], )
set_weight_attrs(self.weight,
{"weight_loader": row_parallel_weight_loader})
set_weight_attrs(self.bias,
{"weight_loader": row_parallel_weight_loader})
def forward(self, hidden_states):
hidden_states = F.layer_norm(hidden_states,
self.normalized_shape,
@ -697,6 +704,8 @@ class ChameleonVQVAEEncoder(nn.Module):
)
def forward(self, pixel_values: torch.Tensor):
pixel_values = pixel_values.to(self.conv_in.weight.dtype)
# downsampling
hidden_states = [self.conv_in(pixel_values)]
for i_level in range(self.num_resolutions):
@ -959,15 +968,19 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
# Disallow image tokens which does not include special
# begin-image and end-image tokens
image_tokens = self.model.vocabulary_mapping.image_tokens
logits[:, image_tokens] = torch.finfo(logits.dtype).min
if logits is not None:
image_tokens = self.model.vocabulary_mapping.image_tokens
logits[:, image_tokens] = torch.finfo(logits.dtype).min
return logits

View File

@ -372,8 +372,11 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -25,13 +25,11 @@ from typing import Iterable, List, Optional, Set, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn.parameter import Parameter
from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
@ -43,7 +41,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, row_parallel_weight_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors, SamplerOutput
@ -67,25 +66,14 @@ class LayerNorm(nn.Module):
super().__init__()
self.weight = nn.Parameter(torch.ones(param_shape))
self.variance_epsilon = eps
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
set_weight_attrs(self.weight,
{"weight_loader": row_parallel_weight_loader})
def forward(self, hidden_states, residuals=None):
hidden_states = layer_norm_func(hidden_states, self.weight,
self.variance_epsilon)
return hidden_states, residuals
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_dim = 0 if param.dim() != 1 else None
param_data = param.data
if shard_dim is not None:
shard_size = param_data.shape[shard_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module):
@ -359,8 +347,11 @@ class CohereForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
is_not_lora = hasattr(self.model.embed_tokens, 'weight')
if is_not_lora:
logits = self.logits_processor(self.model.embed_tokens,

View File

@ -388,8 +388,11 @@ class DbrxForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -395,8 +395,11 @@ class DeepseekForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -505,8 +505,11 @@ class DeepseekV2ForCausalLM(nn.Module):
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -420,8 +420,11 @@ class FalconForCausalLM(nn.Module):
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -287,8 +287,11 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.language_model.logits_processor(
self.language_model.lm_head, hidden_states, sampling_metadata)
return logits

View File

@ -352,8 +352,11 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
sampling_metadata)
return logits

View File

@ -343,8 +343,11 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
sampling_metadata)
return logits

View File

@ -265,8 +265,11 @@ class GPT2LMHeadModel(nn.Module):
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -279,8 +279,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -246,8 +246,11 @@ class GPTJForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits

View File

@ -258,8 +258,11 @@ class GPTNeoXForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.embed_out, hidden_states,
sampling_metadata)
return logits

View File

@ -279,8 +279,11 @@ class InternLM2ForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.output, hidden_states,
sampling_metadata)
return logits

View File

@ -466,8 +466,11 @@ class InternVLChatModel(nn.Module, SupportsVision):
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)

View File

@ -295,8 +295,11 @@ class JAISLMHeadModel(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -861,8 +861,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
dtype=dtype,
device="cuda"))
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -430,8 +430,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
attn_metadata, intermediate_tensors)
return model_output
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -355,8 +355,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)

View File

@ -588,8 +588,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)

View File

@ -65,22 +65,28 @@ class Medusa(nn.Module):
def compute_logits(
self, hidden_states: List[torch.Tensor],
sampling_metadata: SamplingMetadata) -> List[torch.Tensor]:
logits = []
logits_lst: List[torch.Tensor] = []
for hs, lm_head in zip(hidden_states, self.lm_heads):
_logits = self.logits_processor(lm_head, hs, sampling_metadata)
if _logits is None:
# _logits should only be None on rank > 0, in which case
# it should remain true for every lm_head
assert len(logits_lst) == 0
continue
if self.token_map is None:
logits.append(_logits)
logits_lst.append(_logits)
else:
logits.append(-torch.inf * torch.ones(
logits_lst.append(-torch.inf * torch.ones(
size=(*_logits.shape[:-1], self.orig_vocab_size),
device=_logits.device,
dtype=_logits.dtype))
logits[-1][..., self.token_map] = _logits
logits_lst[-1][..., self.token_map] = _logits
return logits
return logits_lst
def sample(
self,

View File

@ -470,8 +470,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings:
lm_head = self.model.embed_tokens

View File

@ -630,8 +630,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsVision):
)
return output
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -375,8 +375,11 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -362,8 +362,11 @@ class MixtralForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -279,8 +279,11 @@ class MPTForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -453,8 +453,11 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA):
attn_metadata, intermediate_tensors)
return model_output
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -311,8 +311,11 @@ class OlmoForCausalLM(nn.Module):
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -323,8 +323,11 @@ class OPTForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -277,8 +277,11 @@ class OrionForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -262,8 +262,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
return hidden_states
# Copied from vllm/model_executor/models/gemma.py
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.language_model.embed_tokens,
hidden_states, sampling_metadata)
return logits

View File

@ -285,8 +285,11 @@ class PersimmonForCausalLM(nn.Module):
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -286,8 +286,11 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits

View File

@ -399,8 +399,11 @@ class Phi3SmallForCausalLM(nn.Module):
def get_decoder(self):
return self.model
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
if self.dummy_token_indices is not None and logits is not None:

View File

@ -584,8 +584,11 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -281,8 +281,11 @@ class QWenLMHeadModel(nn.Module):
device=device),
})
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -362,8 +362,11 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -400,8 +400,11 @@ class Qwen2MoeForCausalLM(nn.Module):
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -258,8 +258,11 @@ class StablelmForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -268,8 +268,11 @@ class Starcoder2ForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -328,8 +328,11 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

View File

@ -1,6 +1,8 @@
import time
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import List, Optional
from typing import Sequence as GenericSequence
from typing import Union
from vllm.lora.request import LoRARequest
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
@ -28,7 +30,7 @@ class CompletionOutput:
index: int
text: str
token_ids: Tuple[int, ...]
token_ids: GenericSequence[int]
cumulative_logprob: Optional[float]
logprobs: Optional[SampleLogprobs]
finish_reason: Optional[str] = None
@ -139,7 +141,7 @@ class RequestOutput:
CompletionOutput(
seqs.index(seq),
seq.get_output_text_to_return(text_buffer_length),
seq.data._output_token_ids, # type: ignore
seq.data._output_token_ids,
seq.get_cumulative_logprob() if include_logprobs else None,
seq.output_logprobs if include_logprobs else None,
SequenceStatus.get_finished_reason(seq.status),