mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 04:45:01 +08:00
[core] move parallel sampling out from vllm core (#9302)
This commit is contained in:
parent
ef7faad1b8
commit
76a5e13270
@ -340,6 +340,40 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
|
|||||||
assert "".join(chunks) == single_output
|
assert "".join(chunks) == single_output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||||
|
)
|
||||||
|
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
|
||||||
|
"""Streaming for parallel sampling.
|
||||||
|
The tokens from multiple samples, are flattened into a single stream,
|
||||||
|
with an index to indicate which sample the token belongs to.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt = "What is an LLM?"
|
||||||
|
n = 3
|
||||||
|
max_tokens = 5
|
||||||
|
|
||||||
|
stream = await client.completions.create(model=model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
n=n,
|
||||||
|
stream=True)
|
||||||
|
chunks: List[List[str]] = [[] for i in range(n)]
|
||||||
|
finish_reason_count = 0
|
||||||
|
async for chunk in stream:
|
||||||
|
index = chunk.choices[0].index
|
||||||
|
text = chunk.choices[0].text
|
||||||
|
chunks[index].append(text)
|
||||||
|
if chunk.choices[0].finish_reason is not None:
|
||||||
|
finish_reason_count += 1
|
||||||
|
assert finish_reason_count == n
|
||||||
|
for chunk in chunks:
|
||||||
|
assert len(chunk) == max_tokens
|
||||||
|
print("".join(chunk))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_name",
|
"model_name",
|
||||||
|
|||||||
@ -44,8 +44,10 @@ from vllm.pooling_params import PoolingParams
|
|||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||||
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
||||||
Sequence, SequenceGroup, SequenceGroupMetadata,
|
ParallelSampleSequenceGroup, Sequence,
|
||||||
SequenceGroupOutput, SequenceStatus)
|
SequenceGroup, SequenceGroupBase,
|
||||||
|
SequenceGroupMetadata, SequenceGroupOutput,
|
||||||
|
SequenceStatus)
|
||||||
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||||
init_tracer)
|
init_tracer)
|
||||||
from vllm.transformers_utils.config import try_get_generation_config
|
from vllm.transformers_utils.config import try_get_generation_config
|
||||||
@ -474,6 +476,8 @@ class LLMEngine:
|
|||||||
),
|
),
|
||||||
))
|
))
|
||||||
|
|
||||||
|
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
|
||||||
|
|
||||||
def _initialize_kv_caches(self) -> None:
|
def _initialize_kv_caches(self) -> None:
|
||||||
"""Initialize the KV cache in the worker(s).
|
"""Initialize the KV cache in the worker(s).
|
||||||
|
|
||||||
@ -642,7 +646,10 @@ class LLMEngine:
|
|||||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
) -> None:
|
) -> SequenceGroup:
|
||||||
|
"""Add a processed request to the engine's request pool.
|
||||||
|
return the created sequence group.
|
||||||
|
"""
|
||||||
self._validate_model_inputs(processed_inputs)
|
self._validate_model_inputs(processed_inputs)
|
||||||
# Create the sequences.
|
# Create the sequences.
|
||||||
block_size = self.cache_config.block_size
|
block_size = self.cache_config.block_size
|
||||||
@ -696,6 +703,8 @@ class LLMEngine:
|
|||||||
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
|
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
|
||||||
min_cost_scheduler.add_seq_group(seq_group)
|
min_cost_scheduler.add_seq_group(seq_group)
|
||||||
|
|
||||||
|
return seq_group
|
||||||
|
|
||||||
def stop_remote_worker_execution_loop(self) -> None:
|
def stop_remote_worker_execution_loop(self) -> None:
|
||||||
self.model_executor.stop_remote_worker_execution_loop()
|
self.model_executor.stop_remote_worker_execution_loop()
|
||||||
|
|
||||||
@ -711,7 +720,7 @@ class LLMEngine:
|
|||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
) -> None:
|
) -> Optional[SequenceGroup]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@ -725,7 +734,7 @@ class LLMEngine:
|
|||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
) -> None:
|
) -> Optional[SequenceGroup]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@deprecate_kwargs(
|
@deprecate_kwargs(
|
||||||
@ -744,7 +753,7 @@ class LLMEngine:
|
|||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
*,
|
*,
|
||||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||||
) -> None:
|
) -> Optional[SequenceGroup]:
|
||||||
"""Add a request to the engine's request pool.
|
"""Add a request to the engine's request pool.
|
||||||
|
|
||||||
The request is added to the request pool and will be processed by the
|
The request is added to the request pool and will be processed by the
|
||||||
@ -788,6 +797,22 @@ class LLMEngine:
|
|||||||
>>> # continue the request processing
|
>>> # continue the request processing
|
||||||
>>> ...
|
>>> ...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if isinstance(params, SamplingParams) and params.n > 1:
|
||||||
|
ParallelSampleSequenceGroup.add_request(
|
||||||
|
request_id,
|
||||||
|
self,
|
||||||
|
params,
|
||||||
|
prompt=prompt,
|
||||||
|
arrival_time=arrival_time,
|
||||||
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
priority=priority,
|
||||||
|
inputs=inputs,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
if inputs is not None:
|
if inputs is not None:
|
||||||
prompt = inputs
|
prompt = inputs
|
||||||
assert prompt is not None and params is not None
|
assert prompt is not None and params is not None
|
||||||
@ -818,7 +843,7 @@ class LLMEngine:
|
|||||||
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
|
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
|
||||||
"mm_processor_kwargs")
|
"mm_processor_kwargs")
|
||||||
|
|
||||||
self._add_processed_request(
|
return self._add_processed_request(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
processed_inputs=processed_inputs,
|
processed_inputs=processed_inputs,
|
||||||
params=params,
|
params=params,
|
||||||
@ -1135,7 +1160,9 @@ class LLMEngine:
|
|||||||
seq_group = scheduled_seq_group.seq_group
|
seq_group = scheduled_seq_group.seq_group
|
||||||
seq_group.maybe_set_first_token_time(now)
|
seq_group.maybe_set_first_token_time(now)
|
||||||
request_output = RequestOutputFactory.create(
|
request_output = RequestOutputFactory.create(
|
||||||
seq_group, use_cache=self.use_cached_outputs)
|
seq_group,
|
||||||
|
self.seq_id_to_seq_group,
|
||||||
|
use_cache=self.use_cached_outputs)
|
||||||
if request_output:
|
if request_output:
|
||||||
ctx.request_outputs.append(request_output)
|
ctx.request_outputs.append(request_output)
|
||||||
|
|
||||||
@ -1175,7 +1202,9 @@ class LLMEngine:
|
|||||||
seq_group = scheduled_seq_group.seq_group
|
seq_group = scheduled_seq_group.seq_group
|
||||||
seq_group.maybe_set_first_token_time(now)
|
seq_group.maybe_set_first_token_time(now)
|
||||||
request_output = RequestOutputFactory.create(
|
request_output = RequestOutputFactory.create(
|
||||||
seq_group, use_cache=self.use_cached_outputs)
|
seq_group,
|
||||||
|
self.seq_id_to_seq_group,
|
||||||
|
use_cache=self.use_cached_outputs)
|
||||||
if request_output:
|
if request_output:
|
||||||
ctx.request_outputs.append(request_output)
|
ctx.request_outputs.append(request_output)
|
||||||
|
|
||||||
@ -1194,7 +1223,10 @@ class LLMEngine:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
request_output = RequestOutputFactory.create(
|
request_output = RequestOutputFactory.create(
|
||||||
seq_group, use_cache=self.use_cached_outputs)
|
seq_group,
|
||||||
|
self.seq_id_to_seq_group,
|
||||||
|
use_cache=self.use_cached_outputs,
|
||||||
|
)
|
||||||
if request_output:
|
if request_output:
|
||||||
ctx.request_outputs.append(request_output)
|
ctx.request_outputs.append(request_output)
|
||||||
|
|
||||||
|
|||||||
@ -1,13 +1,13 @@
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional
|
from typing import Dict, List, Optional
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sampling_params import RequestOutputKind
|
from vllm.sampling_params import RequestOutputKind
|
||||||
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
|
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
|
||||||
SequenceGroup, SequenceStatus)
|
SequenceGroup, SequenceGroupBase, SequenceStatus)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -114,14 +114,28 @@ class RequestOutput:
|
|||||||
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_seq_group(cls, seq_group: SequenceGroup,
|
def from_seq_group(
|
||||||
use_cache: bool) -> Optional["RequestOutput"]:
|
cls, seq_group: SequenceGroup, use_cache: bool,
|
||||||
|
seq_id_to_seq_group: Dict[str, SequenceGroupBase]
|
||||||
|
) -> Optional["RequestOutput"]:
|
||||||
|
finished = seq_group.is_finished()
|
||||||
|
|
||||||
|
if seq_group.request_id in seq_id_to_seq_group:
|
||||||
|
group: SequenceGroupBase = seq_id_to_seq_group[
|
||||||
|
seq_group.request_id]
|
||||||
|
if finished:
|
||||||
|
group.finish_seq(seq_group)
|
||||||
|
assembled_seq_group = group.maybe_assemble_group(seq_group)
|
||||||
|
if assembled_seq_group is None:
|
||||||
|
return None
|
||||||
|
return cls.from_seq_group(assembled_seq_group, use_cache,
|
||||||
|
seq_id_to_seq_group)
|
||||||
|
|
||||||
sampling_params = seq_group.sampling_params
|
sampling_params = seq_group.sampling_params
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Sampling parameters are missing for a CompletionRequest.")
|
"Sampling parameters are missing for a CompletionRequest.")
|
||||||
|
|
||||||
finished = seq_group.is_finished()
|
|
||||||
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
|
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
|
||||||
not finished):
|
not finished):
|
||||||
return None
|
return None
|
||||||
@ -136,15 +150,7 @@ class RequestOutput:
|
|||||||
outputs=[],
|
outputs=[],
|
||||||
finished=False)
|
finished=False)
|
||||||
|
|
||||||
seqs = seq_group.get_seqs()
|
top_n_seqs = seq_group.get_seqs()
|
||||||
if len(seqs) == 1:
|
|
||||||
top_n_seqs = seqs
|
|
||||||
else:
|
|
||||||
# Get the top-n sequences.
|
|
||||||
n = sampling_params._real_n or sampling_params.n
|
|
||||||
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
|
||||||
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
|
||||||
top_n_seqs = sorted_seqs[:n]
|
|
||||||
|
|
||||||
# Create the outputs.
|
# Create the outputs.
|
||||||
# NOTE: We need omit logprobs here explicitly because the sequence
|
# NOTE: We need omit logprobs here explicitly because the sequence
|
||||||
@ -208,7 +214,7 @@ class RequestOutput:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
output = CompletionOutput(
|
output = CompletionOutput(
|
||||||
seqs.index(seq), output_text, [output_token_ids]
|
top_n_seqs.index(seq), output_text, [output_token_ids]
|
||||||
if isinstance(output_token_ids, int) else output_token_ids,
|
if isinstance(output_token_ids, int) else output_token_ids,
|
||||||
seq.get_cumulative_logprob() if include_logprobs else None,
|
seq.get_cumulative_logprob() if include_logprobs else None,
|
||||||
output_logprobs,
|
output_logprobs,
|
||||||
@ -309,10 +315,13 @@ class EmbeddingRequestOutput:
|
|||||||
class RequestOutputFactory:
|
class RequestOutputFactory:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(seq_group: SequenceGroup, use_cache: bool = False):
|
def create(seq_group: SequenceGroup,
|
||||||
|
seq_id_to_seq_group: Dict[str, SequenceGroupBase],
|
||||||
|
use_cache: bool = False):
|
||||||
# Determine the type based on a condition, for example:
|
# Determine the type based on a condition, for example:
|
||||||
if hasattr(seq_group,
|
if hasattr(seq_group,
|
||||||
'embeddings') and seq_group.embeddings is not None:
|
'embeddings') and seq_group.embeddings is not None:
|
||||||
return EmbeddingRequestOutput.from_seq_group(seq_group)
|
return EmbeddingRequestOutput.from_seq_group(seq_group)
|
||||||
else:
|
else:
|
||||||
return RequestOutput.from_seq_group(seq_group, use_cache)
|
return RequestOutput.from_seq_group(seq_group, use_cache,
|
||||||
|
seq_id_to_seq_group)
|
||||||
|
|||||||
122
vllm/sequence.py
122
vllm/sequence.py
@ -4,7 +4,7 @@ import enum
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from array import array
|
from array import array
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from functools import cached_property, reduce
|
from functools import cached_property, reduce
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
@ -17,7 +17,7 @@ from vllm.inputs.parse import is_encoder_decoder_inputs
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -1401,3 +1401,121 @@ class ExecuteModelRequest(
|
|||||||
last_sampled_token_ids=self.last_sampled_token_ids.clone()
|
last_sampled_token_ids=self.last_sampled_token_ids.clone()
|
||||||
if self.last_sampled_token_ids is not None else None,
|
if self.last_sampled_token_ids is not None else None,
|
||||||
async_callback=self.async_callback)
|
async_callback=self.async_callback)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SequenceGroupBase:
|
||||||
|
group_id: str # the original request id before splitting
|
||||||
|
|
||||||
|
assembled_seq_group: Optional[SequenceGroup] = None
|
||||||
|
|
||||||
|
# seq id to a unique index inside this group
|
||||||
|
seq_id_to_index: Dict[str, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# seq ids to be finished
|
||||||
|
to_be_finished: Dict[str, SequenceGroup] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# seq id to finished sequences
|
||||||
|
finished_reqs: Dict[str, SequenceGroup] = field(default_factory=dict)
|
||||||
|
|
||||||
|
streaming: bool = False
|
||||||
|
|
||||||
|
output_produced: bool = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_request(request_id: str, engine, params, *args, **kwargs):
|
||||||
|
"""When we are ready to add a request with request_id and params
|
||||||
|
into the engine, we can split the request into multiple requests.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def finish_seq(self, seq: SequenceGroup):
|
||||||
|
"""The sequence `seq` finishes, we should record the information.
|
||||||
|
"""
|
||||||
|
del self.to_be_finished[seq.request_id]
|
||||||
|
self.finished_reqs[seq.request_id] = seq
|
||||||
|
|
||||||
|
def maybe_assemble_group(
|
||||||
|
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
|
||||||
|
"""Assemble the sequence group, for producing the final
|
||||||
|
output, or adding request in the engine again.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelSampleSequenceGroup(SequenceGroupBase):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_request(request_id: str, engine, params, **kwargs):
|
||||||
|
original_params = params
|
||||||
|
params = copy.deepcopy(original_params)
|
||||||
|
params.n = 1
|
||||||
|
group = ParallelSampleSequenceGroup(request_id)
|
||||||
|
seqs = []
|
||||||
|
for i in range(original_params.n):
|
||||||
|
request_id_i = f"{request_id}_parallel_sample_{i}"
|
||||||
|
group.seq_id_to_index[request_id_i] = i
|
||||||
|
seq_group = engine.add_request(
|
||||||
|
request_id_i,
|
||||||
|
params=params,
|
||||||
|
**kwargs,
|
||||||
|
) # type: ignore
|
||||||
|
assert seq_group is not None
|
||||||
|
engine.seq_id_to_seq_group[request_id_i] = group
|
||||||
|
group.to_be_finished[request_id_i] = seq_group
|
||||||
|
seqs.append(seq_group.seqs[0])
|
||||||
|
|
||||||
|
# for parallel sampling, the `assembled_seq_group` is always
|
||||||
|
# available, since we have all the sequences ready, and they
|
||||||
|
# will not change.
|
||||||
|
group.assembled_seq_group = SequenceGroup(
|
||||||
|
request_id=request_id,
|
||||||
|
seqs=seqs,
|
||||||
|
arrival_time=seq_group.arrival_time,
|
||||||
|
sampling_params=original_params,
|
||||||
|
lora_request=seq_group.lora_request,
|
||||||
|
embeddings=seq_group.embeddings,
|
||||||
|
pooling_params=seq_group.pooling_params,
|
||||||
|
encoder_seq=seq_group.encoder_seq,
|
||||||
|
trace_headers=seq_group.trace_headers,
|
||||||
|
prompt_adapter_request=seq_group.prompt_adapter_request,
|
||||||
|
priority=seq_group.priority,
|
||||||
|
)
|
||||||
|
|
||||||
|
group.streaming = params.output_kind == RequestOutputKind.DELTA
|
||||||
|
group.output_produced = False
|
||||||
|
|
||||||
|
def maybe_assemble_group(
|
||||||
|
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
|
||||||
|
|
||||||
|
# in the streaming mode, we will return the assembled sequence
|
||||||
|
# for the first sequence, and then return None for the rest of
|
||||||
|
# sequences
|
||||||
|
if self.streaming:
|
||||||
|
if self.seq_id_to_index[seq_group.request_id] == 0:
|
||||||
|
return self.assembled_seq_group
|
||||||
|
return None
|
||||||
|
|
||||||
|
# in the non-streaming mode, we will return the assembled sequence
|
||||||
|
# once after all sequences finish, and then return None for the
|
||||||
|
# rest of the time
|
||||||
|
|
||||||
|
if len(self.to_be_finished) > 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert self.assembled_seq_group is not None
|
||||||
|
params = self.assembled_seq_group.sampling_params
|
||||||
|
assert isinstance(params, SamplingParams)
|
||||||
|
if not self.output_produced:
|
||||||
|
self.output_produced = True
|
||||||
|
if params._real_n is not None:
|
||||||
|
# Get the top-n sequences.
|
||||||
|
n = params._real_n or params.n
|
||||||
|
seqs = self.assembled_seq_group.seqs
|
||||||
|
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
||||||
|
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
||||||
|
top_n_seqs = sorted_seqs[:n]
|
||||||
|
self.assembled_seq_group.seqs = top_n_seqs
|
||||||
|
return self.assembled_seq_group
|
||||||
|
if self.output_produced:
|
||||||
|
return None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user