Migrate logits computation and gather to model_runner (#3233)

This commit is contained in:
Roy 2024-03-21 07:25:01 +08:00 committed by GitHub
parent 6e435de766
commit f1c0fc3919
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 576 additions and 305 deletions

View File

@ -49,6 +49,9 @@ steps:
- label: Samplers Test
command: pytest -v -s samplers
- label: LogitsProcessor Test
command: pytest -v -s test_logits_processor.py
- label: Worker Test
command: pytest -v -s worker

View File

@ -13,6 +13,7 @@ from huggingface_hub import snapshot_download
import vllm
from vllm.config import LoRAConfig
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
@ -85,7 +86,8 @@ def dummy_model() -> nn.Module:
("outact", nn.Sigmoid()),
# Special handling for lm_head & sampler
("lm_head", ParallelLMHead(512, 10)),
("sampler", Sampler(512))
("logits_processor", LogitsProcessor(512)),
("sampler", Sampler())
]))
model.config = MagicMock()
return model
@ -110,7 +112,8 @@ def dummy_model_gate_up() -> nn.Module:
("outact", nn.Sigmoid()),
# Special handling for lm_head & sampler
("lm_head", ParallelLMHead(512, 10)),
("sampler", Sampler(512))
("logits_processor", LogitsProcessor(512)),
("sampler", Sampler())
]))
model.config = MagicMock()
return model

View File

@ -13,14 +13,14 @@ from vllm.lora.layers import (
QKVParallelLinearWithLora,
VocabParallelEmbeddingWithLoRA,
RowParallelLinearWithLoRA,
SamplerWithLoRA,
LogitsProcessorWithLoRA,
LoRAMapping,
BaseLayerWithLoRA,
)
from vllm.lora.models import (LoRALayerWeights, convert_mapping,
PackedLoRALayerWeights)
from vllm.config import LoRAConfig
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
@ -394,7 +394,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lm_head_sampler(dist_init, num_loras, device) -> None:
def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
torch.set_default_device(device)
max_loras = 8
@ -402,28 +402,29 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None:
max_lora_rank=8,
lora_dtype=torch.float16)
def create_random_sampler_layer():
def _pretest():
linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size,
1024, 32000)
linear.weight.data = torch.rand_like(linear.weight.data)
linear.weight.data[:, 32000:] = 0
sampler = Sampler(32000 + lora_config.lora_extra_vocab_size, 32000)
lora_sampler = SamplerWithLoRA(sampler, 1024, linear.weight.dtype,
linear.weight.device)
lora_sampler.create_lora_weights(max_loras, lora_config)
logits_processor = LogitsProcessor(
32000 + lora_config.lora_extra_vocab_size, 32000)
lora_logits_processor = LogitsProcessorWithLoRA(
logits_processor, 1024, linear.weight.dtype, linear.weight.device)
lora_logits_processor.create_lora_weights(max_loras, lora_config)
return linear, sampler, lora_sampler
return linear, logits_processor, lora_logits_processor
for i in range(10):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, sampler, lora_sampler = create_random_sampler_layer()
linear, logits_processor, lora_logits_processor = _pretest()
# NOTE: all the generated loras share the same embeddings tensor.
lora_dict, _ = populate_loras(
id_to_index,
layer=lora_sampler,
layer=lora_logits_processor,
layer_weights=linear.weight,
generate_embeddings_tensor=1024,
)
@ -447,34 +448,37 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None:
32000,
lora_config.lora_extra_vocab_size,
)
lora_sampler.set_mapping(*mapping_info, )
lora_logits_processor.set_mapping(*mapping_info, )
lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs),
embedding=linear.weight,
embedding_bias=None)
lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=linear.weight,
embedding_bias=None)
original_weight = linear.weight.clone()
linear.weight[sampler.org_vocab_size:sampler.org_vocab_size +
linear.weight[logits_processor.
org_vocab_size:logits_processor.org_vocab_size +
embeddings_tensor_len] = embeddings_tensor
sampler.org_vocab_size = 32000 + lora_config.lora_extra_vocab_size
logits_processor.org_vocab_size = (32000 +
lora_config.lora_extra_vocab_size)
expected_results = []
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = sampler._get_logits(hidden_states=input_,
embedding=linear.weight,
embedding_bias=None)
result = logits_processor._get_logits(hidden_states=input_,
embedding=linear.weight,
embedding_bias=None)
result[:, 32000 + embeddings_tensor_len:] = float("-inf")
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
sampler.org_vocab_size = 32000
logits_processor.org_vocab_size = 32000
# Check that resetting the lora weights succeeds
for slot_idx in range(max_loras):
lora_sampler.reset_lora(slot_idx)
lora_logits_processor.reset_lora(slot_idx)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
@ -488,14 +492,16 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None:
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
32000,
lora_config.lora_extra_vocab_size)
lora_sampler.set_mapping(*mapping_info, )
lora_logits_processor.set_mapping(*mapping_info, )
lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs),
embedding=original_weight,
embedding_bias=None)[:, :32000]
expected_result = sampler._get_logits(hidden_states=torch.cat(inputs),
embedding=original_weight,
embedding_bias=None)
lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=original_weight,
embedding_bias=None)[:, :32000]
expected_result = logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=original_weight,
embedding_bias=None)
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,

