Migrate logits computation and gather to model_runner (#3233)

This commit is contained in:
Roy 2024-03-21 07:25:01 +08:00 committed by GitHub
parent 6e435de766
commit f1c0fc3919
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 576 additions and 305 deletions

View File

@ -49,6 +49,9 @@ steps:
- label: Samplers Test - label: Samplers Test
command: pytest -v -s samplers command: pytest -v -s samplers
- label: LogitsProcessor Test
command: pytest -v -s test_logits_processor.py
- label: Worker Test - label: Worker Test
command: pytest -v -s worker command: pytest -v -s worker

View File

@ -13,6 +13,7 @@ from huggingface_hub import snapshot_download
import vllm import vllm
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
@ -85,7 +86,8 @@ def dummy_model() -> nn.Module:
("outact", nn.Sigmoid()), ("outact", nn.Sigmoid()),
# Special handling for lm_head & sampler # Special handling for lm_head & sampler
("lm_head", ParallelLMHead(512, 10)), ("lm_head", ParallelLMHead(512, 10)),
("sampler", Sampler(512)) ("logits_processor", LogitsProcessor(512)),
("sampler", Sampler())
])) ]))
model.config = MagicMock() model.config = MagicMock()
return model return model
@ -110,7 +112,8 @@ def dummy_model_gate_up() -> nn.Module:
("outact", nn.Sigmoid()), ("outact", nn.Sigmoid()),
# Special handling for lm_head & sampler # Special handling for lm_head & sampler
("lm_head", ParallelLMHead(512, 10)), ("lm_head", ParallelLMHead(512, 10)),
("sampler", Sampler(512)) ("logits_processor", LogitsProcessor(512)),
("sampler", Sampler())
])) ]))
model.config = MagicMock() model.config = MagicMock()
return model return model

View File

