mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 16:46:38 +08:00
[BugFix][Model] Jamba - Handle aborted requests, Add tests and fix cleanup bug (#6425)
Co-authored-by: Mor Zusman <morz@ai21.com>
This commit is contained in:
parent
d6f3b3d5c4
commit
9ad32dacd9
@ -1,5 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.models.utils import check_outputs_equal
|
||||||
from vllm.worker.model_runner import _get_graph_batch_size
|
from vllm.worker.model_runner import _get_graph_batch_size
|
||||||
|
|
||||||
MODELS = ["ai21labs/Jamba-tiny-random"]
|
MODELS = ["ai21labs/Jamba-tiny-random"]
|
||||||
@ -34,6 +35,34 @@ def test_models(
|
|||||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [20])
|
||||||
|
def test_batching(
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
# To pass the small model tests, we need full precision.
|
||||||
|
for_loop_outputs = []
|
||||||
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
|
for prompt in example_prompts:
|
||||||
|
for_loop_outputs.append(
|
||||||
|
vllm_model.generate_greedy([prompt], max_tokens)[0])
|
||||||
|
|
||||||
|
batched_outputs = vllm_model.generate_greedy(example_prompts,
|
||||||
|
max_tokens)
|
||||||
|
|
||||||
|
check_outputs_equal(
|
||||||
|
outputs_0_lst=for_loop_outputs,
|
||||||
|
outputs_1_lst=batched_outputs,
|
||||||
|
name_0="for_loop_vllm",
|
||||||
|
name_1="batched_vllm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [20])
|
@pytest.mark.parametrize("max_tokens", [20])
|
||||||
@ -60,6 +89,60 @@ def test_mamba_cache_cg_padding(
|
|||||||
"Could be related to mamba cache not padded correctly")
|
"Could be related to mamba cache not padded correctly")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [20])
|
||||||
|
def test_models_preemption_recompute(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
# Tests that outputs are identical with and w/o preemtions (recompute)
|
||||||
|
assert dtype == "float"
|
||||||
|
|
||||||
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
|
vllm_model.model.llm_engine.scheduler[
|
||||||
|
0].ENABLE_ARTIFICIAL_PREEMPT = True
|
||||||
|
preempt_vllm_outputs = vllm_model.generate_greedy(
|
||||||
|
example_prompts, max_tokens)
|
||||||
|
|
||||||
|
vllm_model.model.llm_engine.scheduler[
|
||||||
|
0].ENABLE_ARTIFICIAL_PREEMPT = False
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
|
check_outputs_equal(
|
||||||
|
outputs_0_lst=preempt_vllm_outputs,
|
||||||
|
outputs_1_lst=vllm_outputs,
|
||||||
|
name_0="vllm_preepmtions",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
|
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
|
||||||
|
vllm_runner,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
example_prompts,
|
||||||
|
) -> None:
|
||||||
|
# This test is for verifying that the Jamba inner state management doesn't
|
||||||
|
# collapse in case where the number of incoming requests and
|
||||||
|
# finished_requests_ids is larger than the maximum mamba block capacity.
|
||||||
|
# This could generally happen due to the fact that Jamba does support
|
||||||
|
# statelessness mechanism where it can cleanup new incoming requests in
|
||||||
|
# a single step.
|
||||||
|
try:
|
||||||
|
with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model:
|
||||||
|
vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
|
||||||
|
except ValueError:
|
||||||
|
pytest.fail("Jamba inner state wasn't cleaned up properly between"
|
||||||
|
"steps finished requests registered unnecessarily ")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["float"])
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
def test_state_cleanup(
|
def test_state_cleanup(
|
||||||
|
|||||||
@ -374,6 +374,7 @@ class Scheduler:
|
|||||||
for aborted_group in aborted_groups:
|
for aborted_group in aborted_groups:
|
||||||
# Remove the sequence group from the state queue.
|
# Remove the sequence group from the state queue.
|
||||||
state_queue.remove(aborted_group)
|
state_queue.remove(aborted_group)
|
||||||
|
self._finished_requests_ids.append(aborted_group.request_id)
|
||||||
for seq in aborted_group.get_seqs():
|
for seq in aborted_group.get_seqs():
|
||||||
if seq.is_finished():
|
if seq.is_finished():
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -32,7 +32,8 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
||||||
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
|
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
|
||||||
pt_weights_iterator, safetensors_weights_iterator)
|
pt_weights_iterator, safetensors_weights_iterator)
|
||||||
from vllm.model_executor.models.interfaces import (supports_lora,
|
from vllm.model_executor.models.interfaces import (has_inner_state,
|
||||||
|
supports_lora,
|
||||||
supports_vision)
|
supports_vision)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -69,7 +70,7 @@ def _get_model_initialization_kwargs(
|
|||||||
model_class: Type[nn.Module],
|
model_class: Type[nn.Module],
|
||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
multimodal_config: Optional[MultiModalConfig],
|
multimodal_config: Optional[MultiModalConfig],
|
||||||
) -> Dict[str, Any]:
|
scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]:
|
||||||
"""Get extra kwargs for model initialization."""
|
"""Get extra kwargs for model initialization."""
|
||||||
extra_kwargs: Dict[str, Any] = {}
|
extra_kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
@ -90,13 +91,19 @@ def _get_model_initialization_kwargs(
|
|||||||
|
|
||||||
extra_kwargs["multimodal_config"] = multimodal_config
|
extra_kwargs["multimodal_config"] = multimodal_config
|
||||||
|
|
||||||
|
if has_inner_state(model_class) and scheduler_config:
|
||||||
|
extra_kwargs["scheduler_config"] = scheduler_config
|
||||||
|
|
||||||
return extra_kwargs
|
return extra_kwargs
|
||||||
|
|
||||||
|
|
||||||
def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
|
def _initialize_model(
|
||||||
|
model_config: ModelConfig,
|
||||||
|
load_config: LoadConfig,
|
||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
multimodal_config: Optional[MultiModalConfig],
|
multimodal_config: Optional[MultiModalConfig],
|
||||||
cache_config: CacheConfig) -> nn.Module:
|
cache_config: CacheConfig,
|
||||||
|
scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
|
||||||
"""Initialize a model with the given configurations."""
|
"""Initialize a model with the given configurations."""
|
||||||
model_class = get_model_architecture(model_config)[0]
|
model_class = get_model_architecture(model_config)[0]
|
||||||
quant_config = _get_quantization_config(model_config, load_config)
|
quant_config = _get_quantization_config(model_config, load_config)
|
||||||
@ -105,7 +112,8 @@ def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
**_get_model_initialization_kwargs(
|
**_get_model_initialization_kwargs(
|
||||||
model_class, lora_config, multimodal_config))
|
model_class, lora_config, multimodal_config,
|
||||||
|
scheduler_config))
|
||||||
|
|
||||||
|
|
||||||
class BaseModelLoader(ABC):
|
class BaseModelLoader(ABC):
|
||||||
@ -266,7 +274,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
with torch.device(device_config.device):
|
with torch.device(device_config.device):
|
||||||
model = _initialize_model(model_config, self.load_config,
|
model = _initialize_model(model_config, self.load_config,
|
||||||
lora_config, multimodal_config,
|
lora_config, multimodal_config,
|
||||||
cache_config)
|
cache_config, scheduler_config)
|
||||||
model.load_weights(
|
model.load_weights(
|
||||||
self._get_weights_iterator(model_config.model,
|
self._get_weights_iterator(model_config.model,
|
||||||
model_config.revision,
|
model_config.revision,
|
||||||
@ -302,7 +310,7 @@ class DummyModelLoader(BaseModelLoader):
|
|||||||
with torch.device(device_config.device):
|
with torch.device(device_config.device):
|
||||||
model = _initialize_model(model_config, self.load_config,
|
model = _initialize_model(model_config, self.load_config,
|
||||||
lora_config, multimodal_config,
|
lora_config, multimodal_config,
|
||||||
cache_config)
|
cache_config, scheduler_config)
|
||||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||||
# random values to the weights.
|
# random values to the weights.
|
||||||
initialize_dummy_weights(model)
|
initialize_dummy_weights(model)
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
|
|||||||
|
|
||||||
from typing_extensions import TypeGuard
|
from typing_extensions import TypeGuard
|
||||||
|
|
||||||
from vllm.config import LoRAConfig, MultiModalConfig
|
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -142,3 +142,49 @@ def _supports_lora(
|
|||||||
return isinstance(model, _SupportsLoRAType)
|
return isinstance(model, _SupportsLoRAType)
|
||||||
|
|
||||||
return isinstance(model, SupportsLoRA)
|
return isinstance(model, SupportsLoRA)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class HasInnerState(Protocol):
|
||||||
|
"""The interface required for all models that has inner state."""
|
||||||
|
|
||||||
|
has_inner_state: ClassVar[Literal[True]] = True
|
||||||
|
"""
|
||||||
|
A flag that indicates this model has inner state.
|
||||||
|
Models that has inner state usually need access to the scheduler_config
|
||||||
|
for max_num_seqs ,etc... (Currently only used by Jamba)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
scheduler_config: Optional[SchedulerConfig] = None) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class _HasInnerStateType(Protocol):
|
||||||
|
has_inner_state: ClassVar[Literal[True]]
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
scheduler_config: Optional[SchedulerConfig] = None) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def has_inner_state(model: object) -> TypeGuard[HasInnerState]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def has_inner_state(
|
||||||
|
model: Union[Type[object], object]
|
||||||
|
) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]:
|
||||||
|
if isinstance(model, type):
|
||||||
|
return isinstance(model, _HasInnerStateType)
|
||||||
|
|
||||||
|
return isinstance(model, HasInnerState)
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from transformers import JambaConfig
|
|||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import CacheConfig, LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
@ -32,10 +32,12 @@ from vllm.model_executor.layers.sampler import Sampler
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.interfaces import HasInnerState
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE
|
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
|
||||||
|
_get_graph_batch_size)
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
@ -612,7 +614,7 @@ class JambaModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class JambaForCausalLM(nn.Module):
|
class JambaForCausalLM(nn.Module, HasInnerState):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -640,9 +642,11 @@ class JambaForCausalLM(nn.Module):
|
|||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
scheduler_config: Optional[SchedulerConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
self.model = JambaModel(config,
|
self.model = JambaModel(config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
@ -689,6 +693,8 @@ class JambaForCausalLM(nn.Module):
|
|||||||
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
||||||
|
|
||||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||||
|
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||||
|
self._release_mamba_cache(finished_requests_ids)
|
||||||
batch_size = input_ids.shape[0]
|
batch_size = input_ids.shape[0]
|
||||||
if attn_metadata.prefill_metadata:
|
if attn_metadata.prefill_metadata:
|
||||||
batch_size = len(request_ids_to_seq_ids)
|
batch_size = len(request_ids_to_seq_ids)
|
||||||
@ -696,9 +702,8 @@ class JambaForCausalLM(nn.Module):
|
|||||||
current_seqlen_agnostic_cache,
|
current_seqlen_agnostic_cache,
|
||||||
indices,
|
indices,
|
||||||
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
||||||
batch_size)
|
batch_size,
|
||||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
finished_requests_ids)
|
||||||
self._release_mamba_cache(finished_requests_ids)
|
|
||||||
else:
|
else:
|
||||||
# CUDA graph capturing runs
|
# CUDA graph capturing runs
|
||||||
current_seqlen_agnostic_cache, indices = (
|
current_seqlen_agnostic_cache, indices = (
|
||||||
@ -760,10 +765,15 @@ class JambaForCausalLM(nn.Module):
|
|||||||
return indices_for_current_run
|
return indices_for_current_run
|
||||||
|
|
||||||
def _prepare_current_run_mamba_cache(
|
def _prepare_current_run_mamba_cache(
|
||||||
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int
|
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int,
|
||||||
|
finished_requests_ids: List[str]
|
||||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]:
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]:
|
||||||
indices_for_current_run = []
|
indices_for_current_run = []
|
||||||
for request_id, seqs_id in request_ids_to_seq_ids.items():
|
for request_id, seqs_id in request_ids_to_seq_ids.items():
|
||||||
|
if request_id in finished_requests_ids:
|
||||||
|
# Do not allocate cache for requests that run
|
||||||
|
# and finish right after
|
||||||
|
continue
|
||||||
indices_for_current_run += self._assign_seq_id_to_mamba_cache(
|
indices_for_current_run += self._assign_seq_id_to_mamba_cache(
|
||||||
request_id, seqs_id)
|
request_id, seqs_id)
|
||||||
## Pad the batch in case of running batch that was not captured via CG
|
## Pad the batch in case of running batch that was not captured via CG
|
||||||
@ -787,16 +797,17 @@ class JambaForCausalLM(nn.Module):
|
|||||||
assert all(
|
assert all(
|
||||||
key in kwargs
|
key in kwargs
|
||||||
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
||||||
|
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||||
|
self._release_mamba_cache(finished_requests_ids)
|
||||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||||
cg_batch_size = input_buffers['input_ids'].shape[0]
|
cg_batch_size = input_buffers['input_ids'].shape[0]
|
||||||
(
|
(
|
||||||
current_mamba_cache,
|
current_mamba_cache,
|
||||||
indices,
|
indices,
|
||||||
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
||||||
cg_batch_size)
|
cg_batch_size,
|
||||||
|
finished_requests_ids)
|
||||||
self.current_indices = indices
|
self.current_indices = indices
|
||||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
|
||||||
self._release_mamba_cache(finished_requests_ids)
|
|
||||||
|
|
||||||
for input_buffer, current_cache_buffer in zip(
|
for input_buffer, current_cache_buffer in zip(
|
||||||
input_buffers["seqlen_agnostic_capture_inputs"],
|
input_buffers["seqlen_agnostic_capture_inputs"],
|
||||||
@ -860,9 +871,12 @@ class JambaForCausalLM(nn.Module):
|
|||||||
layers_type = self.config.layers_block_type
|
layers_type = self.config.layers_block_type
|
||||||
mamba_layers = sum(
|
mamba_layers = sum(
|
||||||
[layer_type == "mamba" for layer_type in layers_type])
|
[layer_type == "mamba" for layer_type in layers_type])
|
||||||
max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10
|
max_batch_size = (_get_graph_batch_size(
|
||||||
|
self.scheduler_config.max_num_seqs) if self.scheduler_config else
|
||||||
|
max(_BATCH_SIZES_TO_CAPTURE)) + 10
|
||||||
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
|
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
|
||||||
assert conv_state_shape is not None and temporal_state_shape is not None
|
assert conv_state_shape is not None and temporal_state_shape is not None
|
||||||
|
|
||||||
for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]:
|
for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]:
|
||||||
buffer = (torch.empty(size=(mamba_layers, max_batch_size) +
|
buffer = (torch.empty(size=(mamba_layers, max_batch_size) +
|
||||||
conv_state_shape,
|
conv_state_shape,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user