View File

@ -15,17 +15,12 @@ from vllm.worker.model_runner import ModelRunner
class MockLogitsSampler(Sampler):
def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
super().__init__(vocab_size=vocab_size)
def __init__(self, fake_logits: torch.Tensor):
super().__init__()
self.fake_logits = fake_logits
def forward(self, *args, **kwargs):
with patch(
"vllm.model_executor.layers.sampler._prune_hidden_states",
lambda x, y: x), patch(
"vllm.model_executor.layers.sampler.Sampler._get_logits",
lambda *args, **kwargs: self.fake_logits):
return super().forward(*args, **kwargs)
return super().forward(*args, **kwargs)
def _prepare_test(
@ -36,7 +31,7 @@ def _prepare_test(
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits)
sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(None, None, None, None, None)
return input_tensor, fake_logits, sampler, model_runner
@ -70,9 +65,7 @@ def _do_sample(
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
return sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@ -85,8 +78,8 @@ def test_sampler_all_greedy(seed: int, device: str):
batch_size)
sampling_params = SamplingParams(temperature=0)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
@ -111,8 +104,8 @@ def test_sampler_all_random(seed: int, device: str):
temperature=1.0,
n=random.randint(1, 10),
)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
@ -127,8 +120,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
for i in range(batch_size):
fake_logits[i, i] = 1e2
@ -138,8 +130,8 @@ def test_sampler_all_random_seed(seed: int, device: str):
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
@ -154,18 +146,17 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
first_sampler_output = _do_sample(batch_size, input_tensor, sampler,
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params)
second_sampler_output = _do_sample(batch_size, input_tensor, sampler,
second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params)
assert first_sampler_output == second_sampler_output
@ -179,15 +170,14 @@ def test_sampler_all_beam(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
sampling_params = SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
)
_do_sample(batch_size, input_tensor, sampler, model_runner,
sampling_params)
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
@ -246,8 +236,7 @@ def test_sampler_mixed(seed: int, device: str):
def test_sampling(model_runner: ModelRunner):
sampling_metadata = model_runner._prepare_sample(
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata)
for i, (sequence_output, metadata) in enumerate(
@ -294,48 +283,6 @@ def test_sampler_mixed(seed: int, device: str):
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_logits_processors(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
# This sample logits processor gives maximum score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = torch.finfo(logits.dtype).max
return logits
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
for _, sequence_output in enumerate(sampler_output):
for idx, nth_output in enumerate(sequence_output.samples):
assert nth_output.output_token == idx
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_top_k_top_p(seed: int, device: str):
@ -352,7 +299,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
size=(batch_size, vocab_size),
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits)
sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(None, None, None, None, None)
generation_model = GenerationMixin()
@ -391,9 +338,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)

View File

@ -0,0 +1,94 @@
import random
from typing import Tuple
from unittest.mock import patch
import pytest
import torch
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner
class MockLogitsProcessor(LogitsProcessor):
def __init__(self, vocab_size: int, scale: float,
fake_logits: torch.Tensor):
super().__init__(vocab_size=vocab_size, scale=scale)
self.fake_logits = fake_logits.clone()
def forward(self, *args, **kwargs):
with patch(
"vllm.model_executor.layers.logits_processor._prune_hidden_states",
lambda x, y: x
), patch(
"vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits",
lambda *args, **kwargs: self.fake_logits):
return super().forward(*args, **kwargs)
def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor, ModelRunner]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
dtype=input_tensor.dtype)
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
model_runner = ModelRunner(None, None, None, None, None)
return input_tensor, fake_logits, logits_processor, model_runner
RANDOM_SEEDS = list(range(128))
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_logits_processors(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, logits_processor, model_runner = _prepare_test(
batch_size)
# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = float("inf")
return logits
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
logits_processor_output = logits_processor(
embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
assert torch.isinf(logits_processor_output[:, 0]).all()
fake_logits *= logits_processor.scale
assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1],
1e-4)
del model_runner

