mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
Migrate logits computation and gather to model_runner (#3233)
This commit is contained in:
parent
6e435de766
commit
f1c0fc3919
@ -49,6 +49,9 @@ steps:
|
||||
- label: Samplers Test
|
||||
command: pytest -v -s samplers
|
||||
|
||||
- label: LogitsProcessor Test
|
||||
command: pytest -v -s test_logits_processor.py
|
||||
|
||||
- label: Worker Test
|
||||
command: pytest -v -s worker
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from huggingface_hub import snapshot_download
|
||||
import vllm
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
@ -85,7 +86,8 @@ def dummy_model() -> nn.Module:
|
||||
("outact", nn.Sigmoid()),
|
||||
# Special handling for lm_head & sampler
|
||||
("lm_head", ParallelLMHead(512, 10)),
|
||||
("sampler", Sampler(512))
|
||||
("logits_processor", LogitsProcessor(512)),
|
||||
("sampler", Sampler())
|
||||
]))
|
||||
model.config = MagicMock()
|
||||
return model
|
||||
@ -110,7 +112,8 @@ def dummy_model_gate_up() -> nn.Module:
|
||||
("outact", nn.Sigmoid()),
|
||||
# Special handling for lm_head & sampler
|
||||
("lm_head", ParallelLMHead(512, 10)),
|
||||
("sampler", Sampler(512))
|
||||
("logits_processor", LogitsProcessor(512)),
|
||||
("sampler", Sampler())
|
||||
]))
|
||||
model.config = MagicMock()
|
||||
return model
|
||||
|
||||
@ -13,14 +13,14 @@ from vllm.lora.layers import (
|
||||
QKVParallelLinearWithLora,
|
||||
VocabParallelEmbeddingWithLoRA,
|
||||
RowParallelLinearWithLoRA,
|
||||
SamplerWithLoRA,
|
||||
LogitsProcessorWithLoRA,
|
||||
LoRAMapping,
|
||||
BaseLayerWithLoRA,
|
||||
)
|
||||
from vllm.lora.models import (LoRALayerWeights, convert_mapping,
|
||||
PackedLoRALayerWeights)
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
@ -394,7 +394,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_lm_head_sampler(dist_init, num_loras, device) -> None:
|
||||
def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
@ -402,28 +402,29 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None:
|
||||
max_lora_rank=8,
|
||||
lora_dtype=torch.float16)
|
||||
|
||||
def create_random_sampler_layer():
|
||||
def _pretest():
|
||||
linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size,
|
||||
1024, 32000)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
linear.weight.data[:, 32000:] = 0
|
||||
sampler = Sampler(32000 + lora_config.lora_extra_vocab_size, 32000)
|
||||
lora_sampler = SamplerWithLoRA(sampler, 1024, linear.weight.dtype,
|
||||
linear.weight.device)
|
||||
lora_sampler.create_lora_weights(max_loras, lora_config)
|
||||
logits_processor = LogitsProcessor(
|
||||
32000 + lora_config.lora_extra_vocab_size, 32000)
|
||||
lora_logits_processor = LogitsProcessorWithLoRA(
|
||||
logits_processor, 1024, linear.weight.dtype, linear.weight.device)
|
||||
lora_logits_processor.create_lora_weights(max_loras, lora_config)
|
||||
|
||||
return linear, sampler, lora_sampler
|
||||
return linear, logits_processor, lora_logits_processor
|
||||
|
||||
for i in range(10):
|
||||
set_random_seed(i)
|
||||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
linear, sampler, lora_sampler = create_random_sampler_layer()
|
||||
linear, logits_processor, lora_logits_processor = _pretest()
|
||||
|
||||
# NOTE: all the generated loras share the same embeddings tensor.
|
||||
lora_dict, _ = populate_loras(
|
||||
id_to_index,
|
||||
layer=lora_sampler,
|
||||
layer=lora_logits_processor,
|
||||
layer_weights=linear.weight,
|
||||
generate_embeddings_tensor=1024,
|
||||
)
|
||||
@ -447,34 +448,37 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None:
|
||||
32000,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
lora_sampler.set_mapping(*mapping_info, )
|
||||
lora_logits_processor.set_mapping(*mapping_info, )
|
||||
|
||||
lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs),
|
||||
embedding=linear.weight,
|
||||
embedding_bias=None)
|
||||
lora_result = lora_logits_processor._get_logits(
|
||||
hidden_states=torch.cat(inputs),
|
||||
embedding=linear.weight,
|
||||
embedding_bias=None)
|
||||
|
||||
original_weight = linear.weight.clone()
|
||||
|
||||
linear.weight[sampler.org_vocab_size:sampler.org_vocab_size +
|
||||
linear.weight[logits_processor.
|
||||
org_vocab_size:logits_processor.org_vocab_size +
|
||||
embeddings_tensor_len] = embeddings_tensor
|
||||
|
||||
sampler.org_vocab_size = 32000 + lora_config.lora_extra_vocab_size
|
||||
logits_processor.org_vocab_size = (32000 +
|
||||
lora_config.lora_extra_vocab_size)
|
||||
expected_results = []
|
||||
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||
lora = lora_dict[lora_id]
|
||||
result = sampler._get_logits(hidden_states=input_,
|
||||
embedding=linear.weight,
|
||||
embedding_bias=None)
|
||||
result = logits_processor._get_logits(hidden_states=input_,
|
||||
embedding=linear.weight,
|
||||
embedding_bias=None)
|
||||
result[:, 32000 + embeddings_tensor_len:] = float("-inf")
|
||||
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
||||
expected_results.append(result)
|
||||
expected_result = torch.cat(expected_results)
|
||||
sampler.org_vocab_size = 32000
|
||||
logits_processor.org_vocab_size = 32000
|
||||
|
||||
# Check that resetting the lora weights succeeds
|
||||
|
||||
for slot_idx in range(max_loras):
|
||||
lora_sampler.reset_lora(slot_idx)
|
||||
lora_logits_processor.reset_lora(slot_idx)
|
||||
|
||||
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||
active_lora_ids=[0],
|
||||
@ -488,14 +492,16 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None:
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
32000,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_sampler.set_mapping(*mapping_info, )
|
||||
lora_logits_processor.set_mapping(*mapping_info, )
|
||||
|
||||
lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs),
|
||||
embedding=original_weight,
|
||||
embedding_bias=None)[:, :32000]
|
||||
expected_result = sampler._get_logits(hidden_states=torch.cat(inputs),
|
||||
embedding=original_weight,
|
||||
embedding_bias=None)
|
||||
lora_result = lora_logits_processor._get_logits(
|
||||
hidden_states=torch.cat(inputs),
|
||||
embedding=original_weight,
|
||||
embedding_bias=None)[:, :32000]
|
||||
expected_result = logits_processor._get_logits(
|
||||
hidden_states=torch.cat(inputs),
|
||||
embedding=original_weight,
|
||||
embedding_bias=None)
|
||||
|
||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||
assert torch.allclose(lora_result,
|
||||
|
||||
@ -15,17 +15,12 @@ from vllm.worker.model_runner import ModelRunner
|
||||
|
||||
class MockLogitsSampler(Sampler):
|
||||
|
||||
def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
|
||||
super().__init__(vocab_size=vocab_size)
|
||||
def __init__(self, fake_logits: torch.Tensor):
|
||||
super().__init__()
|
||||
self.fake_logits = fake_logits
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
with patch(
|
||||
"vllm.model_executor.layers.sampler._prune_hidden_states",
|
||||
lambda x, y: x), patch(
|
||||
"vllm.model_executor.layers.sampler.Sampler._get_logits",
|
||||
lambda *args, **kwargs: self.fake_logits):
|
||||
return super().forward(*args, **kwargs)
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
def _prepare_test(
|
||||
@ -36,7 +31,7 @@ def _prepare_test(
|
||||
fake_logits = torch.full((batch_size, vocab_size),
|
||||
1e-2,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(32000, fake_logits)
|
||||
sampler = MockLogitsSampler(fake_logits)
|
||||
model_runner = ModelRunner(None, None, None, None, None)
|
||||
return input_tensor, fake_logits, sampler, model_runner
|
||||
|
||||
@ -70,9 +65,7 @@ def _do_sample(
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
return sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@ -85,8 +78,8 @@ def test_sampler_all_greedy(seed: int, device: str):
|
||||
batch_size)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||
model_runner, sampling_params)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
|
||||
sampling_params)
|
||||
expected = torch.argmax(fake_logits, dim=-1)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
@ -111,8 +104,8 @@ def test_sampler_all_random(seed: int, device: str):
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||
model_runner, sampling_params)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
|
||||
sampling_params)
|
||||
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
@ -127,8 +120,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||
batch_size)
|
||||
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, i] = 1e2
|
||||
@ -138,8 +130,8 @@ def test_sampler_all_random_seed(seed: int, device: str):
|
||||
n=random.randint(1, 10),
|
||||
seed=random.randint(0, 10000),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||
model_runner, sampling_params)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
|
||||
sampling_params)
|
||||
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
@ -154,18 +146,17 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||
batch_size)
|
||||
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
seed=random.randint(0, 10000),
|
||||
)
|
||||
first_sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
model_runner, sampling_params)
|
||||
|
||||
second_sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||
second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
model_runner, sampling_params)
|
||||
|
||||
assert first_sampler_output == second_sampler_output
|
||||
@ -179,15 +170,14 @@ def test_sampler_all_beam(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
|
||||
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
best_of=2,
|
||||
use_beam_search=True,
|
||||
)
|
||||
_do_sample(batch_size, input_tensor, sampler, model_runner,
|
||||
sampling_params)
|
||||
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params)
|
||||
# no assertion here as I am not sure how to determine whether
|
||||
# the outputs are expected - in other words, this just tests
|
||||
# whether there are no exceptions in the sampler
|
||||
@ -246,8 +236,7 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
def test_sampling(model_runner: ModelRunner):
|
||||
sampling_metadata = model_runner._prepare_sample(
|
||||
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampler_output = sampler(logits=fake_logits,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
for i, (sequence_output, metadata) in enumerate(
|
||||
@ -294,48 +283,6 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
del model_runner
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_logits_processors(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
|
||||
|
||||
# This sample logits processor gives maximum score to the i-th token,
|
||||
# where i is the length of the input sequence.
|
||||
# We therefore expect the output token sequence to be [0, 1, 2, ...]
|
||||
def pick_ith(token_ids, logits):
|
||||
logits[len(token_ids)] = torch.finfo(logits.dtype).max
|
||||
return logits
|
||||
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(temperature=0,
|
||||
logits_processors=[pick_ith]),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
for _, sequence_output in enumerate(sampler_output):
|
||||
for idx, nth_output in enumerate(sequence_output.samples):
|
||||
assert nth_output.output_token == idx
|
||||
|
||||
del model_runner
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
@ -352,7 +299,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
size=(batch_size, vocab_size),
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(32000, fake_logits)
|
||||
sampler = MockLogitsSampler(fake_logits)
|
||||
model_runner = ModelRunner(None, None, None, None, None)
|
||||
|
||||
generation_model = GenerationMixin()
|
||||
@ -391,9 +338,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
|
||||
|
||||
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
|
||||
sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
|
||||
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
|
||||
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
|
||||
|
||||
94
tests/test_logits_processor.py
Normal file
94
tests/test_logits_processor.py
Normal file
@ -0,0 +1,94 @@
|
||||
import random
|
||||
from typing import Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
|
||||
|
||||
class MockLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, vocab_size: int, scale: float,
|
||||
fake_logits: torch.Tensor):
|
||||
super().__init__(vocab_size=vocab_size, scale=scale)
|
||||
self.fake_logits = fake_logits.clone()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
with patch(
|
||||
"vllm.model_executor.layers.logits_processor._prune_hidden_states",
|
||||
lambda x, y: x
|
||||
), patch(
|
||||
"vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits",
|
||||
lambda *args, **kwargs: self.fake_logits):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
def _prepare_test(
|
||||
batch_size: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor, ModelRunner]:
|
||||
vocab_size = 32000
|
||||
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
|
||||
fake_logits = torch.full((batch_size, vocab_size),
|
||||
1e-2,
|
||||
dtype=input_tensor.dtype)
|
||||
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
|
||||
model_runner = ModelRunner(None, None, None, None, None)
|
||||
return input_tensor, fake_logits, logits_processor, model_runner
|
||||
|
||||
|
||||
RANDOM_SEEDS = list(range(128))
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_logits_processors(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, logits_processor, model_runner = _prepare_test(
|
||||
batch_size)
|
||||
|
||||
# This sample logits processor gives infinite score to the i-th token,
|
||||
# where i is the length of the input sequence.
|
||||
# We therefore expect the output token sequence to be [0, 1, 2, ...]
|
||||
def pick_ith(token_ids, logits):
|
||||
logits[len(token_ids)] = float("inf")
|
||||
return logits
|
||||
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(temperature=0,
|
||||
logits_processors=[pick_ith]),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
logits_processor_output = logits_processor(
|
||||
embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
assert torch.isinf(logits_processor_output[:, 0]).all()
|
||||
|
||||
fake_logits *= logits_processor.scale
|
||||
assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1],
|
||||
1e-4)
|
||||
|
||||
del model_runner
|
||||
@ -10,7 +10,6 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
@ -20,6 +19,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
QKVParallelLinear,
|
||||
MergedColumnParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
@ -783,11 +783,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
return self.base_layer.weight
|
||||
|
||||
|
||||
class SamplerWithLoRA(BaseLayerWithLoRA):
|
||||
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: Sampler,
|
||||
base_layer: LogitsProcessor,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
@ -806,6 +806,10 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
|
||||
def vocab_size(self):
|
||||
return self.base_layer.vocab_size
|
||||
|
||||
@property
|
||||
def scale(self):
|
||||
return self.base_layer.scale
|
||||
|
||||
@property
|
||||
def org_vocab_size(self):
|
||||
return self.base_layer.org_vocab_size
|
||||
@ -968,14 +972,14 @@ def from_layer(
|
||||
return layer
|
||||
|
||||
|
||||
def from_layer_sampler(
|
||||
layer: Sampler,
|
||||
def from_layer_logits_processor(
|
||||
layer: LogitsProcessor,
|
||||
lm_head: ParallelLMHead,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> SamplerWithLoRA:
|
||||
ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype,
|
||||
lm_head.weight.device)
|
||||
) -> LogitsProcessorWithLoRA:
|
||||
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
|
||||
lm_head.weight.dtype, lm_head.weight.device)
|
||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||
return ret
|
||||
|
||||
@ -14,7 +14,7 @@ from vllm.config import LoRAConfig
|
||||
from vllm.utils import LRUCache, in_wsl
|
||||
|
||||
from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer,
|
||||
from_layer_sampler)
|
||||
from_layer_logits_processor)
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
|
||||
|
||||
@ -421,11 +421,14 @@ class LoRAModelManager:
|
||||
self.model.config))
|
||||
# (yard1): TODO make this more robust
|
||||
if "lm_head" in module_name:
|
||||
sampler_module = self.model.get_submodule("sampler")
|
||||
logits_processor_module = self.model.get_submodule(
|
||||
"logits_processor")
|
||||
new_module = replace_submodule(
|
||||
self.model, "sampler",
|
||||
from_layer_sampler(sampler_module, module, self.lora_slots,
|
||||
self.lora_config, self.model.config))
|
||||
self.model, "logits_processor",
|
||||
from_layer_logits_processor(logits_processor_module,
|
||||
module, self.lora_slots,
|
||||
self.lora_config,
|
||||
self.model.config))
|
||||
self.register_module(module_name, new_module)
|
||||
self._register_packed_modules(module_name)
|
||||
new_module.set_mapping(self.base_indices, self.sampler_indices,
|
||||
|
||||
106
vllm/model_executor/layers/logits_processor.py
Normal file
106
vllm/model_executor/layers/logits_processor.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""A layer that compute logits from hidden_stats."""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.utils import is_neuron
|
||||
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_gather)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
|
||||
class LogitsProcessor(nn.Module):
|
||||
"""Process logits and apply logits processors from sampling metadata.
|
||||
|
||||
This layer does the following:
|
||||
1. Gather logits from model hidden_states.
|
||||
2. Scale logits if needed.
|
||||
3. Apply logits processors (if any).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_size: int,
|
||||
org_vocab_size: Optional[int] = None,
|
||||
scale: Optional[float] = 1.0) -> None:
|
||||
"""
|
||||
Args:
|
||||
scale: A scaling factor to apply to the logits.
|
||||
"""
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.vocab_size = vocab_size
|
||||
# Transformers-neuronx generate outputs as logits directly.
|
||||
self.logits_as_hidden_states = is_neuron()
|
||||
# original vocabulary size (without LoRA).
|
||||
self.org_vocab_size = org_vocab_size or vocab_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
embedding: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.logits_as_hidden_states:
|
||||
logits = hidden_states
|
||||
else:
|
||||
hidden_states = _prune_hidden_states(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
# Get the logits for the next tokens.
|
||||
logits = self._get_logits(hidden_states, embedding, embedding_bias)
|
||||
|
||||
if logits is not None:
|
||||
logits *= self.scale
|
||||
|
||||
# Apply logits processors (if any).
|
||||
logits = _apply_logits_processors(logits, sampling_metadata)
|
||||
|
||||
return logits
|
||||
|
||||
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
|
||||
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
# Remove paddings in vocab (if any).
|
||||
if logits is not None:
|
||||
logits = logits[:, :self.org_vocab_size]
|
||||
return logits
|
||||
|
||||
|
||||
def _prune_hidden_states(
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
return hidden_states.index_select(0,
|
||||
sampling_metadata.selected_token_indices)
|
||||
|
||||
|
||||
def _apply_logits_processors(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
logits_row_idx = 0
|
||||
found_logits_processors = False
|
||||
for seq_ids, sampling_params in sampling_metadata.seq_groups:
|
||||
logits_processors = sampling_params.logits_processors
|
||||
if logits_processors:
|
||||
found_logits_processors = True
|
||||
for seq_id in seq_ids:
|
||||
logits_row = logits[logits_row_idx]
|
||||
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
|
||||
for logits_processor in logits_processors:
|
||||
logits_row = logits_processor(token_ids, logits_row)
|
||||
logits[logits_row_idx] = logits_row
|
||||
logits_row_idx += 1
|
||||
else:
|
||||
logits_row_idx += len(seq_ids)
|
||||
if found_logits_processors:
|
||||
assert logits_row_idx == logits.shape[0]
|
||||
return logits
|
||||
@ -4,8 +4,6 @@ from typing import Dict, List, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_gather)
|
||||
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
||||
SamplingTensors)
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
@ -13,7 +11,6 @@ from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
||||
SamplerOutput, SequenceData, SequenceGroupOutput,
|
||||
SequenceOutput)
|
||||
from vllm.model_executor.layers.ops.sample import (sample as sample_triton)
|
||||
from vllm.utils import is_neuron
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
@ -31,58 +28,14 @@ class Sampler(nn.Module):
|
||||
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_size: int,
|
||||
org_vocab_size: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
# Transformers-neuronx generate outputs as logits directly.
|
||||
self.logits_as_hidden_states = is_neuron()
|
||||
# original vocabulary size (without LoRA).
|
||||
self.org_vocab_size = org_vocab_size or vocab_size
|
||||
|
||||
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
|
||||
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
# Remove paddings in vocab (if any).
|
||||
if logits is not None:
|
||||
logits = logits[:, :self.org_vocab_size]
|
||||
return logits
|
||||
|
||||
def forward(
|
||||
self,
|
||||
embedding: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> Optional[SamplerOutput]:
|
||||
# Get the hidden states that we use for sampling.
|
||||
if self.logits_as_hidden_states:
|
||||
logits = hidden_states
|
||||
else:
|
||||
hidden_states = _prune_hidden_states(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
# Get the logits for the next tokens.
|
||||
logits = self._get_logits(hidden_states, embedding, embedding_bias)
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
# Note: `_get_logits` is still distributed across TP workers because
|
||||
# the `embedding` weight is distributed across TP workers.
|
||||
# TODO(zhuohan): Change the get_logits part to a separate stage.
|
||||
if not sampling_metadata.perform_sampling:
|
||||
return None
|
||||
|
||||
assert logits is not None
|
||||
_, vocab_size = logits.shape
|
||||
|
||||
# Apply logits processors (if any).
|
||||
logits = _apply_logits_processors(logits, sampling_metadata)
|
||||
|
||||
# Prepare sampling tensors with pinned memory to avoid blocking.
|
||||
(sampling_tensors, do_penalties, do_top_p_top_k,
|
||||
do_min_p) = SamplingTensors.from_sampling_metadata(
|
||||
@ -124,14 +77,6 @@ class Sampler(nn.Module):
|
||||
prompt_logprobs, sample_logprobs)
|
||||
|
||||
|
||||
def _prune_hidden_states(
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
return hidden_states.index_select(0,
|
||||
sampling_metadata.selected_token_indices)
|
||||
|
||||
|
||||
def _get_bin_counts_and_mask(
|
||||
tokens: torch.Tensor,
|
||||
vocab_size: int,
|
||||
@ -149,30 +94,6 @@ def _get_bin_counts_and_mask(
|
||||
return bin_counts, mask
|
||||
|
||||
|
||||
def _apply_logits_processors(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
logits_row_idx = 0
|
||||
found_logits_processors = False
|
||||
for seq_ids, sampling_params in sampling_metadata.seq_groups:
|
||||
logits_processors = sampling_params.logits_processors
|
||||
if logits_processors:
|
||||
found_logits_processors = True
|
||||
for seq_id in seq_ids:
|
||||
logits_row = logits[logits_row_idx]
|
||||
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
|
||||
for logits_processor in logits_processors:
|
||||
logits_row = logits_processor(token_ids, logits_row)
|
||||
logits[logits_row_idx] = logits_row
|
||||
logits_row_idx += 1
|
||||
else:
|
||||
logits_row_idx += len(seq_ids)
|
||||
if found_logits_processors:
|
||||
assert logits_row_idx == logits.shape[0]
|
||||
return logits
|
||||
|
||||
|
||||
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
||||
output_tokens_tensor: torch.Tensor,
|
||||
presence_penalties: torch.Tensor,
|
||||
|
||||
@ -34,6 +34,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -295,7 +296,8 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
self.linear_method = linear_method
|
||||
self.model = BaiChuanModel(config, position_embedding, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -308,13 +310,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@ -273,7 +274,8 @@ class BloomForCausalLM(nn.Module):
|
||||
self.linear_method = linear_method
|
||||
self.transformer = BloomModel(config, linear_method)
|
||||
self.lm_head_weight = self.transformer.word_embeddings.weight
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -286,13 +288,18 @@ class BloomForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -17,6 +17,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -332,7 +333,8 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
self.linear_method = linear_method
|
||||
self.transformer = ChatGLMModel(config, linear_method)
|
||||
self.lm_head_weight = self.transformer.output_layer.weight
|
||||
self.sampler = Sampler(config.padded_vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -345,13 +347,18 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -38,6 +38,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -372,7 +373,8 @@ class DeepseekForCausalLM(nn.Module):
|
||||
self.linear_method = linear_method
|
||||
self.model = DeepseekModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -385,13 +387,18 @@ class DeepseekForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: Optional[torch.Tensor],
|
||||
logits: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -34,6 +34,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -373,7 +374,8 @@ class FalconForCausalLM(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -390,13 +392,18 @@ class FalconForCausalLM(nn.Module):
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@ -281,7 +282,8 @@ class GemmaForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = GemmaModel(config, linear_method)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@ -295,13 +297,18 @@ class GemmaForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.model.embed_tokens.weight,
|
||||
hidden_states, sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.model.embed_tokens.weight,
|
||||
hidden_states, sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@ -216,7 +217,8 @@ class GPT2LMHeadModel(nn.Module):
|
||||
self.linear_method = linear_method
|
||||
self.transformer = GPT2Model(config, linear_method)
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -229,12 +231,18 @@ class GPT2LMHeadModel(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
next_tokens = self.sampler(self.lm_head_weight, logits,
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
|
||||
@ -31,6 +31,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@ -237,7 +238,8 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
self.linear_method = linear_method
|
||||
self.transformer = GPTBigCodeModel(config, linear_method)
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -250,13 +252,18 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -224,7 +225,8 @@ class GPTJForCausalLM(nn.Module):
|
||||
config.n_embd,
|
||||
bias=True,
|
||||
)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -237,13 +239,18 @@ class GPTJForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata, self.lm_head.bias)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata, self.lm_head.bias)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -238,7 +239,8 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -251,13 +253,18 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.embed_out.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.embed_out.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -14,6 +14,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -250,7 +251,8 @@ class InternLM2ForCausalLM(nn.Module):
|
||||
self.linear_method = linear_method
|
||||
self.model = InternLM2Model(config, linear_method)
|
||||
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -263,13 +265,18 @@ class InternLM2ForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.output.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.output.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
|
||||
@ -325,7 +326,11 @@ class LlamaForCausalLM(nn.Module):
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
)
|
||||
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size, logit_scale)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -338,13 +343,18 @@ class LlamaForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
|
||||
@ -369,7 +370,9 @@ class MixtralForCausalLM(nn.Module):
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
)
|
||||
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -382,13 +385,18 @@ class MixtralForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: Optional[torch.Tensor],
|
||||
logits: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -39,6 +39,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -344,7 +345,8 @@ class MixtralForCausalLM(nn.Module):
|
||||
self.linear_method = linear_method
|
||||
self.model = MixtralModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -357,13 +359,18 @@ class MixtralForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: Optional[torch.Tensor],
|
||||
logits: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@ -259,7 +260,8 @@ class MPTForCausalLM(nn.Module):
|
||||
|
||||
self.transformer = MPTModel(config, linear_method)
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -272,13 +274,18 @@ class MPTForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -7,6 +7,7 @@ from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -25,7 +26,8 @@ class LlamaForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = None
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -45,13 +47,18 @@ class LlamaForCausalLM(nn.Module):
|
||||
start_ids=seq_ids.flatten())
|
||||
return logits
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.model.chkpt_model.lm_head,
|
||||
hidden_states, sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.model.chkpt_model.lm_head,
|
||||
hidden_states, sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -6,6 +6,7 @@ from torch import nn
|
||||
from transformers import MistralConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -26,7 +27,8 @@ class MistralForCausalLM(nn.Module):
|
||||
self.linear_method = linear_method
|
||||
self.model = None
|
||||
self.lm_head = None
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -48,13 +50,18 @@ class MistralForCausalLM(nn.Module):
|
||||
start_ids=seq_ids)
|
||||
return logits
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.model.chkpt_model.lm_head,
|
||||
hidden_states, sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.model.chkpt_model.lm_head,
|
||||
hidden_states, sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -51,6 +51,7 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@ -336,7 +337,8 @@ class OLMoForCausalLM(nn.Module):
|
||||
self.lm_head_weight = (self.model.transformer.wte.weight
|
||||
if config.weight_tying else
|
||||
self.model.transformer.ff_out.weight)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -353,13 +355,18 @@ class OLMoForCausalLM(nn.Module):
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(
|
||||
|
||||
@ -31,6 +31,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@ -292,7 +293,8 @@ class OPTForCausalLM(nn.Module):
|
||||
self.linear_method = linear_method
|
||||
self.model = OPTModel(config, linear_method)
|
||||
self.lm_head_weight = self.model.decoder.embed_tokens.weight
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -305,13 +307,18 @@ class OPTForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -256,7 +257,8 @@ class OrionForCausalLM(nn.Module):
|
||||
self.linear_method = linear_method
|
||||
self.model = OrionModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -269,13 +271,18 @@ class OrionForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -49,6 +49,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -240,7 +241,8 @@ class PhiForCausalLM(nn.Module):
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
bias=True)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -254,14 +256,18 @@ class PhiForCausalLM(nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata, self.lm_head.bias)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
head = self.lm_head
|
||||
next_tokens = self.sampler(head.weight, hidden_states,
|
||||
sampling_metadata, head.bias)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -19,6 +19,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -230,7 +231,8 @@ class QWenLMHeadModel(nn.Module):
|
||||
self.linear_method = linear_method
|
||||
self.transformer = QWenModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -243,13 +245,18 @@ class QWenLMHeadModel(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -300,11 +301,15 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
self.linear_method = linear_method
|
||||
self.model = Qwen2Model(config, linear_method)
|
||||
|
||||
if not config.tie_word_embeddings:
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head_weight = self.model.embed_tokens.weight
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.lm_head_weight = self.lm_head.weight
|
||||
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -317,17 +322,18 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
if self.config.tie_word_embeddings:
|
||||
lm_head_weight = self.model.embed_tokens.weight
|
||||
else:
|
||||
lm_head_weight = self.lm_head.weight
|
||||
next_tokens = self.sampler(lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -33,6 +33,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -238,7 +239,8 @@ class StablelmForCausalLM(nn.Module):
|
||||
self.linear_method = linear_method
|
||||
self.model = StableLMEpochModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -251,13 +253,18 @@ class StablelmForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
|
||||
@ -254,7 +255,9 @@ class Starcoder2ForCausalLM(nn.Module):
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
)
|
||||
self.lm_head_weight = self.lm_head.weight
|
||||
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -267,13 +270,18 @@ class Starcoder2ForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: Optional[torch.Tensor],
|
||||
logits: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -613,9 +613,16 @@ class ModelRunner:
|
||||
input_metadata=input_metadata,
|
||||
)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
if not sampling_metadata.perform_sampling:
|
||||
return None
|
||||
|
||||
# Sample the next token.
|
||||
output = self.model.sample(
|
||||
hidden_states=hidden_states,
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
return output
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user