@ -13,14 +13,14 @@ from vllm.lora.layers import (
QKVParallelLinearWithLora, QKVParallelLinearWithLora,
VocabParallelEmbeddingWithLoRA, VocabParallelEmbeddingWithLoRA,
RowParallelLinearWithLoRA, RowParallelLinearWithLoRA,
SamplerWithLoRA, LogitsProcessorWithLoRA,
LoRAMapping, LoRAMapping,
BaseLayerWithLoRA, BaseLayerWithLoRA,
) )
from vllm.lora.models import (LoRALayerWeights, convert_mapping, from vllm.lora.models import (LoRALayerWeights, convert_mapping,
PackedLoRALayerWeights) PackedLoRALayerWeights)
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
RowParallelLinear, RowParallelLinear,
@ -394,7 +394,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
@torch.inference_mode() @torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lm_head_sampler(dist_init, num_loras, device) -> None: def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
@ -402,28 +402,29 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None:
max_lora_rank=8, max_lora_rank=8,
lora_dtype=torch.float16) lora_dtype=torch.float16)
def create_random_sampler_layer(): def _pretest():
linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size, linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size,
1024, 32000) 1024, 32000)
linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data = torch.rand_like(linear.weight.data)
linear.weight.data[:, 32000:] = 0 linear.weight.data[:, 32000:] = 0
sampler = Sampler(32000 + lora_config.lora_extra_vocab_size, 32000) logits_processor = LogitsProcessor(
lora_sampler = SamplerWithLoRA(sampler, 1024, linear.weight.dtype, 32000 + lora_config.lora_extra_vocab_size, 32000)
linear.weight.device) lora_logits_processor = LogitsProcessorWithLoRA(
lora_sampler.create_lora_weights(max_loras, lora_config) logits_processor, 1024, linear.weight.dtype, linear.weight.device)
lora_logits_processor.create_lora_weights(max_loras, lora_config)
return linear, sampler, lora_sampler return linear, logits_processor, lora_logits_processor
for i in range(10): for i in range(10):
set_random_seed(i) set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, sampler, lora_sampler = create_random_sampler_layer() linear, logits_processor, lora_logits_processor = _pretest()
# NOTE: all the generated loras share the same embeddings tensor. # NOTE: all the generated loras share the same embeddings tensor.
lora_dict, _ = populate_loras( lora_dict, _ = populate_loras(
id_to_index, id_to_index,
layer=lora_sampler, layer=lora_logits_processor,
layer_weights=linear.weight, layer_weights=linear.weight,
generate_embeddings_tensor=1024, generate_embeddings_tensor=1024,
) )
@ -447,34 +448,37 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None:
32000, 32000,
lora_config.lora_extra_vocab_size, lora_config.lora_extra_vocab_size,
) )
lora_sampler.set_mapping(*mapping_info, ) lora_logits_processor.set_mapping(*mapping_info, )
lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs), lora_result = lora_logits_processor._get_logits(
embedding=linear.weight, hidden_states=torch.cat(inputs),
embedding_bias=None) embedding=linear.weight,
embedding_bias=None)
original_weight = linear.weight.clone() original_weight = linear.weight.clone()
linear.weight[sampler.org_vocab_size:sampler.org_vocab_size + linear.weight[logits_processor.
org_vocab_size:logits_processor.org_vocab_size +
embeddings_tensor_len] = embeddings_tensor embeddings_tensor_len] = embeddings_tensor
sampler.org_vocab_size = 32000 + lora_config.lora_extra_vocab_size logits_processor.org_vocab_size = (32000 +
lora_config.lora_extra_vocab_size)
expected_results = [] expected_results = []
for input_, lora_id in zip(inputs, prompt_mapping): for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id] lora = lora_dict[lora_id]
result = sampler._get_logits(hidden_states=input_, result = logits_processor._get_logits(hidden_states=input_,
embedding=linear.weight, embedding=linear.weight,
embedding_bias=None) embedding_bias=None)
result[:, 32000 + embeddings_tensor_len:] = float("-inf") result[:, 32000 + embeddings_tensor_len:] = float("-inf")
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
expected_results.append(result) expected_results.append(result)
expected_result = torch.cat(expected_results) expected_result = torch.cat(expected_results)
sampler.org_vocab_size = 32000 logits_processor.org_vocab_size = 32000
# Check that resetting the lora weights succeeds # Check that resetting the lora weights succeeds
for slot_idx in range(max_loras): for slot_idx in range(max_loras):
lora_sampler.reset_lora(slot_idx) lora_logits_processor.reset_lora(slot_idx)
inputs, index_mapping, prompt_mapping = create_random_inputs( inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0], active_lora_ids=[0],
@ -488,14 +492,16 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None:
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
32000, 32000,
lora_config.lora_extra_vocab_size) lora_config.lora_extra_vocab_size)
lora_sampler.set_mapping(*mapping_info, ) lora_logits_processor.set_mapping(*mapping_info, )
lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs), lora_result = lora_logits_processor._get_logits(
embedding=original_weight, hidden_states=torch.cat(inputs),
embedding_bias=None)[:, :32000] embedding=original_weight,
expected_result = sampler._get_logits(hidden_states=torch.cat(inputs), embedding_bias=None)[:, :32000]
embedding=original_weight, expected_result = logits_processor._get_logits(
embedding_bias=None) hidden_states=torch.cat(inputs),
embedding=original_weight,
embedding_bias=None)
rtol, atol = TOLERANCES[lora_result.dtype] rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result, assert torch.allclose(lora_result,

View File

@ -15,17 +15,12 @@ from vllm.worker.model_runner import ModelRunner
class MockLogitsSampler(Sampler): class MockLogitsSampler(Sampler):
def __init__(self, vocab_size: int, fake_logits: torch.Tensor): def __init__(self, fake_logits: torch.Tensor):
super().__init__(vocab_size=vocab_size) super().__init__()
self.fake_logits = fake_logits self.fake_logits = fake_logits
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
with patch( return super().forward(*args, **kwargs)
"vllm.model_executor.layers.sampler._prune_hidden_states",
lambda x, y: x), patch(
"vllm.model_executor.layers.sampler.Sampler._get_logits",
lambda *args, **kwargs: self.fake_logits):
return super().forward(*args, **kwargs)
def _prepare_test( def _prepare_test(
@ -36,7 +31,7 @@ def _prepare_test(
fake_logits = torch.full((batch_size, vocab_size), fake_logits = torch.full((batch_size, vocab_size),
1e-2, 1e-2,
dtype=input_tensor.dtype) dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits) sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(None, None, None, None, None) model_runner = ModelRunner(None, None, None, None, None)
return input_tensor, fake_logits, sampler, model_runner return input_tensor, fake_logits, sampler, model_runner
@ -70,9 +65,7 @@ def _do_sample(
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens, prompt_lens,
subquery_lens=prompt_lens) subquery_lens=prompt_lens)
return sampler(embedding=None, return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
@pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("seed", RANDOM_SEEDS)
@ -85,8 +78,8 @@ def test_sampler_all_greedy(seed: int, device: str):
batch_size) batch_size)
sampling_params = SamplingParams(temperature=0) sampling_params = SamplingParams(temperature=0)
sampler_output = _do_sample(batch_size, input_tensor, sampler, sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
model_runner, sampling_params) sampling_params)
expected = torch.argmax(fake_logits, dim=-1) expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output): for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples: for nth_output in sequence_output.samples:
@ -111,8 +104,8 @@ def test_sampler_all_random(seed: int, device: str):
temperature=1.0, temperature=1.0,
n=random.randint(1, 10), n=random.randint(1, 10),
) )
sampler_output = _do_sample(batch_size, input_tensor, sampler, sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
model_runner, sampling_params) sampling_params)
for i, sequence_output in enumerate(sampler_output): for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples: for nth_output in sequence_output.samples:
@ -127,8 +120,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
set_random_seed(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test( _, fake_logits, sampler, model_runner = _prepare_test(batch_size)
batch_size)
for i in range(batch_size): for i in range(batch_size):
fake_logits[i, i] = 1e2 fake_logits[i, i] = 1e2
@ -138,8 +130,8 @@ def test_sampler_all_random_seed(seed: int, device: str):
n=random.randint(1, 10), n=random.randint(1, 10),
seed=random.randint(0, 10000), seed=random.randint(0, 10000),
) )
sampler_output = _do_sample(batch_size, input_tensor, sampler, sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
model_runner, sampling_params) sampling_params)
for i, sequence_output in enumerate(sampler_output): for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples: for nth_output in sequence_output.samples:
@ -154,18 +146,17 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
set_random_seed(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test( _, fake_logits, sampler, model_runner = _prepare_test(batch_size)
batch_size)
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=1.0, temperature=1.0,
n=random.randint(1, 10), n=random.randint(1, 10),
seed=random.randint(0, 10000), seed=random.randint(0, 10000),
) )
first_sampler_output = _do_sample(batch_size, input_tensor, sampler, first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params) model_runner, sampling_params)
second_sampler_output = _do_sample(batch_size, input_tensor, sampler, second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params) model_runner, sampling_params)
assert first_sampler_output == second_sampler_output assert first_sampler_output == second_sampler_output
@ -179,15 +170,14 @@ def test_sampler_all_beam(seed: int, device: str):
set_random_seed(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size) _, fake_logits, sampler, model_runner = _prepare_test(batch_size)
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, temperature=0,
best_of=2, best_of=2,
use_beam_search=True, use_beam_search=True,
) )
_do_sample(batch_size, input_tensor, sampler, model_runner, _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params)
sampling_params)
# no assertion here as I am not sure how to determine whether # no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests # the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler # whether there are no exceptions in the sampler
@ -246,8 +236,7 @@ def test_sampler_mixed(seed: int, device: str):
def test_sampling(model_runner: ModelRunner): def test_sampling(model_runner: ModelRunner):
sampling_metadata = model_runner._prepare_sample( sampling_metadata = model_runner._prepare_sample(
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None, sampler_output = sampler(logits=fake_logits,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata) sampling_metadata=sampling_metadata)
for i, (sequence_output, metadata) in enumerate( for i, (sequence_output, metadata) in enumerate(
@ -294,48 +283,6 @@ def test_sampler_mixed(seed: int, device: str):
del model_runner del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_logits_processors(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
# This sample logits processor gives maximum score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = torch.finfo(logits.dtype).max
return logits
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
for _, sequence_output in enumerate(sampler_output):
for idx, nth_output in enumerate(sequence_output.samples):
assert nth_output.output_token == idx
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_top_k_top_p(seed: int, device: str): def test_sampler_top_k_top_p(seed: int, device: str):
@ -352,7 +299,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
size=(batch_size, vocab_size), size=(batch_size, vocab_size),
device=input_tensor.device, device=input_tensor.device,
dtype=input_tensor.dtype) dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits) sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(None, None, None, None, None) model_runner = ModelRunner(None, None, None, None, None)
generation_model = GenerationMixin() generation_model = GenerationMixin()
@ -391,9 +338,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs] return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
with patch("vllm.model_executor.layers.sampler._sample", mock_sample): with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
sampler(embedding=None, sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone()) hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
assert torch.allclose(hf_probs, sample_probs, atol=1e-5) assert torch.allclose(hf_probs, sample_probs, atol=1e-5)

View File

@ -0,0 +1,94 @@
import random
from typing import Tuple
from unittest.mock import patch
import pytest
import torch
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner
class MockLogitsProcessor(LogitsProcessor):
def __init__(self, vocab_size: int, scale: float,
fake_logits: torch.Tensor):
super().__init__(vocab_size=vocab_size, scale=scale)
self.fake_logits = fake_logits.clone()
def forward(self, *args, **kwargs):
with patch(
"vllm.model_executor.layers.logits_processor._prune_hidden_states",
lambda x, y: x
), patch(
"vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits",
lambda *args, **kwargs: self.fake_logits):
return super().forward(*args, **kwargs)
def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor, ModelRunner]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
dtype=input_tensor.dtype)
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
model_runner = ModelRunner(None, None, None, None, None)
return input_tensor, fake_logits, logits_processor, model_runner
RANDOM_SEEDS = list(range(128))
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_logits_processors(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, logits_processor, model_runner = _prepare_test(
batch_size)
# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = float("inf")
return logits
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
logits_processor_output = logits_processor(
embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
assert torch.isinf(logits_processor_output[:, 0]).all()
fake_logits *= logits_processor.scale
assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1],
1e-4)
del model_runner