View File

@ -10,7 +10,6 @@ from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
@ -20,6 +19,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear,
QKVParallelLinear,
MergedColumnParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
@ -783,11 +783,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
return self.base_layer.weight
class SamplerWithLoRA(BaseLayerWithLoRA):
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def __init__(
self,
base_layer: Sampler,
base_layer: LogitsProcessor,
hidden_size: int,
dtype: torch.dtype,
device: torch.device,
@ -806,6 +806,10 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
def vocab_size(self):
return self.base_layer.vocab_size
@property
def scale(self):
return self.base_layer.scale
@property
def org_vocab_size(self):
return self.base_layer.org_vocab_size
@ -968,14 +972,14 @@ def from_layer(
return layer
def from_layer_sampler(
layer: Sampler,
def from_layer_logits_processor(
layer: LogitsProcessor,
lm_head: ParallelLMHead,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> SamplerWithLoRA:
ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype,
lm_head.weight.device)
) -> LogitsProcessorWithLoRA:
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
lm_head.weight.dtype, lm_head.weight.device)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret

View File

@ -14,7 +14,7 @@ from vllm.config import LoRAConfig
from vllm.utils import LRUCache, in_wsl
from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer,
from_layer_sampler)
from_layer_logits_processor)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
@ -421,11 +421,14 @@ class LoRAModelManager:
self.model.config))
# (yard1): TODO make this more robust
if "lm_head" in module_name:
sampler_module = self.model.get_submodule("sampler")
logits_processor_module = self.model.get_submodule(
"logits_processor")
new_module = replace_submodule(
self.model, "sampler",
from_layer_sampler(sampler_module, module, self.lora_slots,
self.lora_config, self.model.config))
self.model, "logits_processor",
from_layer_logits_processor(logits_processor_module,
module, self.lora_slots,
self.lora_config,
self.model.config))
self.register_module(module_name, new_module)
self._register_packed_modules(module_name)
new_module.set_mapping(self.base_indices, self.sampler_indices,

View File

@ -0,0 +1,106 @@
"""A layer that compute logits from hidden_stats."""
from typing import Optional
import torch
import torch.nn as nn
from vllm.utils import is_neuron
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_gather)
from vllm.model_executor.sampling_metadata import SamplingMetadata
class LogitsProcessor(nn.Module):
"""Process logits and apply logits processors from sampling metadata.
This layer does the following:
1. Gather logits from model hidden_states.
2. Scale logits if needed.
3. Apply logits processors (if any).
"""
def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None,
scale: Optional[float] = 1.0) -> None:
"""
Args:
scale: A scaling factor to apply to the logits.
"""
super().__init__()
self.scale = scale
self.vocab_size = vocab_size
# Transformers-neuronx generate outputs as logits directly.
self.logits_as_hidden_states = is_neuron()
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
def forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.logits_as_hidden_states:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)
if logits is not None:
logits *= self.scale
# Apply logits processors (if any).
logits = _apply_logits_processors(logits, sampling_metadata)
return logits
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits
def _prune_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
return hidden_states.index_select(0,
sampling_metadata.selected_token_indices)
def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
logits_row_idx = 0
found_logits_processors = False
for seq_ids, sampling_params in sampling_metadata.seq_groups:
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
logits_row = logits[logits_row_idx]
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
logits_row_idx += len(seq_ids)
if found_logits_processors:
assert logits_row_idx == logits.shape[0]
return logits

View File

