mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 09:15:55 +08:00
Migrate logits computation and gather to model_runner (#3233)
This commit is contained in:
parent
6e435de766
commit
f1c0fc3919
@ -49,6 +49,9 @@ steps:
|
|||||||
- label: Samplers Test
|
- label: Samplers Test
|
||||||
command: pytest -v -s samplers
|
command: pytest -v -s samplers
|
||||||
|
|
||||||
|
- label: LogitsProcessor Test
|
||||||
|
command: pytest -v -s test_logits_processor.py
|
||||||
|
|
||||||
- label: Worker Test
|
- label: Worker Test
|
||||||
command: pytest -v -s worker
|
command: pytest -v -s worker
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from huggingface_hub import snapshot_download
|
|||||||
import vllm
|
import vllm
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import LoRAConfig
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
@ -85,7 +86,8 @@ def dummy_model() -> nn.Module:
|
|||||||
("outact", nn.Sigmoid()),
|
("outact", nn.Sigmoid()),
|
||||||
# Special handling for lm_head & sampler
|
# Special handling for lm_head & sampler
|
||||||
("lm_head", ParallelLMHead(512, 10)),
|
("lm_head", ParallelLMHead(512, 10)),
|
||||||
("sampler", Sampler(512))
|
("logits_processor", LogitsProcessor(512)),
|
||||||
|
("sampler", Sampler())
|
||||||
]))
|
]))
|
||||||
model.config = MagicMock()
|
model.config = MagicMock()
|
||||||
return model
|
return model
|
||||||
@ -110,7 +112,8 @@ def dummy_model_gate_up() -> nn.Module:
|
|||||||
("outact", nn.Sigmoid()),
|
("outact", nn.Sigmoid()),
|
||||||
# Special handling for lm_head & sampler
|
# Special handling for lm_head & sampler
|
||||||
("lm_head", ParallelLMHead(512, 10)),
|
("lm_head", ParallelLMHead(512, 10)),
|
||||||
("sampler", Sampler(512))
|
("logits_processor", LogitsProcessor(512)),
|
||||||
|
("sampler", Sampler())
|
||||||
]))
|
]))
|
||||||
model.config = MagicMock()
|
model.config = MagicMock()
|
||||||
return model
|
return model
|
||||||
|
|||||||
@ -13,14 +13,14 @@ from vllm.lora.layers import (
|
|||||||
QKVParallelLinearWithLora,
|
QKVParallelLinearWithLora,
|
||||||
VocabParallelEmbeddingWithLoRA,
|
VocabParallelEmbeddingWithLoRA,
|
||||||
RowParallelLinearWithLoRA,
|
RowParallelLinearWithLoRA,
|
||||||
SamplerWithLoRA,
|
LogitsProcessorWithLoRA,
|
||||||
LoRAMapping,
|
LoRAMapping,
|
||||||
BaseLayerWithLoRA,
|
BaseLayerWithLoRA,
|
||||||
)
|
)
|
||||||
from vllm.lora.models import (LoRALayerWeights, convert_mapping,
|
from vllm.lora.models import (LoRALayerWeights, convert_mapping,
|
||||||
PackedLoRALayerWeights)
|
PackedLoRALayerWeights)
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import LoRAConfig
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
@ -394,7 +394,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
def test_lm_head_sampler(dist_init, num_loras, device) -> None:
|
def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||||
|
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
max_loras = 8
|
max_loras = 8
|
||||||
@ -402,28 +402,29 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None:
|
|||||||
max_lora_rank=8,
|
max_lora_rank=8,
|
||||||
lora_dtype=torch.float16)
|
lora_dtype=torch.float16)
|
||||||
|
|
||||||
def create_random_sampler_layer():
|
def _pretest():
|
||||||
linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size,
|
linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size,
|
||||||
1024, 32000)
|
1024, 32000)
|
||||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||||
linear.weight.data[:, 32000:] = 0
|
linear.weight.data[:, 32000:] = 0
|
||||||
sampler = Sampler(32000 + lora_config.lora_extra_vocab_size, 32000)
|
logits_processor = LogitsProcessor(
|
||||||
lora_sampler = SamplerWithLoRA(sampler, 1024, linear.weight.dtype,
|
32000 + lora_config.lora_extra_vocab_size, 32000)
|
||||||
linear.weight.device)
|
lora_logits_processor = LogitsProcessorWithLoRA(
|
||||||
lora_sampler.create_lora_weights(max_loras, lora_config)
|
logits_processor, 1024, linear.weight.dtype, linear.weight.device)
|
||||||
|
lora_logits_processor.create_lora_weights(max_loras, lora_config)
|
||||||
|
|
||||||
return linear, sampler, lora_sampler
|
return linear, logits_processor, lora_logits_processor
|
||||||
|
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
set_random_seed(i)
|
set_random_seed(i)
|
||||||
|
|
||||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||||
linear, sampler, lora_sampler = create_random_sampler_layer()
|
linear, logits_processor, lora_logits_processor = _pretest()
|
||||||
|
|
||||||
# NOTE: all the generated loras share the same embeddings tensor.
|
# NOTE: all the generated loras share the same embeddings tensor.
|
||||||
lora_dict, _ = populate_loras(
|
lora_dict, _ = populate_loras(
|
||||||
id_to_index,
|
id_to_index,
|
||||||
layer=lora_sampler,
|
layer=lora_logits_processor,
|
||||||
layer_weights=linear.weight,
|
layer_weights=linear.weight,
|
||||||
generate_embeddings_tensor=1024,
|
generate_embeddings_tensor=1024,
|
||||||
)
|
)
|
||||||
@ -447,34 +448,37 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None:
|
|||||||
32000,
|
32000,
|
||||||
lora_config.lora_extra_vocab_size,
|
lora_config.lora_extra_vocab_size,
|
||||||
)
|
)
|
||||||
lora_sampler.set_mapping(*mapping_info, )
|
lora_logits_processor.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs),
|
lora_result = lora_logits_processor._get_logits(
|
||||||
embedding=linear.weight,
|
hidden_states=torch.cat(inputs),
|
||||||
embedding_bias=None)
|
embedding=linear.weight,
|
||||||
|
embedding_bias=None)
|
||||||
|
|
||||||
original_weight = linear.weight.clone()
|
original_weight = linear.weight.clone()
|
||||||
|
|
||||||
linear.weight[sampler.org_vocab_size:sampler.org_vocab_size +
|
linear.weight[logits_processor.
|
||||||
|
org_vocab_size:logits_processor.org_vocab_size +
|
||||||
embeddings_tensor_len] = embeddings_tensor
|
embeddings_tensor_len] = embeddings_tensor
|
||||||
|
|
||||||
sampler.org_vocab_size = 32000 + lora_config.lora_extra_vocab_size
|
logits_processor.org_vocab_size = (32000 +
|
||||||
|
lora_config.lora_extra_vocab_size)
|
||||||
expected_results = []
|
expected_results = []
|
||||||
for input_, lora_id in zip(inputs, prompt_mapping):
|
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||||
lora = lora_dict[lora_id]
|
lora = lora_dict[lora_id]
|
||||||
result = sampler._get_logits(hidden_states=input_,
|
result = logits_processor._get_logits(hidden_states=input_,
|
||||||
embedding=linear.weight,
|
embedding=linear.weight,
|
||||||
embedding_bias=None)
|
embedding_bias=None)
|
||||||
result[:, 32000 + embeddings_tensor_len:] = float("-inf")
|
result[:, 32000 + embeddings_tensor_len:] = float("-inf")
|
||||||
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
||||||
expected_results.append(result)
|
expected_results.append(result)
|
||||||
expected_result = torch.cat(expected_results)
|
expected_result = torch.cat(expected_results)
|
||||||
sampler.org_vocab_size = 32000
|
logits_processor.org_vocab_size = 32000
|
||||||
|
|
||||||
# Check that resetting the lora weights succeeds
|
# Check that resetting the lora weights succeeds
|
||||||
|
|
||||||
for slot_idx in range(max_loras):
|
for slot_idx in range(max_loras):
|
||||||
lora_sampler.reset_lora(slot_idx)
|
lora_logits_processor.reset_lora(slot_idx)
|
||||||
|
|
||||||
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
active_lora_ids=[0],
|
active_lora_ids=[0],
|
||||||
@ -488,14 +492,16 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None:
|
|||||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||||
32000,
|
32000,
|
||||||
lora_config.lora_extra_vocab_size)
|
lora_config.lora_extra_vocab_size)
|
||||||
lora_sampler.set_mapping(*mapping_info, )
|
lora_logits_processor.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs),
|
lora_result = lora_logits_processor._get_logits(
|
||||||
embedding=original_weight,
|
hidden_states=torch.cat(inputs),
|
||||||
embedding_bias=None)[:, :32000]
|
embedding=original_weight,
|
||||||
expected_result = sampler._get_logits(hidden_states=torch.cat(inputs),
|
embedding_bias=None)[:, :32000]
|
||||||
embedding=original_weight,
|
expected_result = logits_processor._get_logits(
|
||||||
embedding_bias=None)
|
hidden_states=torch.cat(inputs),
|
||||||
|
embedding=original_weight,
|
||||||
|
embedding_bias=None)
|
||||||
|
|
||||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
assert torch.allclose(lora_result,
|
assert torch.allclose(lora_result,
|
||||||
|
|||||||
@ -15,17 +15,12 @@ from vllm.worker.model_runner import ModelRunner
|
|||||||
|
|
||||||
class MockLogitsSampler(Sampler):
|
class MockLogitsSampler(Sampler):
|
||||||
|
|
||||||
def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
|
def __init__(self, fake_logits: torch.Tensor):
|
||||||
super().__init__(vocab_size=vocab_size)
|
super().__init__()
|
||||||
self.fake_logits = fake_logits
|
self.fake_logits = fake_logits
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
with patch(
|
return super().forward(*args, **kwargs)
|
||||||
"vllm.model_executor.layers.sampler._prune_hidden_states",
|
|
||||||
lambda x, y: x), patch(
|
|
||||||
"vllm.model_executor.layers.sampler.Sampler._get_logits",
|
|
||||||
lambda *args, **kwargs: self.fake_logits):
|
|
||||||
return super().forward(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_test(
|
def _prepare_test(
|
||||||
@ -36,7 +31,7 @@ def _prepare_test(
|
|||||||
fake_logits = torch.full((batch_size, vocab_size),
|
fake_logits = torch.full((batch_size, vocab_size),
|
||||||
1e-2,
|
1e-2,
|
||||||
dtype=input_tensor.dtype)
|
dtype=input_tensor.dtype)
|
||||||
sampler = MockLogitsSampler(32000, fake_logits)
|
sampler = MockLogitsSampler(fake_logits)
|
||||||
model_runner = ModelRunner(None, None, None, None, None)
|
model_runner = ModelRunner(None, None, None, None, None)
|
||||||
return input_tensor, fake_logits, sampler, model_runner
|
return input_tensor, fake_logits, sampler, model_runner
|
||||||
|
|
||||||
@ -70,9 +65,7 @@ def _do_sample(
|
|||||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||||
prompt_lens,
|
prompt_lens,
|
||||||
subquery_lens=prompt_lens)
|
subquery_lens=prompt_lens)
|
||||||
return sampler(embedding=None,
|
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
|
||||||
hidden_states=input_tensor,
|
|
||||||
sampling_metadata=sampling_metadata)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
@ -85,8 +78,8 @@ def test_sampler_all_greedy(seed: int, device: str):
|
|||||||
batch_size)
|
batch_size)
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0)
|
sampling_params = SamplingParams(temperature=0)
|
||||||
sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
|
||||||
model_runner, sampling_params)
|
sampling_params)
|
||||||
expected = torch.argmax(fake_logits, dim=-1)
|
expected = torch.argmax(fake_logits, dim=-1)
|
||||||
for i, sequence_output in enumerate(sampler_output):
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
for nth_output in sequence_output.samples:
|
for nth_output in sequence_output.samples:
|
||||||
@ -111,8 +104,8 @@ def test_sampler_all_random(seed: int, device: str):
|
|||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
n=random.randint(1, 10),
|
n=random.randint(1, 10),
|
||||||
)
|
)
|
||||||
sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
|
||||||
model_runner, sampling_params)
|
sampling_params)
|
||||||
|
|
||||||
for i, sequence_output in enumerate(sampler_output):
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
for nth_output in sequence_output.samples:
|
for nth_output in sequence_output.samples:
|
||||||
@ -127,8 +120,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
|
|||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
batch_size = random.randint(1, 256)
|
batch_size = random.randint(1, 256)
|
||||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||||
batch_size)
|
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
fake_logits[i, i] = 1e2
|
fake_logits[i, i] = 1e2
|
||||||
@ -138,8 +130,8 @@ def test_sampler_all_random_seed(seed: int, device: str):
|
|||||||
n=random.randint(1, 10),
|
n=random.randint(1, 10),
|
||||||
seed=random.randint(0, 10000),
|
seed=random.randint(0, 10000),
|
||||||
)
|
)
|
||||||
sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
|
||||||
model_runner, sampling_params)
|
sampling_params)
|
||||||
|
|
||||||
for i, sequence_output in enumerate(sampler_output):
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
for nth_output in sequence_output.samples:
|
for nth_output in sequence_output.samples:
|
||||||
@ -154,18 +146,17 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
|
|||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
batch_size = random.randint(1, 256)
|
batch_size = random.randint(1, 256)
|
||||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||||
batch_size)
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
n=random.randint(1, 10),
|
n=random.randint(1, 10),
|
||||||
seed=random.randint(0, 10000),
|
seed=random.randint(0, 10000),
|
||||||
)
|
)
|
||||||
first_sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||||
model_runner, sampling_params)
|
model_runner, sampling_params)
|
||||||
|
|
||||||
second_sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||||
model_runner, sampling_params)
|
model_runner, sampling_params)
|
||||||
|
|
||||||
assert first_sampler_output == second_sampler_output
|
assert first_sampler_output == second_sampler_output
|
||||||
@ -179,15 +170,14 @@ def test_sampler_all_beam(seed: int, device: str):
|
|||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
batch_size = random.randint(1, 256)
|
batch_size = random.randint(1, 256)
|
||||||
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
|
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=0,
|
temperature=0,
|
||||||
best_of=2,
|
best_of=2,
|
||||||
use_beam_search=True,
|
use_beam_search=True,
|
||||||
)
|
)
|
||||||
_do_sample(batch_size, input_tensor, sampler, model_runner,
|
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params)
|
||||||
sampling_params)
|
|
||||||
# no assertion here as I am not sure how to determine whether
|
# no assertion here as I am not sure how to determine whether
|
||||||
# the outputs are expected - in other words, this just tests
|
# the outputs are expected - in other words, this just tests
|
||||||
# whether there are no exceptions in the sampler
|
# whether there are no exceptions in the sampler
|
||||||
@ -246,8 +236,7 @@ def test_sampler_mixed(seed: int, device: str):
|
|||||||
def test_sampling(model_runner: ModelRunner):
|
def test_sampling(model_runner: ModelRunner):
|
||||||
sampling_metadata = model_runner._prepare_sample(
|
sampling_metadata = model_runner._prepare_sample(
|
||||||
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
|
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
|
||||||
sampler_output = sampler(embedding=None,
|
sampler_output = sampler(logits=fake_logits,
|
||||||
hidden_states=input_tensor,
|
|
||||||
sampling_metadata=sampling_metadata)
|
sampling_metadata=sampling_metadata)
|
||||||
|
|
||||||
for i, (sequence_output, metadata) in enumerate(
|
for i, (sequence_output, metadata) in enumerate(
|
||||||
@ -294,48 +283,6 @@ def test_sampler_mixed(seed: int, device: str):
|
|||||||
del model_runner
|
del model_runner
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
||||||
def test_sampler_logits_processors(seed: int, device: str):
|
|
||||||
set_random_seed(seed)
|
|
||||||
torch.set_default_device(device)
|
|
||||||
batch_size = random.randint(1, 256)
|
|
||||||
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
|
|
||||||
|
|
||||||
# This sample logits processor gives maximum score to the i-th token,
|
|
||||||
# where i is the length of the input sequence.
|
|
||||||
# We therefore expect the output token sequence to be [0, 1, 2, ...]
|
|
||||||
def pick_ith(token_ids, logits):
|
|
||||||
logits[len(token_ids)] = torch.finfo(logits.dtype).max
|
|
||||||
return logits
|
|
||||||
|
|
||||||
seq_group_metadata_list = []
|
|
||||||
prompt_lens = []
|
|
||||||
for i in range(batch_size):
|
|
||||||
seq_group_metadata_list.append(
|
|
||||||
SequenceGroupMetadata(
|
|
||||||
request_id=f"test_{i}",
|
|
||||||
is_prompt=True,
|
|
||||||
seq_data={0: SequenceData([1, 2, 3])},
|
|
||||||
sampling_params=SamplingParams(temperature=0,
|
|
||||||
logits_processors=[pick_ith]),
|
|
||||||
block_tables={0: [1]},
|
|
||||||
))
|
|
||||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
|
||||||
|
|
||||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
|
||||||
prompt_lens,
|
|
||||||
subquery_lens=prompt_lens)
|
|
||||||
sampler_output = sampler(embedding=None,
|
|
||||||
hidden_states=input_tensor,
|
|
||||||
sampling_metadata=sampling_metadata)
|
|
||||||
for _, sequence_output in enumerate(sampler_output):
|
|
||||||
for idx, nth_output in enumerate(sequence_output.samples):
|
|
||||||
assert nth_output.output_token == idx
|
|
||||||
|
|
||||||
del model_runner
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
def test_sampler_top_k_top_p(seed: int, device: str):
|
def test_sampler_top_k_top_p(seed: int, device: str):
|
||||||
@ -352,7 +299,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
|||||||
size=(batch_size, vocab_size),
|
size=(batch_size, vocab_size),
|
||||||
device=input_tensor.device,
|
device=input_tensor.device,
|
||||||
dtype=input_tensor.dtype)
|
dtype=input_tensor.dtype)
|
||||||
sampler = MockLogitsSampler(32000, fake_logits)
|
sampler = MockLogitsSampler(fake_logits)
|
||||||
model_runner = ModelRunner(None, None, None, None, None)
|
model_runner = ModelRunner(None, None, None, None, None)
|
||||||
|
|
||||||
generation_model = GenerationMixin()
|
generation_model = GenerationMixin()
|
||||||
@ -391,9 +338,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
|||||||
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
|
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
|
||||||
|
|
||||||
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
|
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
|
||||||
sampler(embedding=None,
|
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||||
hidden_states=input_tensor,
|
|
||||||
sampling_metadata=sampling_metadata)
|
|
||||||
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
|
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
|
||||||
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
|
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
|
||||||
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
|
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
|
||||||
|
|||||||
94
tests/test_logits_processor.py
Normal file
94
tests/test_logits_processor.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import random
|
||||||
|
from typing import Tuple
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.utils import set_random_seed
|
||||||
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||||
|
from vllm.worker.model_runner import ModelRunner
|
||||||
|
|
||||||
|
|
||||||
|
class MockLogitsProcessor(LogitsProcessor):
|
||||||
|
|
||||||
|
def __init__(self, vocab_size: int, scale: float,
|
||||||
|
fake_logits: torch.Tensor):
|
||||||
|
super().__init__(vocab_size=vocab_size, scale=scale)
|
||||||
|
self.fake_logits = fake_logits.clone()
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
with patch(
|
||||||
|
"vllm.model_executor.layers.logits_processor._prune_hidden_states",
|
||||||
|
lambda x, y: x
|
||||||
|
), patch(
|
||||||
|
"vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits",
|
||||||
|
lambda *args, **kwargs: self.fake_logits):
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_test(
|
||||||
|
batch_size: int
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor, ModelRunner]:
|
||||||
|
vocab_size = 32000
|
||||||
|
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
|
||||||
|
fake_logits = torch.full((batch_size, vocab_size),
|
||||||
|
1e-2,
|
||||||
|
dtype=input_tensor.dtype)
|
||||||
|
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
|
||||||
|
model_runner = ModelRunner(None, None, None, None, None)
|
||||||
|
return input_tensor, fake_logits, logits_processor, model_runner
|
||||||
|
|
||||||
|
|
||||||
|
RANDOM_SEEDS = list(range(128))
|
||||||
|
CUDA_DEVICES = [
|
||||||
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
def test_logits_processors(seed: int, device: str):
|
||||||
|
set_random_seed(seed)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
batch_size = random.randint(1, 256)
|
||||||
|
input_tensor, fake_logits, logits_processor, model_runner = _prepare_test(
|
||||||
|
batch_size)
|
||||||
|
|
||||||
|
# This sample logits processor gives infinite score to the i-th token,
|
||||||
|
# where i is the length of the input sequence.
|
||||||
|
# We therefore expect the output token sequence to be [0, 1, 2, ...]
|
||||||
|
def pick_ith(token_ids, logits):
|
||||||
|
logits[len(token_ids)] = float("inf")
|
||||||
|
return logits
|
||||||
|
|
||||||
|
seq_group_metadata_list = []
|
||||||
|
prompt_lens = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
seq_group_metadata_list.append(
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id=f"test_{i}",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={0: SequenceData([1, 2, 3])},
|
||||||
|
sampling_params=SamplingParams(temperature=0,
|
||||||
|
logits_processors=[pick_ith]),
|
||||||
|
block_tables={0: [1]},
|
||||||
|
))
|
||||||
|
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||||
|
|
||||||
|
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||||
|
prompt_lens,
|
||||||
|
subquery_lens=prompt_lens)
|
||||||
|
logits_processor_output = logits_processor(
|
||||||
|
embedding=None,
|
||||||
|
hidden_states=input_tensor,
|
||||||
|
sampling_metadata=sampling_metadata)
|
||||||
|
|
||||||
|
assert torch.isinf(logits_processor_output[:, 0]).all()
|
||||||
|
|
||||||
|
fake_logits *= logits_processor.scale
|
||||||
|
assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1],
|
||||||
|
1e-4)
|
||||||
|
|
||||||
|
del model_runner
|
||||||
@ -10,7 +10,6 @@ from transformers import PretrainedConfig
|
|||||||
|
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import LoRAConfig
|
||||||
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
|
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
|
||||||
from vllm.model_executor.parallel_utils.communication_op import (
|
from vllm.model_executor.parallel_utils.communication_op import (
|
||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
@ -20,6 +19,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
MergedColumnParallelLinear)
|
MergedColumnParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
@ -783,11 +783,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
return self.base_layer.weight
|
return self.base_layer.weight
|
||||||
|
|
||||||
|
|
||||||
class SamplerWithLoRA(BaseLayerWithLoRA):
|
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_layer: Sampler,
|
base_layer: LogitsProcessor,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
@ -806,6 +806,10 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
|
|||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
return self.base_layer.vocab_size
|
return self.base_layer.vocab_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scale(self):
|
||||||
|
return self.base_layer.scale
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def org_vocab_size(self):
|
def org_vocab_size(self):
|
||||||
return self.base_layer.org_vocab_size
|
return self.base_layer.org_vocab_size
|
||||||
@ -968,14 +972,14 @@ def from_layer(
|
|||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
def from_layer_sampler(
|
def from_layer_logits_processor(
|
||||||
layer: Sampler,
|
layer: LogitsProcessor,
|
||||||
lm_head: ParallelLMHead,
|
lm_head: ParallelLMHead,
|
||||||
max_loras: int,
|
max_loras: int,
|
||||||
lora_config: LoRAConfig,
|
lora_config: LoRAConfig,
|
||||||
model_config: Optional[PretrainedConfig] = None,
|
model_config: Optional[PretrainedConfig] = None,
|
||||||
) -> SamplerWithLoRA:
|
) -> LogitsProcessorWithLoRA:
|
||||||
ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype,
|
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
|
||||||
lm_head.weight.device)
|
lm_head.weight.dtype, lm_head.weight.device)
|
||||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||||
return ret
|
return ret
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from vllm.config import LoRAConfig
|
|||||||
from vllm.utils import LRUCache, in_wsl
|
from vllm.utils import LRUCache, in_wsl
|
||||||
|
|
||||||
from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer,
|
from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer,
|
||||||
from_layer_sampler)
|
from_layer_logits_processor)
|
||||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||||
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
|
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
|
||||||
|
|
||||||
@ -421,11 +421,14 @@ class LoRAModelManager:
|
|||||||
self.model.config))
|
self.model.config))
|
||||||
# (yard1): TODO make this more robust
|
# (yard1): TODO make this more robust
|
||||||
if "lm_head" in module_name:
|
if "lm_head" in module_name:
|
||||||
sampler_module = self.model.get_submodule("sampler")
|
logits_processor_module = self.model.get_submodule(
|
||||||
|
"logits_processor")
|
||||||
new_module = replace_submodule(
|
new_module = replace_submodule(
|
||||||
self.model, "sampler",
|
self.model, "logits_processor",
|
||||||
from_layer_sampler(sampler_module, module, self.lora_slots,
|
from_layer_logits_processor(logits_processor_module,
|
||||||
self.lora_config, self.model.config))
|
module, self.lora_slots,
|
||||||
|
self.lora_config,
|
||||||
|
self.model.config))
|
||||||
self.register_module(module_name, new_module)
|
self.register_module(module_name, new_module)
|
||||||
self._register_packed_modules(module_name)
|
self._register_packed_modules(module_name)
|
||||||
new_module.set_mapping(self.base_indices, self.sampler_indices,
|
new_module.set_mapping(self.base_indices, self.sampler_indices,
|
||||||
|
|||||||
106
vllm/model_executor/layers/logits_processor.py
Normal file
106
vllm/model_executor/layers/logits_processor.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
"""A layer that compute logits from hidden_stats."""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.utils import is_neuron
|
||||||
|
|
||||||
|
from vllm.model_executor.parallel_utils.communication_op import (
|
||||||
|
tensor_model_parallel_gather)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class LogitsProcessor(nn.Module):
|
||||||
|
"""Process logits and apply logits processors from sampling metadata.
|
||||||
|
|
||||||
|
This layer does the following:
|
||||||
|
1. Gather logits from model hidden_states.
|
||||||
|
2. Scale logits if needed.
|
||||||
|
3. Apply logits processors (if any).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
vocab_size: int,
|
||||||
|
org_vocab_size: Optional[int] = None,
|
||||||
|
scale: Optional[float] = 1.0) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
scale: A scaling factor to apply to the logits.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.scale = scale
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
# Transformers-neuronx generate outputs as logits directly.
|
||||||
|
self.logits_as_hidden_states = is_neuron()
|
||||||
|
# original vocabulary size (without LoRA).
|
||||||
|
self.org_vocab_size = org_vocab_size or vocab_size
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
embedding: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
embedding_bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if self.logits_as_hidden_states:
|
||||||
|
logits = hidden_states
|
||||||
|
else:
|
||||||
|
hidden_states = _prune_hidden_states(hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
|
||||||
|
# Get the logits for the next tokens.
|
||||||
|
logits = self._get_logits(hidden_states, embedding, embedding_bias)
|
||||||
|
|
||||||
|
if logits is not None:
|
||||||
|
logits *= self.scale
|
||||||
|
|
||||||
|
# Apply logits processors (if any).
|
||||||
|
logits = _apply_logits_processors(logits, sampling_metadata)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
|
||||||
|
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
|
# Get the logits for the next tokens.
|
||||||
|
logits = torch.matmul(hidden_states, embedding.t())
|
||||||
|
if embedding_bias is not None:
|
||||||
|
logits += embedding_bias
|
||||||
|
logits = tensor_model_parallel_gather(logits)
|
||||||
|
# Remove paddings in vocab (if any).
|
||||||
|
if logits is not None:
|
||||||
|
logits = logits[:, :self.org_vocab_size]
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def _prune_hidden_states(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
return hidden_states.index_select(0,
|
||||||
|
sampling_metadata.selected_token_indices)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_logits_processors(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
logits_row_idx = 0
|
||||||
|
found_logits_processors = False
|
||||||
|
for seq_ids, sampling_params in sampling_metadata.seq_groups:
|
||||||
|
logits_processors = sampling_params.logits_processors
|
||||||
|
if logits_processors:
|
||||||
|
found_logits_processors = True
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
logits_row = logits[logits_row_idx]
|
||||||
|
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
|
||||||
|
for logits_processor in logits_processors:
|
||||||
|
logits_row = logits_processor(token_ids, logits_row)
|
||||||
|
logits[logits_row_idx] = logits_row
|
||||||
|
logits_row_idx += 1
|
||||||
|
else:
|
||||||
|
logits_row_idx += len(seq_ids)
|
||||||
|
if found_logits_processors:
|
||||||
|
assert logits_row_idx == logits.shape[0]
|
||||||
|
return logits
|
||||||
@ -4,8 +4,6 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.model_executor.parallel_utils.communication_op import (
|
|
||||||
tensor_model_parallel_gather)
|
|
||||||
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
||||||
SamplingTensors)
|
SamplingTensors)
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
@ -13,7 +11,6 @@ from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
|||||||
SamplerOutput, SequenceData, SequenceGroupOutput,
|
SamplerOutput, SequenceData, SequenceGroupOutput,
|
||||||
SequenceOutput)
|
SequenceOutput)
|
||||||
from vllm.model_executor.layers.ops.sample import (sample as sample_triton)
|
from vllm.model_executor.layers.ops.sample import (sample as sample_triton)
|
||||||
from vllm.utils import is_neuron
|
|
||||||
|
|
||||||
|
|
||||||
class Sampler(nn.Module):
|
class Sampler(nn.Module):
|
||||||
@ -31,58 +28,14 @@ class Sampler(nn.Module):
|
|||||||
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
vocab_size: int,
|
|
||||||
org_vocab_size: Optional[int] = None) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
# Transformers-neuronx generate outputs as logits directly.
|
|
||||||
self.logits_as_hidden_states = is_neuron()
|
|
||||||
# original vocabulary size (without LoRA).
|
|
||||||
self.org_vocab_size = org_vocab_size or vocab_size
|
|
||||||
|
|
||||||
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
|
|
||||||
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
||||||
# Get the logits for the next tokens.
|
|
||||||
logits = torch.matmul(hidden_states, embedding.t())
|
|
||||||
if embedding_bias is not None:
|
|
||||||
logits += embedding_bias
|
|
||||||
logits = tensor_model_parallel_gather(logits)
|
|
||||||
# Remove paddings in vocab (if any).
|
|
||||||
if logits is not None:
|
|
||||||
logits = logits[:, :self.org_vocab_size]
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
embedding: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
embedding_bias: Optional[torch.Tensor] = None,
|
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
# Get the hidden states that we use for sampling.
|
|
||||||
if self.logits_as_hidden_states:
|
|
||||||
logits = hidden_states
|
|
||||||
else:
|
|
||||||
hidden_states = _prune_hidden_states(hidden_states,
|
|
||||||
sampling_metadata)
|
|
||||||
|
|
||||||
# Get the logits for the next tokens.
|
|
||||||
logits = self._get_logits(hidden_states, embedding, embedding_bias)
|
|
||||||
|
|
||||||
# Only perform sampling in the driver worker.
|
|
||||||
# Note: `_get_logits` is still distributed across TP workers because
|
|
||||||
# the `embedding` weight is distributed across TP workers.
|
|
||||||
# TODO(zhuohan): Change the get_logits part to a separate stage.
|
|
||||||
if not sampling_metadata.perform_sampling:
|
|
||||||
return None
|
|
||||||
|
|
||||||
assert logits is not None
|
assert logits is not None
|
||||||
_, vocab_size = logits.shape
|
_, vocab_size = logits.shape
|
||||||
|
|
||||||
# Apply logits processors (if any).
|
|
||||||
logits = _apply_logits_processors(logits, sampling_metadata)
|
|
||||||
|
|
||||||
# Prepare sampling tensors with pinned memory to avoid blocking.
|
# Prepare sampling tensors with pinned memory to avoid blocking.
|
||||||
(sampling_tensors, do_penalties, do_top_p_top_k,
|
(sampling_tensors, do_penalties, do_top_p_top_k,
|
||||||
do_min_p) = SamplingTensors.from_sampling_metadata(
|
do_min_p) = SamplingTensors.from_sampling_metadata(
|
||||||
@ -124,14 +77,6 @@ class Sampler(nn.Module):
|
|||||||
prompt_logprobs, sample_logprobs)
|
prompt_logprobs, sample_logprobs)
|
||||||
|
|
||||||
|
|
||||||
def _prune_hidden_states(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return hidden_states.index_select(0,
|
|
||||||
sampling_metadata.selected_token_indices)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_bin_counts_and_mask(
|
def _get_bin_counts_and_mask(
|
||||||
tokens: torch.Tensor,
|
tokens: torch.Tensor,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
@ -149,30 +94,6 @@ def _get_bin_counts_and_mask(
|
|||||||
return bin_counts, mask
|
return bin_counts, mask
|
||||||
|
|
||||||
|
|
||||||
def _apply_logits_processors(
|
|
||||||
logits: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
logits_row_idx = 0
|
|
||||||
found_logits_processors = False
|
|
||||||
for seq_ids, sampling_params in sampling_metadata.seq_groups:
|
|
||||||
logits_processors = sampling_params.logits_processors
|
|
||||||
if logits_processors:
|
|
||||||
found_logits_processors = True
|
|
||||||
for seq_id in seq_ids:
|
|
||||||
logits_row = logits[logits_row_idx]
|
|
||||||
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
|
|
||||||
for logits_processor in logits_processors:
|
|
||||||
logits_row = logits_processor(token_ids, logits_row)
|
|
||||||
logits[logits_row_idx] = logits_row
|
|
||||||
logits_row_idx += 1
|
|
||||||
else:
|
|
||||||
logits_row_idx += len(seq_ids)
|
|
||||||
if found_logits_processors:
|
|
||||||
assert logits_row_idx == logits.shape[0]
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
||||||
output_tokens_tensor: torch.Tensor,
|
output_tokens_tensor: torch.Tensor,
|
||||||
presence_penalties: torch.Tensor,
|
presence_penalties: torch.Tensor,
|
||||||
|
|||||||
@ -34,6 +34,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
@ -295,7 +296,8 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.model = BaiChuanModel(config, position_embedding, linear_method)
|
self.model = BaiChuanModel(config, position_embedding, linear_method)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -308,13 +310,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
LinearMethodBase,
|
LinearMethodBase,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
@ -273,7 +274,8 @@ class BloomForCausalLM(nn.Module):
|
|||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.transformer = BloomModel(config, linear_method)
|
self.transformer = BloomModel(config, linear_method)
|
||||||
self.lm_head_weight = self.transformer.word_embeddings.weight
|
self.lm_head_weight = self.transformer.word_embeddings.weight
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -286,13 +288,18 @@ class BloomForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
@ -332,7 +333,8 @@ class ChatGLMForCausalLM(nn.Module):
|
|||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.transformer = ChatGLMModel(config, linear_method)
|
self.transformer = ChatGLMModel(config, linear_method)
|
||||||
self.lm_head_weight = self.transformer.output_layer.weight
|
self.lm_head_weight = self.transformer.output_layer.weight
|
||||||
self.sampler = Sampler(config.padded_vocab_size)
|
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -345,13 +347,18 @@ class ChatGLMForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -38,6 +38,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
@ -372,7 +373,8 @@ class DeepseekForCausalLM(nn.Module):
|
|||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.model = DeepseekModel(config, linear_method)
|
self.model = DeepseekModel(config, linear_method)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -385,13 +387,18 @@ class DeepseekForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[torch.Tensor],
|
logits: Optional[torch.Tensor],
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -34,6 +34,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
@ -373,7 +374,8 @@ class FalconForCausalLM(nn.Module):
|
|||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
)
|
)
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -390,13 +392,18 @@ class FalconForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
@ -281,7 +282,8 @@ class GemmaForCausalLM(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.model = GemmaModel(config, linear_method)
|
self.model = GemmaModel(config, linear_method)
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@ -295,13 +297,18 @@ class GemmaForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.model.embed_tokens.weight,
|
||||||
|
hidden_states, sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.model.embed_tokens.weight,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
hidden_states, sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
LinearMethodBase,
|
LinearMethodBase,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
@ -216,7 +217,8 @@ class GPT2LMHeadModel(nn.Module):
|
|||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.transformer = GPT2Model(config, linear_method)
|
self.transformer = GPT2Model(config, linear_method)
|
||||||
self.lm_head_weight = self.transformer.wte.weight
|
self.lm_head_weight = self.transformer.wte.weight
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -229,12 +231,18 @@ class GPT2LMHeadModel(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head_weight, logits,
|
||||||
sampling_metadata)
|
sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
|
|||||||
@ -31,6 +31,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
LinearMethodBase,
|
LinearMethodBase,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
@ -237,7 +238,8 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.transformer = GPTBigCodeModel(config, linear_method)
|
self.transformer = GPTBigCodeModel(config, linear_method)
|
||||||
self.lm_head_weight = self.transformer.wte.weight
|
self.lm_head_weight = self.transformer.wte.weight
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -250,13 +252,18 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
@ -224,7 +225,8 @@ class GPTJForCausalLM(nn.Module):
|
|||||||
config.n_embd,
|
config.n_embd,
|
||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -237,13 +239,18 @@ class GPTJForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||||
|
sampling_metadata, self.lm_head.bias)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
sampling_metadata, self.lm_head.bias)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
@ -238,7 +239,8 @@ class GPTNeoXForCausalLM(nn.Module):
|
|||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
)
|
)
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -251,13 +253,18 @@ class GPTNeoXForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.embed_out.weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.embed_out.weight, hidden_states,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
@ -250,7 +251,8 @@ class InternLM2ForCausalLM(nn.Module):
|
|||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.model = InternLM2Model(config, linear_method)
|
self.model = InternLM2Model(config, linear_method)
|
||||||
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -263,13 +265,18 @@ class InternLM2ForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.output.weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.output.weight, hidden_states,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
|
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
|
||||||
@ -325,7 +326,11 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
# compatibility
|
# compatibility
|
||||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||||
)
|
)
|
||||||
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
|
|
||||||
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size, logit_scale)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -338,13 +343,18 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
|||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
|
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
|
||||||
@ -369,7 +370,9 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
# compatibility
|
# compatibility
|
||||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||||
)
|
)
|
||||||
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -382,13 +385,18 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[torch.Tensor],
|
logits: Optional[torch.Tensor],
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -39,6 +39,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
@ -344,7 +345,8 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.model = MixtralModel(config, linear_method)
|
self.model = MixtralModel(config, linear_method)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -357,13 +359,18 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[torch.Tensor],
|
logits: Optional[torch.Tensor],
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
LinearMethodBase,
|
LinearMethodBase,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
@ -259,7 +260,8 @@ class MPTForCausalLM(nn.Module):
|
|||||||
|
|
||||||
self.transformer = MPTModel(config, linear_method)
|
self.transformer = MPTModel(config, linear_method)
|
||||||
self.lm_head_weight = self.transformer.wte.weight
|
self.lm_head_weight = self.transformer.wte.weight
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -272,13 +274,18 @@ class MPTForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from torch import nn
|
|||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -25,7 +26,8 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.model = None
|
self.model = None
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -45,13 +47,18 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
start_ids=seq_ids.flatten())
|
start_ids=seq_ids.flatten())
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.model.chkpt_model.lm_head,
|
||||||
|
hidden_states, sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.model.chkpt_model.lm_head,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
hidden_states, sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from torch import nn
|
|||||||
from transformers import MistralConfig
|
from transformers import MistralConfig
|
||||||
|
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -26,7 +27,8 @@ class MistralForCausalLM(nn.Module):
|
|||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.model = None
|
self.model = None
|
||||||
self.lm_head = None
|
self.lm_head = None
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -48,13 +50,18 @@ class MistralForCausalLM(nn.Module):
|
|||||||
start_ids=seq_ids)
|
start_ids=seq_ids)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.model.chkpt_model.lm_head,
|
||||||
|
hidden_states, sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.model.chkpt_model.lm_head,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
hidden_states, sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -51,6 +51,7 @@ from vllm.model_executor.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
@ -336,7 +337,8 @@ class OLMoForCausalLM(nn.Module):
|
|||||||
self.lm_head_weight = (self.model.transformer.wte.weight
|
self.lm_head_weight = (self.model.transformer.wte.weight
|
||||||
if config.weight_tying else
|
if config.weight_tying else
|
||||||
self.model.transformer.ff_out.weight)
|
self.model.transformer.ff_out.weight)
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -353,13 +355,18 @@ class OLMoForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(
|
def load_weights(
|
||||||
|
|||||||
@ -31,6 +31,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
@ -292,7 +293,8 @@ class OPTForCausalLM(nn.Module):
|
|||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.model = OPTModel(config, linear_method)
|
self.model = OPTModel(config, linear_method)
|
||||||
self.lm_head_weight = self.model.decoder.embed_tokens.weight
|
self.lm_head_weight = self.model.decoder.embed_tokens.weight
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -305,13 +307,18 @@ class OPTForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
@ -256,7 +257,8 @@ class OrionForCausalLM(nn.Module):
|
|||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.model = OrionModel(config, linear_method)
|
self.model = OrionModel(config, linear_method)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -269,13 +271,18 @@ class OrionForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -49,6 +49,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
@ -240,7 +241,8 @@ class PhiForCausalLM(nn.Module):
|
|||||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
bias=True)
|
bias=True)
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -254,14 +256,18 @@ class PhiForCausalLM(nn.Module):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||||
|
sampling_metadata, self.lm_head.bias)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
head = self.lm_head
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
next_tokens = self.sampler(head.weight, hidden_states,
|
|
||||||
sampling_metadata, head.bias)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
@ -230,7 +231,8 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.transformer = QWenModel(config, linear_method)
|
self.transformer = QWenModel(config, linear_method)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -243,13 +245,18 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
@ -300,11 +301,15 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.model = Qwen2Model(config, linear_method)
|
self.model = Qwen2Model(config, linear_method)
|
||||||
|
|
||||||
if not config.tie_word_embeddings:
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head_weight = self.model.embed_tokens.weight
|
||||||
|
else:
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
config.hidden_size)
|
config.hidden_size)
|
||||||
|
self.lm_head_weight = self.lm_head.weight
|
||||||
|
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -317,17 +322,18 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
if self.config.tie_word_embeddings:
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
lm_head_weight = self.model.embed_tokens.weight
|
|
||||||
else:
|
|
||||||
lm_head_weight = self.lm_head.weight
|
|
||||||
next_tokens = self.sampler(lm_head_weight, hidden_states,
|
|
||||||
sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -33,6 +33,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
@ -238,7 +239,8 @@ class StablelmForCausalLM(nn.Module):
|
|||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.model = StableLMEpochModel(config, linear_method)
|
self.model = StableLMEpochModel(config, linear_method)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -251,13 +253,18 @@ class StablelmForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
LinearMethodBase,
|
LinearMethodBase,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
|
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
|
||||||
@ -254,7 +255,9 @@ class Starcoder2ForCausalLM(nn.Module):
|
|||||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
)
|
)
|
||||||
self.lm_head_weight = self.lm_head.weight
|
self.lm_head_weight = self.lm_head.weight
|
||||||
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -267,13 +270,18 @@ class Starcoder2ForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[torch.Tensor],
|
logits: Optional[torch.Tensor],
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
sampling_metadata)
|
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -613,9 +613,16 @@ class ModelRunner:
|
|||||||
input_metadata=input_metadata,
|
input_metadata=input_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Compute the logits.
|
||||||
|
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||||
|
|
||||||
|
# Only perform sampling in the driver worker.
|
||||||
|
if not sampling_metadata.perform_sampling:
|
||||||
|
return None
|
||||||
|
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
output = self.model.sample(
|
output = self.model.sample(
|
||||||
hidden_states=hidden_states,
|
logits=logits,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user