View File

@ -10,7 +10,6 @@ from transformers import PretrainedConfig
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.lora.punica import add_lora, add_lora_slice, bgmv from vllm.lora.punica import add_lora, add_lora_slice, bgmv
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
@ -20,6 +19,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear, RowParallelLinear,
QKVParallelLinear, QKVParallelLinear,
MergedColumnParallelLinear) MergedColumnParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
@ -783,11 +783,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
return self.base_layer.weight return self.base_layer.weight
class SamplerWithLoRA(BaseLayerWithLoRA): class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def __init__( def __init__(
self, self,
base_layer: Sampler, base_layer: LogitsProcessor,
hidden_size: int, hidden_size: int,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
@ -806,6 +806,10 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
def vocab_size(self): def vocab_size(self):
return self.base_layer.vocab_size return self.base_layer.vocab_size
@property
def scale(self):
return self.base_layer.scale
@property @property
def org_vocab_size(self): def org_vocab_size(self):
return self.base_layer.org_vocab_size return self.base_layer.org_vocab_size
@ -968,14 +972,14 @@ def from_layer(
return layer return layer
def from_layer_sampler( def from_layer_logits_processor(
layer: Sampler, layer: LogitsProcessor,
lm_head: ParallelLMHead, lm_head: ParallelLMHead,
max_loras: int, max_loras: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None, model_config: Optional[PretrainedConfig] = None,
) -> SamplerWithLoRA: ) -> LogitsProcessorWithLoRA:
ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype, ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
lm_head.weight.device) lm_head.weight.dtype, lm_head.weight.device)
ret.create_lora_weights(max_loras, lora_config, model_config) ret.create_lora_weights(max_loras, lora_config, model_config)
return ret return ret

View File

@ -14,7 +14,7 @@ from vllm.config import LoRAConfig
from vllm.utils import LRUCache, in_wsl from vllm.utils import LRUCache, in_wsl
from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer, from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer,
from_layer_sampler) from_layer_logits_processor)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
@ -421,11 +421,14 @@ class LoRAModelManager:
self.model.config)) self.model.config))
# (yard1): TODO make this more robust # (yard1): TODO make this more robust
if "lm_head" in module_name: if "lm_head" in module_name:
sampler_module = self.model.get_submodule("sampler") logits_processor_module = self.model.get_submodule(
"logits_processor")
new_module = replace_submodule( new_module = replace_submodule(
self.model, "sampler", self.model, "logits_processor",
from_layer_sampler(sampler_module, module, self.lora_slots, from_layer_logits_processor(logits_processor_module,
self.lora_config, self.model.config)) module, self.lora_slots,
self.lora_config,
self.model.config))
self.register_module(module_name, new_module) self.register_module(module_name, new_module)
self._register_packed_modules(module_name) self._register_packed_modules(module_name)
new_module.set_mapping(self.base_indices, self.sampler_indices, new_module.set_mapping(self.base_indices, self.sampler_indices,

View File

@ -0,0 +1,106 @@
"""A layer that compute logits from hidden_stats."""
from typing import Optional
import torch
import torch.nn as nn
from vllm.utils import is_neuron
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_gather)
from vllm.model_executor.sampling_metadata import SamplingMetadata
class LogitsProcessor(nn.Module):
"""Process logits and apply logits processors from sampling metadata.
This layer does the following:
1. Gather logits from model hidden_states.
2. Scale logits if needed.
3. Apply logits processors (if any).
"""
def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None,
scale: Optional[float] = 1.0) -> None:
"""
Args:
scale: A scaling factor to apply to the logits.
"""
super().__init__()
self.scale = scale
self.vocab_size = vocab_size
# Transformers-neuronx generate outputs as logits directly.
self.logits_as_hidden_states = is_neuron()
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
def forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.logits_as_hidden_states:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)
if logits is not None:
logits *= self.scale
# Apply logits processors (if any).
logits = _apply_logits_processors(logits, sampling_metadata)
return logits
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits
def _prune_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
return hidden_states.index_select(0,
sampling_metadata.selected_token_indices)
def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
logits_row_idx = 0
found_logits_processors = False
for seq_ids, sampling_params in sampling_metadata.seq_groups:
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
logits_row = logits[logits_row_idx]
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
logits_row_idx += len(seq_ids)
if found_logits_processors:
assert logits_row_idx == logits.shape[0]
return logits