@ -4,8 +4,6 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_gather)
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors)
from vllm.sampling_params import SamplingParams, SamplingType
@ -13,7 +11,6 @@ from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
SamplerOutput, SequenceData, SequenceGroupOutput,
SequenceOutput)
from vllm.model_executor.layers.ops.sample import (sample as sample_triton)
from vllm.utils import is_neuron
class Sampler(nn.Module):
@ -31,58 +28,14 @@ class Sampler(nn.Module):
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
"""
def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None) -> None:
super().__init__()
self.vocab_size = vocab_size
# Transformers-neuronx generate outputs as logits directly.
self.logits_as_hidden_states = is_neuron()
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits
def forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[SamplerOutput]:
# Get the hidden states that we use for sampling.
if self.logits_as_hidden_states:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)
# Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because
# the `embedding` weight is distributed across TP workers.
# TODO(zhuohan): Change the get_logits part to a separate stage.
if not sampling_metadata.perform_sampling:
return None
assert logits is not None
_, vocab_size = logits.shape
# Apply logits processors (if any).
logits = _apply_logits_processors(logits, sampling_metadata)
# Prepare sampling tensors with pinned memory to avoid blocking.
(sampling_tensors, do_penalties, do_top_p_top_k,
do_min_p) = SamplingTensors.from_sampling_metadata(
@ -124,14 +77,6 @@ class Sampler(nn.Module):
prompt_logprobs, sample_logprobs)
def _prune_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
return hidden_states.index_select(0,
sampling_metadata.selected_token_indices)
def _get_bin_counts_and_mask(
tokens: torch.Tensor,
vocab_size: int,
@ -149,30 +94,6 @@ def _get_bin_counts_and_mask(
return bin_counts, mask
def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
logits_row_idx = 0
found_logits_processors = False
for seq_ids, sampling_params in sampling_metadata.seq_groups:
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
logits_row = logits[logits_row_idx]
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
logits_row_idx += len(seq_ids)
if found_logits_processors:
assert logits_row_idx == logits.shape[0]
return logits
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor: torch.Tensor,
presence_penalties: torch.Tensor,

View File

@ -34,6 +34,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
@ -295,7 +296,8 @@ class BaiChuanBaseForCausalLM(nn.Module):
self.linear_method = linear_method
self.model = BaiChuanModel(config, position_embedding, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -308,13 +310,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@ -273,7 +274,8 @@ class BloomForCausalLM(nn.Module):
self.linear_method = linear_method
self.transformer = BloomModel(config, linear_method)
self.lm_head_weight = self.transformer.word_embeddings.weight
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -286,13 +288,18 @@ class BloomForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -17,6 +17,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
@ -332,7 +333,8 @@ class ChatGLMForCausalLM(nn.Module):
self.linear_method = linear_method
self.transformer = ChatGLMModel(config, linear_method)
self.lm_head_weight = self.transformer.output_layer.weight
self.sampler = Sampler(config.padded_vocab_size)
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -345,13 +347,18 @@ class ChatGLMForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -38,6 +38,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
@ -372,7 +373,8 @@ class DeepseekForCausalLM(nn.Module):
self.linear_method = linear_method
self.model = DeepseekModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -385,13 +387,18 @@ class DeepseekForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: Optional[torch.Tensor],
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -34,6 +34,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
@ -373,7 +374,8 @@ class FalconForCausalLM(nn.Module):
config.vocab_size,
config.hidden_size,
)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -390,13 +392,18 @@ class FalconForCausalLM(nn.Module):
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@ -281,7 +282,8 @@ class GemmaForCausalLM(nn.Module):
self.config = config
self.linear_method = linear_method
self.model = GemmaModel(config, linear_method)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@torch.no_grad()
def forward(
@ -295,13 +297,18 @@ class GemmaForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.embed_tokens.weight,
hidden_states, sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.model.embed_tokens.weight,
hidden_states, sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@ -216,7 +217,8 @@ class GPT2LMHeadModel(nn.Module):
self.linear_method = linear_method
self.transformer = GPT2Model(config, linear_method)
self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -229,12 +231,18 @@ class GPT2LMHeadModel(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
next_tokens = self.sampler(self.lm_head_weight, logits,
sampling_metadata)
return next_tokens

View File

@ -31,6 +31,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@ -237,7 +238,8 @@ class GPTBigCodeForCausalLM(nn.Module):
self.linear_method = linear_method
self.transformer = GPTBigCodeModel(config, linear_method)
self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -250,13 +252,18 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
@ -224,7 +225,8 @@ class GPTJForCausalLM(nn.Module):
config.n_embd,
bias=True,
)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -237,13 +239,18 @@ class GPTJForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata, self.lm_head.bias)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
@ -238,7 +239,8 @@ class GPTNeoXForCausalLM(nn.Module):
config.vocab_size,
config.hidden_size,
)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -251,13 +253,18 @@ class GPTNeoXForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.embed_out.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.embed_out.weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -14,6 +14,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
@ -250,7 +251,8 @@ class InternLM2ForCausalLM(nn.Module):
self.linear_method = linear_method
self.model = InternLM2Model(config, linear_method)
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -263,13 +265,18 @@ class InternLM2ForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.output.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.output.weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
@ -325,7 +326,11 @@ class LlamaForCausalLM(nn.Module):
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = Sampler()
def forward(
self,
@ -338,13 +343,18 @@ class LlamaForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
@ -369,7 +370,9 @@ class MixtralForCausalLM(nn.Module):
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -382,13 +385,18 @@ class MixtralForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: Optional[torch.Tensor],
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -39,6 +39,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
@ -344,7 +345,8 @@ class MixtralForCausalLM(nn.Module):
self.linear_method = linear_method
self.model = MixtralModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -357,13 +359,18 @@ class MixtralForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: Optional[torch.Tensor],
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -13,6 +13,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@ -259,7 +260,8 @@ class MPTForCausalLM(nn.Module):
self.transformer = MPTModel(config, linear_method)
self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -272,13 +274,18 @@ class MPTForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -7,6 +7,7 @@ from torch import nn
from transformers import LlamaConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
@ -25,7 +26,8 @@ class LlamaForCausalLM(nn.Module):
self.config = config
self.linear_method = linear_method
self.model = None
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -45,13 +47,18 @@ class LlamaForCausalLM(nn.Module):
start_ids=seq_ids.flatten())
return logits
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -6,6 +6,7 @@ from torch import nn
from transformers import MistralConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
@ -26,7 +27,8 @@ class MistralForCausalLM(nn.Module):
self.linear_method = linear_method
self.model = None
self.lm_head = None
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -48,13 +50,18 @@ class MistralForCausalLM(nn.Module):
start_ids=seq_ids)
return logits
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -51,6 +51,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@ -336,7 +337,8 @@ class OLMoForCausalLM(nn.Module):
self.lm_head_weight = (self.model.transformer.wte.weight
if config.weight_tying else
self.model.transformer.ff_out.weight)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -353,13 +355,18 @@ class OLMoForCausalLM(nn.Module):
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(

View File

@ -31,6 +31,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@ -292,7 +293,8 @@ class OPTForCausalLM(nn.Module):
self.linear_method = linear_method
self.model = OPTModel(config, linear_method)
self.lm_head_weight = self.model.decoder.embed_tokens.weight
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -305,13 +307,18 @@ class OPTForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
@ -256,7 +257,8 @@ class OrionForCausalLM(nn.Module):
self.linear_method = linear_method
self.model = OrionModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -269,13 +271,18 @@ class OrionForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -49,6 +49,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
@ -240,7 +241,8 @@ class PhiForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=True)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -254,14 +256,18 @@ class PhiForCausalLM(nn.Module):
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
head = self.lm_head
next_tokens = self.sampler(head.weight, hidden_states,
sampling_metadata, head.bias)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -19,6 +19,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
@ -230,7 +231,8 @@ class QWenLMHeadModel(nn.Module):
self.linear_method = linear_method
self.transformer = QWenModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -243,13 +245,18 @@ class QWenLMHeadModel(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
@ -300,11 +301,15 @@ class Qwen2ForCausalLM(nn.Module):
self.linear_method = linear_method
self.model = Qwen2Model(config, linear_method)
if not config.tie_word_embeddings:
if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size)
self.lm_head_weight = self.lm_head.weight
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -317,17 +322,18 @@ class Qwen2ForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight
else:
lm_head_weight = self.lm_head.weight
next_tokens = self.sampler(lm_head_weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -33,6 +33,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
@ -238,7 +239,8 @@ class StablelmForCausalLM(nn.Module):
self.linear_method = linear_method
self.model = StableLMEpochModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -251,13 +253,18 @@ class StablelmForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
@ -254,7 +255,9 @@ class Starcoder2ForCausalLM(nn.Module):
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)
self.lm_head_weight = self.lm_head.weight
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
def forward(
self,
@ -267,13 +270,18 @@ class Starcoder2ForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: Optional[torch.Tensor],
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -613,9 +613,16 @@ class ModelRunner:
input_metadata=input_metadata,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Only perform sampling in the driver worker.
if not sampling_metadata.perform_sampling:
return None
# Sample the next token.
output = self.model.sample(
hidden_states=hidden_states,
logits=logits,
sampling_metadata=sampling_metadata,
)
return output