mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-28 00:31:55 +08:00
[Misc] Remove dangling references to SamplingType.BEAM (#13402)
This commit is contained in:
parent
b3942e157e
commit
efbe854448
@ -68,7 +68,6 @@ class SampleResultArgsType:
|
||||
sample_results_dict: SampleResultsDictType
|
||||
sampling_metadata: SamplingMetadata
|
||||
greedy_samples: Optional[torch.Tensor]
|
||||
beam_search_logprobs: Optional[torch.Tensor]
|
||||
|
||||
|
||||
# Union of non-deferred (single-step scheduling)
|
||||
@ -510,74 +509,6 @@ def _random_sample(
|
||||
return results
|
||||
|
||||
|
||||
def _beam_search_sample(
|
||||
selected_seq_groups: List[SequenceGroupToSample],
|
||||
logprobs: torch.Tensor,
|
||||
) -> SampleResultType:
|
||||
"""Run beam sampling on a given samples.
|
||||
|
||||
Args:
|
||||
selected_seq_groups: A list of sequence groups batched.
|
||||
logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
|
||||
on selected sample indices.
|
||||
Returns:
|
||||
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
||||
same as the length of selected_seq_groups. If the corresponding
|
||||
seq_group has do_sample=False, tuple contains ([], [])
|
||||
"""
|
||||
# We sample 2 * beam_width candidates to make sure that with high
|
||||
# probability we can get `beam_width` candidates in addition to
|
||||
# the finished sequences for the next iteration. See
|
||||
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
|
||||
# for details. See also HF reference:
|
||||
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
||||
#
|
||||
# NOTE: Beam search is not vectorized, so its speed can be slower than
|
||||
# other sampling methods.
|
||||
sample_idx = 0
|
||||
results: SampleResultType = []
|
||||
for seq_group in selected_seq_groups:
|
||||
if not seq_group.do_sample:
|
||||
results.append(([], []))
|
||||
continue
|
||||
|
||||
is_prompt = seq_group.is_prompt
|
||||
seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
|
||||
num_parent_seqs = len(seq_ids)
|
||||
beam_width = sampling_params.n
|
||||
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
|
||||
if is_prompt:
|
||||
# Prompt phase.
|
||||
assert num_parent_seqs == 1, (
|
||||
"Prompt input should have only one seq.")
|
||||
parent_ids = [0] * (2 * beam_width)
|
||||
_, next_token_ids = torch.topk(seq_group_logprobs[0],
|
||||
2 * beam_width)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
else:
|
||||
# Generation phase.
|
||||
cumulative_logprobs: List[float] = [
|
||||
seq_group.seq_data[seq_id].cumulative_logprob
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
cumulative_logprobs_tensor = torch.tensor(
|
||||
cumulative_logprobs,
|
||||
dtype=torch.float,
|
||||
device=seq_group_logprobs.device)
|
||||
seq_group_logprobs = (seq_group_logprobs +
|
||||
cumulative_logprobs_tensor.unsqueeze(dim=1))
|
||||
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
|
||||
2 * beam_width)
|
||||
topk_ids = topk_ids.tolist()
|
||||
vocab_size = seq_group_logprobs.size(-1)
|
||||
parent_ids = [i // vocab_size for i in topk_ids]
|
||||
next_token_ids = [i % vocab_size for i in topk_ids]
|
||||
results.append((next_token_ids, parent_ids))
|
||||
sample_idx += num_parent_seqs
|
||||
assert sample_idx == logprobs.size(0)
|
||||
return results
|
||||
|
||||
|
||||
# torch.multinomial forces a GPU<->CPU sync.
|
||||
# Therefore, we use an optimized implementation instead.
|
||||
# Note that we always sample with replacement.
|
||||
@ -666,14 +597,12 @@ def get_pythonized_sample_results(
|
||||
sampling_metadata,
|
||||
greedy_samples,
|
||||
multinomial_samples,
|
||||
beam_search_logprobs,
|
||||
sample_results_dict,
|
||||
) = (
|
||||
sample_result_args.sample_metadata,
|
||||
sample_result_args.sampling_metadata,
|
||||
sample_result_args.greedy_samples,
|
||||
sample_result_args.multinomial_samples,
|
||||
sample_result_args.beam_search_logprobs,
|
||||
sample_result_args.sample_results_dict,
|
||||
)
|
||||
|
||||
@ -686,9 +615,6 @@ def get_pythonized_sample_results(
|
||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||
sample_results = _random_sample(seq_groups,
|
||||
multinomial_samples[sampling_type])
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
sample_results = _beam_search_sample(seq_groups,
|
||||
beam_search_logprobs)
|
||||
sample_results_dict.update(zip(seq_group_id, sample_results))
|
||||
|
||||
return [
|
||||
@ -731,7 +657,6 @@ def _sample_with_torch(
|
||||
sample_metadata: SampleMetadataType = {}
|
||||
multinomial_samples: MultinomialSamplesType = {}
|
||||
greedy_samples: Optional[torch.Tensor] = None
|
||||
beam_search_logprobs: Optional[torch.Tensor] = None
|
||||
|
||||
# Create output tensor for sampled token ids.
|
||||
if include_gpu_probs_tensor:
|
||||
@ -800,8 +725,6 @@ def _sample_with_torch(
|
||||
sampled_token_ids_tensor[long_sample_indices] = \
|
||||
multinomial_samples[sampling_type].to(torch.long)
|
||||
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
beam_search_logprobs = logprobs[sample_indices]
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
||||
|
||||
@ -812,7 +735,6 @@ def _sample_with_torch(
|
||||
sample_metadata=sample_metadata,
|
||||
multinomial_samples=multinomial_samples,
|
||||
greedy_samples=greedy_samples,
|
||||
beam_search_logprobs=beam_search_logprobs,
|
||||
sample_results_dict=sample_results_dict)
|
||||
|
||||
if not sampling_metadata.skip_sampler_cpu_output:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user