[Misc] Remove dangling references to SamplingType.BEAM (#13402)

This commit is contained in:
Harry Mellor 2025-02-18 03:52:35 +00:00 committed by GitHub
parent b3942e157e
commit efbe854448
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -68,7 +68,6 @@ class SampleResultArgsType:
sample_results_dict: SampleResultsDictType
sampling_metadata: SamplingMetadata
greedy_samples: Optional[torch.Tensor]
beam_search_logprobs: Optional[torch.Tensor]
# Union of non-deferred (single-step scheduling)
@ -510,74 +509,6 @@ def _random_sample(
return results
def _beam_search_sample(
selected_seq_groups: List[SequenceGroupToSample],
logprobs: torch.Tensor,
) -> SampleResultType:
"""Run beam sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
on selected sample indices.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# We sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
# for details. See also HF reference:
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
#
# NOTE: Beam search is not vectorized, so its speed can be slower than
# other sampling methods.
sample_idx = 0
results: SampleResultType = []
for seq_group in selected_seq_groups:
if not seq_group.do_sample:
results.append(([], []))
continue
is_prompt = seq_group.is_prompt
seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
num_parent_seqs = len(seq_ids)
beam_width = sampling_params.n
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
if is_prompt:
# Prompt phase.
assert num_parent_seqs == 1, (
"Prompt input should have only one seq.")
parent_ids = [0] * (2 * beam_width)
_, next_token_ids = torch.topk(seq_group_logprobs[0],
2 * beam_width)
next_token_ids = next_token_ids.tolist()
else:
# Generation phase.
cumulative_logprobs: List[float] = [
seq_group.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids
]
cumulative_logprobs_tensor = torch.tensor(
cumulative_logprobs,
dtype=torch.float,
device=seq_group_logprobs.device)
seq_group_logprobs = (seq_group_logprobs +
cumulative_logprobs_tensor.unsqueeze(dim=1))
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
2 * beam_width)
topk_ids = topk_ids.tolist()
vocab_size = seq_group_logprobs.size(-1)
parent_ids = [i // vocab_size for i in topk_ids]
next_token_ids = [i % vocab_size for i in topk_ids]
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0)
return results
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead.
# Note that we always sample with replacement.
@ -666,14 +597,12 @@ def get_pythonized_sample_results(
sampling_metadata,
greedy_samples,
multinomial_samples,
beam_search_logprobs,
sample_results_dict,
) = (
sample_result_args.sample_metadata,
sample_result_args.sampling_metadata,
sample_result_args.greedy_samples,
sample_result_args.multinomial_samples,
sample_result_args.beam_search_logprobs,
sample_result_args.sample_results_dict,
)
@ -686,9 +615,6 @@ def get_pythonized_sample_results(
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
sample_results_dict.update(zip(seq_group_id, sample_results))
return [
@ -731,7 +657,6 @@ def _sample_with_torch(
sample_metadata: SampleMetadataType = {}
multinomial_samples: MultinomialSamplesType = {}
greedy_samples: Optional[torch.Tensor] = None
beam_search_logprobs: Optional[torch.Tensor] = None
# Create output tensor for sampled token ids.
if include_gpu_probs_tensor:
@ -800,8 +725,6 @@ def _sample_with_torch(
sampled_token_ids_tensor[long_sample_indices] = \
multinomial_samples[sampling_type].to(torch.long)
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
@ -812,7 +735,6 @@ def _sample_with_torch(
sample_metadata=sample_metadata,
multinomial_samples=multinomial_samples,
greedy_samples=greedy_samples,
beam_search_logprobs=beam_search_logprobs,
sample_results_dict=sample_results_dict)
if not sampling_metadata.skip_sampler_cpu_output: