mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 22:44:32 +08:00
[V1][TPU] TPU multimodal model support for ragged attention (#14158)
Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
3e1d223626
commit
fbfc3ee37e
@ -15,14 +15,18 @@ from vllm.attention.backends.abstract import AttentionType
|
|||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.forward_context import get_forward_context, set_forward_context
|
from vllm.forward_context import get_forward_context, set_forward_context
|
||||||
|
from vllm.inputs import INPUT_REGISTRY
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
|
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
||||||
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
|
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
|
||||||
NUM_QUERIES_PER_BLOCK,
|
NUM_QUERIES_PER_BLOCK,
|
||||||
PallasAttentionBackend,
|
PallasAttentionBackend,
|
||||||
PallasMetadata)
|
PallasMetadata)
|
||||||
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec)
|
KVCacheSpec)
|
||||||
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
|
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
|
||||||
@ -72,8 +76,10 @@ class TPUModelRunner:
|
|||||||
self.block_size = cache_config.block_size
|
self.block_size = cache_config.block_size
|
||||||
self.max_model_len = model_config.max_model_len
|
self.max_model_len = model_config.max_model_len
|
||||||
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
||||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
self.max_num_tokens = _get_padded_number(
|
||||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
scheduler_config.max_num_batched_tokens, NUM_QUERIES_PER_BLOCK)
|
||||||
|
self.max_num_reqs = _get_padded_number(scheduler_config.max_num_seqs,
|
||||||
|
NUM_QUERIES_PER_BLOCK)
|
||||||
|
|
||||||
# Model-related.
|
# Model-related.
|
||||||
self.num_attn_layers = model_config.get_num_layers_by_block_type(
|
self.num_attn_layers = model_config.get_num_layers_by_block_type(
|
||||||
@ -84,6 +90,28 @@ class TPUModelRunner:
|
|||||||
self.head_size = model_config.get_head_size()
|
self.head_size = model_config.get_head_size()
|
||||||
self.hidden_size = model_config.get_hidden_size()
|
self.hidden_size = model_config.get_hidden_size()
|
||||||
|
|
||||||
|
# Multi-modal data support
|
||||||
|
self.input_registry = INPUT_REGISTRY
|
||||||
|
self.mm_registry = MULTIMODAL_REGISTRY
|
||||||
|
self.uses_mrope = model_config.uses_mrope
|
||||||
|
# TODO: Support M-RoPE (e.g, Qwen2-VL)
|
||||||
|
assert not self.uses_mrope, "TPU does not support M-RoPE yet."
|
||||||
|
|
||||||
|
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
||||||
|
model_config=model_config,
|
||||||
|
scheduler_config=scheduler_config,
|
||||||
|
)
|
||||||
|
self.max_num_encoder_input_tokens = encoder_compute_budget
|
||||||
|
self.encoder_cache_size = encoder_cache_size
|
||||||
|
|
||||||
|
# Lazy initialization
|
||||||
|
# self.model: nn.Module # Set after load_model
|
||||||
|
self.kv_caches: list[torch.Tensor] = []
|
||||||
|
# req_id -> (input_id -> encoder_output)
|
||||||
|
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||||
|
|
||||||
|
# Request states.
|
||||||
|
self.requests: dict[str, CachedRequestState] = {}
|
||||||
# Persistent batch.
|
# Persistent batch.
|
||||||
self.input_batch = InputBatch(
|
self.input_batch = InputBatch(
|
||||||
max_num_reqs=self.max_num_reqs,
|
max_num_reqs=self.max_num_reqs,
|
||||||
@ -91,18 +119,9 @@ class TPUModelRunner:
|
|||||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
vocab_size=self.model_config.get_vocab_size(),
|
vocab_size=model_config.get_vocab_size(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Request states.
|
|
||||||
self.requests: dict[str, CachedRequestState] = {}
|
|
||||||
|
|
||||||
# req_id -> (input_id -> encoder_output)
|
|
||||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
|
||||||
|
|
||||||
# KV caches for forward pass
|
|
||||||
self.kv_caches: list[tuple[torch.Tensor, torch.Tensor]] = []
|
|
||||||
|
|
||||||
# Cached torch/numpy tensor
|
# Cached torch/numpy tensor
|
||||||
# The pytorch tensor and numpy array share the same buffer.
|
# The pytorch tensor and numpy array share the same buffer.
|
||||||
# Sometimes the numpy op is faster so we create both.
|
# Sometimes the numpy op is faster so we create both.
|
||||||
@ -164,6 +183,7 @@ class TPUModelRunner:
|
|||||||
# Remove finished requests from the cached states.
|
# Remove finished requests from the cached states.
|
||||||
for req_id in scheduler_output.finished_req_ids:
|
for req_id in scheduler_output.finished_req_ids:
|
||||||
self.requests.pop(req_id, None)
|
self.requests.pop(req_id, None)
|
||||||
|
self.encoder_cache.pop(req_id, None)
|
||||||
|
|
||||||
# Remove the finished requests from the persistent batch.
|
# Remove the finished requests from the persistent batch.
|
||||||
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
||||||
@ -177,6 +197,14 @@ class TPUModelRunner:
|
|||||||
if req_index is not None:
|
if req_index is not None:
|
||||||
removed_req_indices.append(req_index)
|
removed_req_indices.append(req_index)
|
||||||
|
|
||||||
|
# Free the cached encoder outputs.
|
||||||
|
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
||||||
|
encoder_outputs = self.encoder_cache.get(req_id)
|
||||||
|
if encoder_outputs is not None:
|
||||||
|
encoder_outputs.pop(input_id, None)
|
||||||
|
if not encoder_outputs:
|
||||||
|
self.encoder_cache.pop(req_id, None)
|
||||||
|
|
||||||
# Remove the unscheduled requests from the persistent batch.
|
# Remove the unscheduled requests from the persistent batch.
|
||||||
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
||||||
# or running requests that are not scheduled in this step. We remove
|
# or running requests that are not scheduled in this step. We remove
|
||||||
@ -426,6 +454,92 @@ class TPUModelRunner:
|
|||||||
logits_indices = query_start_loc[1:] - 1
|
logits_indices = query_start_loc[1:] - 1
|
||||||
return attn_metadata, logits_indices
|
return attn_metadata, logits_indices
|
||||||
|
|
||||||
|
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||||
|
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||||
|
if not scheduled_encoder_inputs:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Batch the multi-modal inputs.
|
||||||
|
mm_inputs: list[MultiModalKwargs] = []
|
||||||
|
req_input_ids: list[tuple[str, int]] = []
|
||||||
|
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||||
|
req_state = self.requests[req_id]
|
||||||
|
for input_id in encoder_input_ids:
|
||||||
|
mm_inputs.append(req_state.mm_inputs[input_id])
|
||||||
|
req_input_ids.append((req_id, input_id))
|
||||||
|
|
||||||
|
# Batch mm inputs as much as we can: if a request in the batch has
|
||||||
|
# multiple modalities or a different modality than the previous one,
|
||||||
|
# we process it separately to preserve item order.
|
||||||
|
# FIXME(ywang96): This is a hacky way to deal with multiple modalities
|
||||||
|
# in the same batch while still being able to benefit from batching
|
||||||
|
# multimodal inputs. The proper solution should be reordering the
|
||||||
|
# encoder outputs.
|
||||||
|
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
|
||||||
|
|
||||||
|
encoder_outputs = []
|
||||||
|
for grouped_mm_inputs in grouped_mm_inputs_list:
|
||||||
|
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
||||||
|
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
# Run the encoder.
|
||||||
|
# `curr_group_outputs` is either of the following:
|
||||||
|
# 1. A tensor of shape (num_items, feature_size, hidden_size)
|
||||||
|
# in case feature_size is fixed across all multimodal items.
|
||||||
|
# 2. A list or tuple (length: num_items) of tensors, each of shape
|
||||||
|
# (feature_size, hidden_size) in case the feature size is dynamic
|
||||||
|
# depending on the input multimodal items.
|
||||||
|
curr_group_outputs = self.model.get_multimodal_embeddings(
|
||||||
|
**batched_mm_inputs)
|
||||||
|
|
||||||
|
for output in curr_group_outputs:
|
||||||
|
encoder_outputs.append(output)
|
||||||
|
|
||||||
|
# Cache the encoder outputs.
|
||||||
|
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
|
||||||
|
if req_id not in self.encoder_cache:
|
||||||
|
self.encoder_cache[req_id] = {}
|
||||||
|
self.encoder_cache[req_id][input_id] = output
|
||||||
|
|
||||||
|
def _gather_encoder_outputs(
|
||||||
|
self,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
) -> list[torch.Tensor]:
|
||||||
|
encoder_outputs: list[torch.Tensor] = []
|
||||||
|
for req_id in self.input_batch.req_ids:
|
||||||
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||||
|
req_id]
|
||||||
|
req_state = self.requests[req_id]
|
||||||
|
num_computed_tokens = req_state.num_computed_tokens
|
||||||
|
mm_positions = req_state.mm_positions
|
||||||
|
for i, pos_info in enumerate(mm_positions):
|
||||||
|
start_pos = pos_info["offset"]
|
||||||
|
num_encoder_tokens = pos_info["length"]
|
||||||
|
|
||||||
|
# The encoder output is needed if the two ranges overlap:
|
||||||
|
# [num_computed_tokens,
|
||||||
|
# num_computed_tokens + num_scheduled_tokens) and
|
||||||
|
# [start_pos, start_pos + num_encoder_tokens)
|
||||||
|
if start_pos >= num_computed_tokens + num_scheduled_tokens:
|
||||||
|
# The encoder output is not needed in this step.
|
||||||
|
break
|
||||||
|
if start_pos + num_encoder_tokens <= num_computed_tokens:
|
||||||
|
# The encoder output is already processed and stored
|
||||||
|
# in the decoder's KV cache.
|
||||||
|
continue
|
||||||
|
|
||||||
|
start_idx = max(num_computed_tokens - start_pos, 0)
|
||||||
|
end_idx = min(
|
||||||
|
num_computed_tokens - start_pos + num_scheduled_tokens,
|
||||||
|
num_encoder_tokens)
|
||||||
|
assert start_idx < end_idx
|
||||||
|
assert req_id in self.encoder_cache
|
||||||
|
assert i in self.encoder_cache[req_id]
|
||||||
|
encoder_output = self.encoder_cache[req_id][i]
|
||||||
|
encoder_outputs.append(encoder_output[start_idx:end_idx])
|
||||||
|
return encoder_outputs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
@ -434,16 +548,42 @@ class TPUModelRunner:
|
|||||||
# Update cached state
|
# Update cached state
|
||||||
self._update_states(scheduler_output)
|
self._update_states(scheduler_output)
|
||||||
|
|
||||||
|
if self.is_multimodal_model:
|
||||||
|
# Run the multimodal encoder if any.
|
||||||
|
self._execute_encoder(scheduler_output)
|
||||||
|
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
|
||||||
|
else:
|
||||||
|
encoder_outputs = []
|
||||||
|
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
|
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
|
||||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
|
|
||||||
|
if self.is_multimodal_model:
|
||||||
|
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||||
|
# embeddings), we always use embeddings (rather than token ids)
|
||||||
|
# as input to the multimodal model, even when the input is text.
|
||||||
|
if encoder_outputs:
|
||||||
|
inputs_embeds = self.model.get_input_embeddings(
|
||||||
|
self.input_ids, encoder_outputs)
|
||||||
|
else:
|
||||||
|
inputs_embeds = self.model.get_input_embeddings(self.input_ids)
|
||||||
|
input_ids = None
|
||||||
|
else:
|
||||||
|
# For text-only models, we use token ids as input.
|
||||||
|
# While it is possible to use embeddings as input just like the
|
||||||
|
# multimodal models, it is not desirable for performance since
|
||||||
|
# then the embedding layer is not included in the CUDA graph.
|
||||||
|
input_ids = self.input_ids
|
||||||
|
inputs_embeds = None
|
||||||
|
|
||||||
# Run the decoder
|
# Run the decoder
|
||||||
with set_forward_context(attn_metadata, self.vllm_config):
|
with set_forward_context(attn_metadata, self.vllm_config):
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
token_ids=self.input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=self.position_ids,
|
positions=self.position_ids,
|
||||||
kv_caches=self.kv_caches,
|
kv_caches=self.kv_caches,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
hidden_states = hidden_states[:total_num_scheduled_tokens]
|
hidden_states = hidden_states[:total_num_scheduled_tokens]
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
@ -538,14 +678,21 @@ class TPUModelRunner:
|
|||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
dynamic=False)
|
dynamic=False)
|
||||||
|
|
||||||
def dummy_run(
|
def _dummy_run(
|
||||||
self,
|
self,
|
||||||
kv_caches,
|
kv_caches,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
input_ids = torch.zeros(num_tokens,
|
if self.is_multimodal_model:
|
||||||
dtype=torch.int32,
|
input_ids = None
|
||||||
device=self.device)
|
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=self.device)
|
||||||
|
else:
|
||||||
|
input_ids = torch.zeros((num_tokens),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
inputs_embeds = None
|
||||||
position_ids = torch.zeros(num_tokens,
|
position_ids = torch.zeros(num_tokens,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
@ -571,7 +718,10 @@ class TPUModelRunner:
|
|||||||
num_seqs=num_tokens,
|
num_seqs=num_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch._dynamo.mark_dynamic(input_ids, 0)
|
if self.is_multimodal_model:
|
||||||
|
torch._dynamo.mark_dynamic(inputs_embeds, 0)
|
||||||
|
else:
|
||||||
|
torch._dynamo.mark_dynamic(input_ids, 0)
|
||||||
torch._dynamo.mark_dynamic(position_ids, 0)
|
torch._dynamo.mark_dynamic(position_ids, 0)
|
||||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||||
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
||||||
@ -580,7 +730,12 @@ class TPUModelRunner:
|
|||||||
|
|
||||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
self.model(input_ids, position_ids, kv_caches)
|
self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=position_ids,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
def capture_model(self) -> None:
|
def capture_model(self) -> None:
|
||||||
"""Compile the model."""
|
"""Compile the model."""
|
||||||
@ -590,11 +745,11 @@ class TPUModelRunner:
|
|||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
num_tokens = 16
|
num_tokens = 16
|
||||||
while True:
|
while True:
|
||||||
self.dummy_run(self.kv_caches, num_tokens)
|
self._dummy_run(self.kv_caches, num_tokens)
|
||||||
logger.info(" -- num_tokens: %d", num_tokens)
|
logger.info(" -- num_tokens: %d", num_tokens)
|
||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
|
if num_tokens >= self.max_num_tokens:
|
||||||
break
|
break
|
||||||
num_tokens *= 2
|
num_tokens *= 2
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
@ -647,17 +802,20 @@ class ModelWrapperV1(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
token_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
|
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Executes the forward pass of the model and samples the next token.
|
"""Executes the forward pass of the model and samples the next token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token_ids: The input token IDs of shape [num_tokens].
|
input_ids: The input token IDs of shape [num_tokens].
|
||||||
position_ids: The input position IDs of shape [num_tokens].
|
positions: The input position IDs of shape [num_tokens].
|
||||||
kv_caches: The key and value caches. They can be None during the
|
kv_caches: The key and value caches. They can be None during the
|
||||||
memory profiling at initialization.
|
memory profiling at initialization.
|
||||||
|
inputs_embeds: The input embeddings of shape [num_tokens,
|
||||||
|
hidden_size]. It is used for multimodal models.
|
||||||
"""
|
"""
|
||||||
# Skip this in memory profiling at initialization.
|
# Skip this in memory profiling at initialization.
|
||||||
if kv_caches[0][0].numel() > 0:
|
if kv_caches[0][0].numel() > 0:
|
||||||
@ -684,9 +842,9 @@ class ModelWrapperV1(nn.Module):
|
|||||||
|
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
token_ids,
|
input_ids=input_ids,
|
||||||
position_ids,
|
positions=positions,
|
||||||
kv_caches,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@ -699,6 +857,12 @@ class ModelWrapperV1(nn.Module):
|
|||||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
def get_multimodal_embeddings(self, *args, **kwargs):
|
||||||
|
return self.model.get_multimodal_embeddings(*args, **kwargs)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, *args, **kwargs):
|
||||||
|
return self.model.get_input_embeddings(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _get_padded_number(n: int, multiple: int) -> int:
|
def _get_padded_number(n: int, multiple: int) -> int:
|
||||||
return ((n + multiple - 1) // multiple) * multiple
|
return ((n + multiple - 1) // multiple) * multiple
|
||||||
|
|||||||
@ -134,7 +134,7 @@ class TPUWorker:
|
|||||||
self.vllm_config.compilation_config.static_forward_context,
|
self.vllm_config.compilation_config.static_forward_context,
|
||||||
runner_kv_caches)
|
runner_kv_caches)
|
||||||
|
|
||||||
self.model_runner.dummy_run(
|
self.model_runner._dummy_run(
|
||||||
runner_kv_caches,
|
runner_kv_caches,
|
||||||
num_tokens=self.scheduler_config.max_num_batched_tokens,
|
num_tokens=self.scheduler_config.max_num_batched_tokens,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user