[BugFix][Model] Jamba - Handle aborted requests, Add tests and fix cleanup bug (#6425)

Co-authored-by: Mor Zusman <morz@ai21.com>
This commit is contained in:
Mor Zusman 2024-07-16 04:32:55 +03:00 committed by GitHub
parent d6f3b3d5c4
commit 9ad32dacd9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 176 additions and 24 deletions

View File

@ -1,5 +1,6 @@
import pytest import pytest
from tests.models.utils import check_outputs_equal
from vllm.worker.model_runner import _get_graph_batch_size from vllm.worker.model_runner import _get_graph_batch_size
MODELS = ["ai21labs/Jamba-tiny-random"] MODELS = ["ai21labs/Jamba-tiny-random"]
@ -34,6 +35,34 @@ def test_models(
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
def test_batching(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
for_loop_outputs = []
with vllm_runner(model, dtype=dtype) as vllm_model:
for prompt in example_prompts:
for_loop_outputs.append(
vllm_model.generate_greedy([prompt], max_tokens)[0])
batched_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
check_outputs_equal(
outputs_0_lst=for_loop_outputs,
outputs_1_lst=batched_outputs,
name_0="for_loop_vllm",
name_1="batched_vllm",
)
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [20]) @pytest.mark.parametrize("max_tokens", [20])
@ -60,6 +89,60 @@ def test_mamba_cache_cg_padding(
"Could be related to mamba cache not padded correctly") "Could be related to mamba cache not padded correctly")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
def test_models_preemption_recompute(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# Tests that outputs are identical with and w/o preemtions (recompute)
assert dtype == "float"
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_model.model.llm_engine.scheduler[
0].ENABLE_ARTIFICIAL_PREEMPT = True
preempt_vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
vllm_model.model.llm_engine.scheduler[
0].ENABLE_ARTIFICIAL_PREEMPT = False
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=preempt_vllm_outputs,
outputs_1_lst=vllm_outputs,
name_0="vllm_preepmtions",
name_1="vllm",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
# This test is for verifying that the Jamba inner state management doesn't
# collapse in case where the number of incoming requests and
# finished_requests_ids is larger than the maximum mamba block capacity.
# This could generally happen due to the fact that Jamba does support
# statelessness mechanism where it can cleanup new incoming requests in
# a single step.
try:
with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
except ValueError:
pytest.fail("Jamba inner state wasn't cleaned up properly between"
"steps finished requests registered unnecessarily ")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("dtype", ["float"])
def test_state_cleanup( def test_state_cleanup(

View File

@ -374,6 +374,7 @@ class Scheduler:
for aborted_group in aborted_groups: for aborted_group in aborted_groups:
# Remove the sequence group from the state queue. # Remove the sequence group from the state queue.
state_queue.remove(aborted_group) state_queue.remove(aborted_group)
self._finished_requests_ids.append(aborted_group.request_id)
for seq in aborted_group.get_seqs(): for seq in aborted_group.get_seqs():
if seq.is_finished(): if seq.is_finished():
continue continue

View File

@ -32,7 +32,8 @@ from vllm.model_executor.model_loader.weight_utils import (
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator, get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator) pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.interfaces import (supports_lora, from vllm.model_executor.models.interfaces import (has_inner_state,
supports_lora,
supports_vision) supports_vision)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -69,7 +70,7 @@ def _get_model_initialization_kwargs(
model_class: Type[nn.Module], model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig], multimodal_config: Optional[MultiModalConfig],
) -> Dict[str, Any]: scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]:
"""Get extra kwargs for model initialization.""" """Get extra kwargs for model initialization."""
extra_kwargs: Dict[str, Any] = {} extra_kwargs: Dict[str, Any] = {}
@ -90,13 +91,19 @@ def _get_model_initialization_kwargs(
extra_kwargs["multimodal_config"] = multimodal_config extra_kwargs["multimodal_config"] = multimodal_config
if has_inner_state(model_class) and scheduler_config:
extra_kwargs["scheduler_config"] = scheduler_config
return extra_kwargs return extra_kwargs
def _initialize_model(model_config: ModelConfig, load_config: LoadConfig, def _initialize_model(
model_config: ModelConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig], multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig,
scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
"""Initialize a model with the given configurations.""" """Initialize a model with the given configurations."""
model_class = get_model_architecture(model_config)[0] model_class = get_model_architecture(model_config)[0]
quant_config = _get_quantization_config(model_config, load_config) quant_config = _get_quantization_config(model_config, load_config)
@ -105,7 +112,8 @@ def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
**_get_model_initialization_kwargs( **_get_model_initialization_kwargs(
model_class, lora_config, multimodal_config)) model_class, lora_config, multimodal_config,
scheduler_config))
class BaseModelLoader(ABC): class BaseModelLoader(ABC):
@ -266,7 +274,7 @@ class DefaultModelLoader(BaseModelLoader):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config, lora_config, multimodal_config,
cache_config) cache_config, scheduler_config)
model.load_weights( model.load_weights(
self._get_weights_iterator(model_config.model, self._get_weights_iterator(model_config.model,
model_config.revision, model_config.revision,
@ -302,7 +310,7 @@ class DummyModelLoader(BaseModelLoader):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config, lora_config, multimodal_config,
cache_config) cache_config, scheduler_config)
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights. # random values to the weights.
initialize_dummy_weights(model) initialize_dummy_weights(model)

View File

@ -3,7 +3,7 @@ from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
from typing_extensions import TypeGuard from typing_extensions import TypeGuard
from vllm.config import LoRAConfig, MultiModalConfig from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
@ -142,3 +142,49 @@ def _supports_lora(
return isinstance(model, _SupportsLoRAType) return isinstance(model, _SupportsLoRAType)
return isinstance(model, SupportsLoRA) return isinstance(model, SupportsLoRA)
@runtime_checkable
class HasInnerState(Protocol):
"""The interface required for all models that has inner state."""
has_inner_state: ClassVar[Literal[True]] = True
"""
A flag that indicates this model has inner state.
Models that has inner state usually need access to the scheduler_config
for max_num_seqs ,etc... (Currently only used by Jamba)
"""
def __init__(self,
*,
scheduler_config: Optional[SchedulerConfig] = None) -> None:
...
@runtime_checkable
class _HasInnerStateType(Protocol):
has_inner_state: ClassVar[Literal[True]]
def __init__(self,
*,
scheduler_config: Optional[SchedulerConfig] = None) -> None:
...
@overload
def has_inner_state(model: object) -> TypeGuard[HasInnerState]:
...
@overload
def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]:
...
def has_inner_state(
model: Union[Type[object], object]
) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]:
if isinstance(model, type):
return isinstance(model, _HasInnerStateType)
return isinstance(model, HasInnerState)

View File

@ -13,7 +13,7 @@ from transformers import JambaConfig
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
@ -32,10 +32,12 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import HasInnerState
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -612,7 +614,7 @@ class JambaModel(nn.Module):
return hidden_states return hidden_states
class JambaForCausalLM(nn.Module): class JambaForCausalLM(nn.Module, HasInnerState):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
@ -640,9 +642,11 @@ class JambaForCausalLM(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.scheduler_config = scheduler_config
self.model = JambaModel(config, self.model = JambaModel(config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
@ -689,6 +693,8 @@ class JambaForCausalLM(nn.Module):
for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
batch_size = input_ids.shape[0] batch_size = input_ids.shape[0]
if attn_metadata.prefill_metadata: if attn_metadata.prefill_metadata:
batch_size = len(request_ids_to_seq_ids) batch_size = len(request_ids_to_seq_ids)
@ -696,9 +702,8 @@ class JambaForCausalLM(nn.Module):
current_seqlen_agnostic_cache, current_seqlen_agnostic_cache,
indices, indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
batch_size) batch_size,
finished_requests_ids = kwargs["finished_requests_ids"] finished_requests_ids)
self._release_mamba_cache(finished_requests_ids)
else: else:
# CUDA graph capturing runs # CUDA graph capturing runs
current_seqlen_agnostic_cache, indices = ( current_seqlen_agnostic_cache, indices = (
@ -760,10 +765,15 @@ class JambaForCausalLM(nn.Module):
return indices_for_current_run return indices_for_current_run
def _prepare_current_run_mamba_cache( def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int,
finished_requests_ids: List[str]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]:
indices_for_current_run = [] indices_for_current_run = []
for request_id, seqs_id in request_ids_to_seq_ids.items(): for request_id, seqs_id in request_ids_to_seq_ids.items():
if request_id in finished_requests_ids:
# Do not allocate cache for requests that run
# and finish right after
continue
indices_for_current_run += self._assign_seq_id_to_mamba_cache( indices_for_current_run += self._assign_seq_id_to_mamba_cache(
request_id, seqs_id) request_id, seqs_id)
## Pad the batch in case of running batch that was not captured via CG ## Pad the batch in case of running batch that was not captured via CG
@ -787,16 +797,17 @@ class JambaForCausalLM(nn.Module):
assert all( assert all(
key in kwargs key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
cg_batch_size = input_buffers['input_ids'].shape[0] cg_batch_size = input_buffers['input_ids'].shape[0]
( (
current_mamba_cache, current_mamba_cache,
indices, indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
cg_batch_size) cg_batch_size,
finished_requests_ids)
self.current_indices = indices self.current_indices = indices
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_mamba_cache(finished_requests_ids)
for input_buffer, current_cache_buffer in zip( for input_buffer, current_cache_buffer in zip(
input_buffers["seqlen_agnostic_capture_inputs"], input_buffers["seqlen_agnostic_capture_inputs"],
@ -860,9 +871,12 @@ class JambaForCausalLM(nn.Module):
layers_type = self.config.layers_block_type layers_type = self.config.layers_block_type
mamba_layers = sum( mamba_layers = sum(
[layer_type == "mamba" for layer_type in layers_type]) [layer_type == "mamba" for layer_type in layers_type])
max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10 max_batch_size = (_get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config else
max(_BATCH_SIZES_TO_CAPTURE)) + 10
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
assert conv_state_shape is not None and temporal_state_shape is not None assert conv_state_shape is not None and temporal_state_shape is not None
for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]:
buffer = (torch.empty(size=(mamba_layers, max_batch_size) + buffer = (torch.empty(size=(mamba_layers, max_batch_size) +
conv_state_shape, conv_state_shape,