View File

@ -4,8 +4,6 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_gather)
from vllm.model_executor.sampling_metadata import (SamplingMetadata, from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors) SamplingTensors)
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
@ -13,7 +11,6 @@ from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
SamplerOutput, SequenceData, SequenceGroupOutput, SamplerOutput, SequenceData, SequenceGroupOutput,
SequenceOutput) SequenceOutput)
from vllm.model_executor.layers.ops.sample import (sample as sample_triton) from vllm.model_executor.layers.ops.sample import (sample as sample_triton)
from vllm.utils import is_neuron
class Sampler(nn.Module): class Sampler(nn.Module):
@ -31,58 +28,14 @@ class Sampler(nn.Module):
parameters (e.g., sampling method, temperature, top-p, top-k, etc.). parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
""" """
def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None) -> None:
super().__init__()
self.vocab_size = vocab_size
# Transformers-neuronx generate outputs as logits directly.
self.logits_as_hidden_states = is_neuron()
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits
def forward( def forward(
self, self,
embedding: torch.Tensor, logits: torch.Tensor,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
# Get the hidden states that we use for sampling.
if self.logits_as_hidden_states:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)
# Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because
# the `embedding` weight is distributed across TP workers.
# TODO(zhuohan): Change the get_logits part to a separate stage.
if not sampling_metadata.perform_sampling:
return None
assert logits is not None assert logits is not None
_, vocab_size = logits.shape _, vocab_size = logits.shape
# Apply logits processors (if any).
logits = _apply_logits_processors(logits, sampling_metadata)
# Prepare sampling tensors with pinned memory to avoid blocking. # Prepare sampling tensors with pinned memory to avoid blocking.
(sampling_tensors, do_penalties, do_top_p_top_k, (sampling_tensors, do_penalties, do_top_p_top_k,
do_min_p) = SamplingTensors.from_sampling_metadata( do_min_p) = SamplingTensors.from_sampling_metadata(
@ -124,14 +77,6 @@ class Sampler(nn.Module):
prompt_logprobs, sample_logprobs) prompt_logprobs, sample_logprobs)
def _prune_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
return hidden_states.index_select(0,
sampling_metadata.selected_token_indices)
def _get_bin_counts_and_mask( def _get_bin_counts_and_mask(
tokens: torch.Tensor, tokens: torch.Tensor,
vocab_size: int, vocab_size: int,
@ -149,30 +94,6 @@ def _get_bin_counts_and_mask(
return bin_counts, mask return bin_counts, mask
def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
logits_row_idx = 0
found_logits_processors = False
for seq_ids, sampling_params in sampling_metadata.seq_groups:
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
logits_row = logits[logits_row_idx]
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
logits_row_idx += len(seq_ids)
if found_logits_processors:
assert logits_row_idx == logits.shape[0]
return logits
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor: torch.Tensor, output_tokens_tensor: torch.Tensor,
presence_penalties: torch.Tensor, presence_penalties: torch.Tensor,

View File

@ -34,6 +34,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
@ -295,7 +296,8 @@ class BaiChuanBaseForCausalLM(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.model = BaiChuanModel(config, position_embedding, linear_method) self.model = BaiChuanModel(config, position_embedding, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -308,13 +310,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
@ -273,7 +274,8 @@ class BloomForCausalLM(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.transformer = BloomModel(config, linear_method) self.transformer = BloomModel(config, linear_method)
self.lm_head_weight = self.transformer.word_embeddings.weight self.lm_head_weight = self.transformer.word_embeddings.weight
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -286,13 +288,18 @@ class BloomForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -17,6 +17,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
@ -332,7 +333,8 @@ class ChatGLMForCausalLM(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.transformer = ChatGLMModel(config, linear_method) self.transformer = ChatGLMModel(config, linear_method)
self.lm_head_weight = self.transformer.output_layer.weight self.lm_head_weight = self.transformer.output_layer.weight
self.sampler = Sampler(config.padded_vocab_size) self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -345,13 +347,18 @@ class ChatGLMForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -38,6 +38,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
@ -372,7 +373,8 @@ class DeepseekForCausalLM(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.model = DeepseekModel(config, linear_method) self.model = DeepseekModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -385,13 +387,18 @@ class DeepseekForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: Optional[torch.Tensor], logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -34,6 +34,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
@ -373,7 +374,8 @@ class FalconForCausalLM(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -390,13 +392,18 @@ class FalconForCausalLM(nn.Module):
) )
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
@ -281,7 +282,8 @@ class GemmaForCausalLM(nn.Module):
self.config = config self.config = config
self.linear_method = linear_method self.linear_method = linear_method
self.model = GemmaModel(config, linear_method) self.model = GemmaModel(config, linear_method)
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
@ -295,13 +297,18 @@ class GemmaForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.embed_tokens.weight,
hidden_states, sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.model.embed_tokens.weight, next_tokens = self.sampler(logits, sampling_metadata)
hidden_states, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
@ -216,7 +217,8 @@ class GPT2LMHeadModel(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.transformer = GPT2Model(config, linear_method) self.transformer = GPT2Model(config, linear_method)
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -229,12 +231,18 @@ class GPT2LMHeadModel(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, logits,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens

View File

@ -31,6 +31,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
@ -237,7 +238,8 @@ class GPTBigCodeForCausalLM(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.transformer = GPTBigCodeModel(config, linear_method) self.transformer = GPTBigCodeModel(config, linear_method)
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -250,13 +252,18 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
@ -224,7 +225,8 @@ class GPTJForCausalLM(nn.Module):
config.n_embd, config.n_embd,
bias=True, bias=True,
) )
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -237,13 +239,18 @@ class GPTJForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata, self.lm_head.bias)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
@ -238,7 +239,8 @@ class GPTNeoXForCausalLM(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -251,13 +253,18 @@ class GPTNeoXForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.embed_out.weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.embed_out.weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -14,6 +14,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
@ -250,7 +251,8 @@ class InternLM2ForCausalLM(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.model = InternLM2Model(config, linear_method) self.model = InternLM2Model(config, linear_method)
self.output = ParallelLMHead(config.vocab_size, config.hidden_size) self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -263,13 +265,18 @@ class InternLM2ForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.output.weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.output.weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
@ -325,7 +326,11 @@ class LlamaForCausalLM(nn.Module):
# compatibility # compatibility
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
) )
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -338,13 +343,18 @@ class LlamaForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
@ -369,7 +370,9 @@ class MixtralForCausalLM(nn.Module):
# compatibility # compatibility
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
) )
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -382,13 +385,18 @@ class MixtralForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: Optional[torch.Tensor], logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -39,6 +39,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
@ -344,7 +345,8 @@ class MixtralForCausalLM(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.model = MixtralModel(config, linear_method) self.model = MixtralModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -357,13 +359,18 @@ class MixtralForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: Optional[torch.Tensor], logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -13,6 +13,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
@ -259,7 +260,8 @@ class MPTForCausalLM(nn.Module):
self.transformer = MPTModel(config, linear_method) self.transformer = MPTModel(config, linear_method)
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -272,13 +274,18 @@ class MPTForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -7,6 +7,7 @@ from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
@ -25,7 +26,8 @@ class LlamaForCausalLM(nn.Module):
self.config = config self.config = config
self.linear_method = linear_method self.linear_method = linear_method
self.model = None self.model = None
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -45,13 +47,18 @@ class LlamaForCausalLM(nn.Module):
start_ids=seq_ids.flatten()) start_ids=seq_ids.flatten())
return logits return logits
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.model.chkpt_model.lm_head, next_tokens = self.sampler(logits, sampling_metadata)
hidden_states, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -6,6 +6,7 @@ from torch import nn
from transformers import MistralConfig from transformers import MistralConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
@ -26,7 +27,8 @@ class MistralForCausalLM(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.model = None self.model = None
self.lm_head = None self.lm_head = None
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -48,13 +50,18 @@ class MistralForCausalLM(nn.Module):
start_ids=seq_ids) start_ids=seq_ids)
return logits return logits
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.model.chkpt_model.lm_head, next_tokens = self.sampler(logits, sampling_metadata)
hidden_states, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -51,6 +51,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
@ -336,7 +337,8 @@ class OLMoForCausalLM(nn.Module):
self.lm_head_weight = (self.model.transformer.wte.weight self.lm_head_weight = (self.model.transformer.wte.weight
if config.weight_tying else if config.weight_tying else
self.model.transformer.ff_out.weight) self.model.transformer.ff_out.weight)
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -353,13 +355,18 @@ class OLMoForCausalLM(nn.Module):
) )
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights( def load_weights(

View File

@ -31,6 +31,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
@ -292,7 +293,8 @@ class OPTForCausalLM(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.model = OPTModel(config, linear_method) self.model = OPTModel(config, linear_method)
self.lm_head_weight = self.model.decoder.embed_tokens.weight self.lm_head_weight = self.model.decoder.embed_tokens.weight
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -305,13 +307,18 @@ class OPTForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
@ -256,7 +257,8 @@ class OrionForCausalLM(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.model = OrionModel(config, linear_method) self.model = OrionModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -269,13 +271,18 @@ class OrionForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -49,6 +49,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
@ -240,7 +241,8 @@ class PhiForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
bias=True) bias=True)
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -254,14 +256,18 @@ class PhiForCausalLM(nn.Module):
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
head = self.lm_head next_tokens = self.sampler(logits, sampling_metadata)
next_tokens = self.sampler(head.weight, hidden_states,
sampling_metadata, head.bias)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -19,6 +19,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
@ -230,7 +231,8 @@ class QWenLMHeadModel(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.transformer = QWenModel(config, linear_method) self.transformer = QWenModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -243,13 +245,18 @@ class QWenLMHeadModel(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
@ -300,11 +301,15 @@ class Qwen2ForCausalLM(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.model = Qwen2Model(config, linear_method) self.model = Qwen2Model(config, linear_method)
if not config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
else:
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size) config.hidden_size)
self.lm_head_weight = self.lm_head.weight
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -317,17 +322,18 @@ class Qwen2ForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
if self.config.tie_word_embeddings: next_tokens = self.sampler(logits, sampling_metadata)
lm_head_weight = self.model.embed_tokens.weight
else:
lm_head_weight = self.lm_head.weight
next_tokens = self.sampler(lm_head_weight, hidden_states,
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -33,6 +33,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
@ -238,7 +239,8 @@ class StablelmForCausalLM(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.model = StableLMEpochModel(config, linear_method) self.model = StableLMEpochModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -251,13 +253,18 @@ class StablelmForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
@ -254,7 +255,9 @@ class Starcoder2ForCausalLM(nn.Module):
padding_size=DEFAULT_VOCAB_PADDING_SIZE, padding_size=DEFAULT_VOCAB_PADDING_SIZE,
) )
self.lm_head_weight = self.lm_head.weight self.lm_head_weight = self.lm_head.weight
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
@ -267,13 +270,18 @@ class Starcoder2ForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: Optional[torch.Tensor], logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,

View File

@ -613,9 +613,16 @@ class ModelRunner:
input_metadata=input_metadata, input_metadata=input_metadata,
) )
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Only perform sampling in the driver worker.
if not sampling_metadata.perform_sampling:
return None
# Sample the next token. # Sample the next token.
output = self.model.sample( output = self.model.sample(
hidden_states=hidden_states, logits=logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
return output return output