mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 08:04:27 +08:00
[Model] Jamba support (#4115)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai> Co-authored-by: Erez Schwartz <erezs@ai21.com> Co-authored-by: Mor Zusman <morz@ai21.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Tomer Asida <tomera@ai21.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
parent
ee93f4f92a
commit
9d6a8daa87
@ -23,4 +23,4 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
|
|||||||
docker exec cpu-test bash -c "cd tests;
|
docker exec cpu-test bash -c "cd tests;
|
||||||
pip install pytest Pillow protobuf
|
pip install pytest Pillow protobuf
|
||||||
cd ../
|
cd ../
|
||||||
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py"
|
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported
|
||||||
|
|||||||
23
Dockerfile
23
Dockerfile
@ -43,6 +43,10 @@ COPY requirements-cuda.txt requirements-cuda.txt
|
|||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
python3 -m pip install -r requirements-cuda.txt
|
python3 -m pip install -r requirements-cuda.txt
|
||||||
|
|
||||||
|
COPY requirements-mamba.txt requirements-mamba.txt
|
||||||
|
RUN python3 -m pip install packaging
|
||||||
|
RUN python3 -m pip install -r requirements-mamba.txt
|
||||||
|
|
||||||
# cuda arch list used by torch
|
# cuda arch list used by torch
|
||||||
# can be useful for both `dev` and `test`
|
# can be useful for both `dev` and `test`
|
||||||
# explicitly set the list to avoid issues with torch 2.2
|
# explicitly set the list to avoid issues with torch 2.2
|
||||||
@ -123,6 +127,21 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
|||||||
python3 -m pip install -r requirements-dev.txt
|
python3 -m pip install -r requirements-dev.txt
|
||||||
|
|
||||||
#################### DEV IMAGE ####################
|
#################### DEV IMAGE ####################
|
||||||
|
#################### MAMBA Build IMAGE ####################
|
||||||
|
FROM dev as mamba-builder
|
||||||
|
# max jobs used for build
|
||||||
|
ARG max_jobs=2
|
||||||
|
ENV MAX_JOBS=${max_jobs}
|
||||||
|
|
||||||
|
WORKDIR /usr/src/mamba
|
||||||
|
|
||||||
|
COPY requirements-mamba.txt requirements-mamba.txt
|
||||||
|
|
||||||
|
# Download the wheel or build it if a pre-compiled release doesn't exist
|
||||||
|
RUN pip --verbose wheel -r requirements-mamba.txt \
|
||||||
|
--no-build-isolation --no-deps --no-cache-dir
|
||||||
|
|
||||||
|
#################### MAMBA Build IMAGE ####################
|
||||||
|
|
||||||
#################### vLLM installation IMAGE ####################
|
#################### vLLM installation IMAGE ####################
|
||||||
# image with vLLM installed
|
# image with vLLM installed
|
||||||
@ -143,6 +162,10 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
|
|||||||
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
||||||
--mount=type=cache,target=/root/.cache/pip \
|
--mount=type=cache,target=/root/.cache/pip \
|
||||||
python3 -m pip install dist/*.whl --verbose
|
python3 -m pip install dist/*.whl --verbose
|
||||||
|
|
||||||
|
RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \
|
||||||
|
--mount=type=cache,target=/root/.cache/pip \
|
||||||
|
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
|
||||||
#################### vLLM installation IMAGE ####################
|
#################### vLLM installation IMAGE ####################
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -87,6 +87,10 @@ Alongside each architecture, we include some popular models that use it.
|
|||||||
- Jais
|
- Jais
|
||||||
- :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc.
|
- :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc.
|
||||||
-
|
-
|
||||||
|
* - :code:`JambaForCausalLM`
|
||||||
|
- Jamba
|
||||||
|
- :code:`ai21labs/Jamba-v0.1`, etc.
|
||||||
|
- ✅︎
|
||||||
* - :code:`LlamaForCausalLM`
|
* - :code:`LlamaForCausalLM`
|
||||||
- LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi
|
- LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi
|
||||||
- :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
|
- :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
|
||||||
|
|||||||
3
requirements-mamba.txt
Normal file
3
requirements-mamba.txt
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# Mamba dependencies
|
||||||
|
mamba-ssm>=1.2.2
|
||||||
|
causal-conv1d>=1.2.0
|
||||||
65
tests/models/test_jamba.py
Normal file
65
tests/models/test_jamba.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
MODELS = ["ai21labs/Jamba-tiny-random"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [20])
|
||||||
|
def test_models(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
# To pass the small model tests, we need full precision.
|
||||||
|
assert dtype == "float"
|
||||||
|
|
||||||
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
|
for i in range(len(example_prompts)):
|
||||||
|
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||||
|
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||||
|
assert hf_output_str == vllm_output_str, (
|
||||||
|
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||||
|
assert hf_output_ids == vllm_output_ids, (
|
||||||
|
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
|
def test_state_cleanup(
|
||||||
|
vllm_runner,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
example_prompts,
|
||||||
|
) -> None:
|
||||||
|
# This test is for verifying that the Jamba state is cleaned up between
|
||||||
|
# steps, If its not cleaned, an error would be expected.
|
||||||
|
try:
|
||||||
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
|
for _ in range(10):
|
||||||
|
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
|
||||||
|
except ValueError:
|
||||||
|
pytest.fail("Jamba inner state wasn't cleaned up between states, "
|
||||||
|
"could be related to finished_requests_ids")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
|
def test_model_print(
|
||||||
|
vllm_runner,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
) -> None:
|
||||||
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
|
# This test is for verifying whether the model's extra_repr
|
||||||
|
# can be printed correctly.
|
||||||
|
print(vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||||
|
model_runner.model)
|
||||||
@ -386,9 +386,36 @@ class ModelConfig:
|
|||||||
return num_heads // parallel_config.tensor_parallel_size
|
return num_heads // parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||||
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
|
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||||
|
"num_hidden_layers", 0)
|
||||||
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
||||||
|
|
||||||
|
def contains_seqlen_agnostic_layers(
|
||||||
|
self, parallel_config: "ParallelConfig") -> bool:
|
||||||
|
"""True for Mamba/SSM models (Jamba)"""
|
||||||
|
return self._get_num_seqlen_agnostic_layers(parallel_config) > 0
|
||||||
|
|
||||||
|
def get_layers_block_type(self,
|
||||||
|
parallel_config: "ParallelConfig") -> List[str]:
|
||||||
|
num_layers = self.get_num_layers(parallel_config)
|
||||||
|
# Transformers supports layers_block_type @property
|
||||||
|
return getattr(self.hf_config, "layers_block_type",
|
||||||
|
["attention"] * num_layers)
|
||||||
|
|
||||||
|
def get_num_attention_layers(self,
|
||||||
|
parallel_config: "ParallelConfig") -> int:
|
||||||
|
return len([
|
||||||
|
t for t in self.get_layers_block_type(parallel_config)
|
||||||
|
if t == "attention"
|
||||||
|
])
|
||||||
|
|
||||||
|
def _get_num_seqlen_agnostic_layers(
|
||||||
|
self, parallel_config: "ParallelConfig") -> int:
|
||||||
|
return len([
|
||||||
|
t for t in self.get_layers_block_type(parallel_config)
|
||||||
|
if t != "attention"
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
class CacheConfig:
|
class CacheConfig:
|
||||||
"""Configuration for the KV cache.
|
"""Configuration for the KV cache.
|
||||||
|
|||||||
@ -299,7 +299,10 @@ class Scheduler:
|
|||||||
# Sequence groups in the SWAPPED state.
|
# Sequence groups in the SWAPPED state.
|
||||||
# Contain decode requests that are swapped out.
|
# Contain decode requests that are swapped out.
|
||||||
self.swapped: Deque[SequenceGroup] = deque()
|
self.swapped: Deque[SequenceGroup] = deque()
|
||||||
|
# Sequence groups finished requests ids since last step iteration.
|
||||||
|
# It lets the model know that any state associated with these requests
|
||||||
|
# can and must be released after the current step.
|
||||||
|
self._finished_requests_ids: List[str] = list()
|
||||||
# Time at previous scheduling step
|
# Time at previous scheduling step
|
||||||
self.prev_time = 0.0
|
self.prev_time = 0.0
|
||||||
# Did we schedule a prompt at previous step?
|
# Did we schedule a prompt at previous step?
|
||||||
@ -373,6 +376,12 @@ class Scheduler:
|
|||||||
def get_num_unfinished_seq_groups(self) -> int:
|
def get_num_unfinished_seq_groups(self) -> int:
|
||||||
return len(self.waiting) + len(self.running) + len(self.swapped)
|
return len(self.waiting) + len(self.running) + len(self.swapped)
|
||||||
|
|
||||||
|
def get_and_reset_finished_requests_ids(self) -> List[str]:
|
||||||
|
"""Flushes the list of request ids of previously finished seq_groups."""
|
||||||
|
finished_requests_ids = self._finished_requests_ids
|
||||||
|
self._finished_requests_ids = list()
|
||||||
|
return finished_requests_ids
|
||||||
|
|
||||||
def _schedule_running(
|
def _schedule_running(
|
||||||
self,
|
self,
|
||||||
running_queue: deque,
|
running_queue: deque,
|
||||||
@ -1036,6 +1045,11 @@ class Scheduler:
|
|||||||
self.block_manager.free(seq)
|
self.block_manager.free(seq)
|
||||||
|
|
||||||
def free_finished_seq_groups(self) -> None:
|
def free_finished_seq_groups(self) -> None:
|
||||||
|
for queue in [self.running, self.swapped, self.waiting]:
|
||||||
|
self._finished_requests_ids += [
|
||||||
|
seq_group.request_id for seq_group in queue
|
||||||
|
if seq_group.is_finished()
|
||||||
|
]
|
||||||
self.running = deque(seq_group for seq_group in self.running
|
self.running = deque(seq_group for seq_group in self.running
|
||||||
if not seq_group.is_finished())
|
if not seq_group.is_finished())
|
||||||
|
|
||||||
|
|||||||
@ -224,6 +224,8 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
"""
|
"""
|
||||||
seq_group_metadata_list, scheduler_outputs = self.scheduler[
|
seq_group_metadata_list, scheduler_outputs = self.scheduler[
|
||||||
virtual_engine].schedule()
|
virtual_engine].schedule()
|
||||||
|
finished_requests_ids = self.scheduler[
|
||||||
|
virtual_engine].get_and_reset_finished_requests_ids()
|
||||||
|
|
||||||
if not scheduler_outputs.is_empty():
|
if not scheduler_outputs.is_empty():
|
||||||
# Execute the model.
|
# Execute the model.
|
||||||
@ -235,7 +237,7 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
virtual_engine=virtual_engine,
|
virtual_engine=virtual_engine,
|
||||||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
||||||
running_queue_size=scheduler_outputs.running_queue_size,
|
running_queue_size=scheduler_outputs.running_queue_size,
|
||||||
)
|
finished_requests_ids=finished_requests_ids)
|
||||||
output = await self.model_executor.execute_model_async(
|
output = await self.model_executor.execute_model_async(
|
||||||
execute_model_req)
|
execute_model_req)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -846,6 +846,8 @@ class LLMEngine:
|
|||||||
"as performance will be severely degraded otherwise.")
|
"as performance will be severely degraded otherwise.")
|
||||||
seq_group_metadata_list, scheduler_outputs = self.scheduler[
|
seq_group_metadata_list, scheduler_outputs = self.scheduler[
|
||||||
0].schedule()
|
0].schedule()
|
||||||
|
finished_requests_ids = self.scheduler[
|
||||||
|
0].get_and_reset_finished_requests_ids()
|
||||||
|
|
||||||
if not scheduler_outputs.is_empty():
|
if not scheduler_outputs.is_empty():
|
||||||
execute_model_req = ExecuteModelRequest(
|
execute_model_req = ExecuteModelRequest(
|
||||||
@ -855,7 +857,7 @@ class LLMEngine:
|
|||||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
||||||
running_queue_size=scheduler_outputs.running_queue_size,
|
running_queue_size=scheduler_outputs.running_queue_size,
|
||||||
)
|
finished_requests_ids=finished_requests_ids)
|
||||||
output = self.model_executor.execute_model(
|
output = self.model_executor.execute_model(
|
||||||
execute_model_req=execute_model_req)
|
execute_model_req=execute_model_req)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -63,6 +63,7 @@ _GENERATION_MODELS = {
|
|||||||
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
||||||
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
||||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||||
|
"JambaForCausalLM": ("jamba", "JambaForCausalLM")
|
||||||
}
|
}
|
||||||
|
|
||||||
_EMBEDDING_MODELS = {
|
_EMBEDDING_MODELS = {
|
||||||
|
|||||||
955
vllm/model_executor/models/jamba.py
Normal file
955
vllm/model_executor/models/jamba.py
Normal file
@ -0,0 +1,955 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""Inference-only Jurassic model."""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||||
|
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
||||||
|
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
from transformers import JambaConfig
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_reduce)
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
MergedColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
|
from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MambaCacheParams:
|
||||||
|
is_prompt: bool = False
|
||||||
|
conv_state: torch.Tensor = torch.Tensor()
|
||||||
|
ssm_state: torch.Tensor = torch.Tensor()
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
||||||
|
class JambaMambaMixer(nn.Module):
|
||||||
|
"""
|
||||||
|
Compute ∆, A, B, C, and D the state space parameters and compute
|
||||||
|
the `contextualized_states`. A, D are input independent
|
||||||
|
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
|
||||||
|
for why A isn't selective) ∆, B, C are input-dependent
|
||||||
|
(this is a key difference between Mamba and the linear time
|
||||||
|
invariant S4, and is why Mamba is called
|
||||||
|
**selective** state spaces)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: JambaConfig, layer_idx):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.ssm_state_size = config.mamba_d_state
|
||||||
|
self.conv_kernel_size = config.mamba_d_conv
|
||||||
|
self.intermediate_size = config.mamba_expand * config.hidden_size
|
||||||
|
self.time_step_rank = config.mamba_dt_rank
|
||||||
|
self.use_conv_bias = config.mamba_conv_bias
|
||||||
|
self.use_bias = config.mamba_proj_bias
|
||||||
|
self.conv1d = ColumnParallelLinear(
|
||||||
|
input_size=self.conv_kernel_size,
|
||||||
|
output_size=self.intermediate_size,
|
||||||
|
bias=self.use_conv_bias,
|
||||||
|
)
|
||||||
|
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
||||||
|
# Can't do this in `weight_loader` since it already exists in
|
||||||
|
# `ColumnParallelLinear` and `set_weight_attrs`
|
||||||
|
# doesn't allow to override it
|
||||||
|
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||||
|
|
||||||
|
self.in_proj = MergedColumnParallelLinear(self.hidden_size,
|
||||||
|
[self.intermediate_size] * 2,
|
||||||
|
bias=self.use_bias)
|
||||||
|
# selective projection used to make dt, B and C input dependent
|
||||||
|
self.x_proj = RowParallelLinear(
|
||||||
|
self.intermediate_size,
|
||||||
|
self.time_step_rank + self.ssm_state_size * 2,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
# time step projection (discretization) -
|
||||||
|
# In the forward we need to apply dt_proj without the bias,
|
||||||
|
# as the bias is added in the selective scan kernel.
|
||||||
|
self.dt_proj = ColumnParallelLinear(self.time_step_rank,
|
||||||
|
self.intermediate_size,
|
||||||
|
bias=True,
|
||||||
|
skip_bias_add=True)
|
||||||
|
|
||||||
|
def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
param.data.copy_(
|
||||||
|
loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
|
||||||
|
dim=0)[tp_rank])
|
||||||
|
|
||||||
|
def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
|
||||||
|
weight_loader(param, -torch.exp(loaded_weight.float()))
|
||||||
|
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.A = nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.intermediate_size // tp_size,
|
||||||
|
self.ssm_state_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
))
|
||||||
|
self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
|
||||||
|
|
||||||
|
set_weight_attrs(self.D, {"weight_loader": weight_loader})
|
||||||
|
set_weight_attrs(self.A, {"weight_loader": A_weight_loader})
|
||||||
|
|
||||||
|
self.out_proj = RowParallelLinear(
|
||||||
|
self.intermediate_size,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=self.use_bias,
|
||||||
|
input_is_parallel=True,
|
||||||
|
)
|
||||||
|
self.activation = config.hidden_act
|
||||||
|
|
||||||
|
self.dt_layernorm = RMSNorm(self.time_step_rank,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.b_layernorm = RMSNorm(self.ssm_state_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.c_layernorm = RMSNorm(self.ssm_state_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def mamba_forward(self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cache_params: MambaCacheParams = None):
|
||||||
|
# 1. Gated MLP's linear projection
|
||||||
|
projected_states = self.in_proj(hidden_states)[0].transpose(1, 2)
|
||||||
|
hidden_states, gate = projected_states.chunk(2, dim=1)
|
||||||
|
|
||||||
|
# 2. Convolution sequence transformation
|
||||||
|
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||||
|
self.conv1d.weight.size(2))
|
||||||
|
if cache_params is not None and not cache_params.is_prompt:
|
||||||
|
hidden_states = causal_conv1d_update(
|
||||||
|
hidden_states.squeeze(-1),
|
||||||
|
cache_params.conv_state,
|
||||||
|
conv_weights,
|
||||||
|
self.conv1d.bias,
|
||||||
|
self.activation,
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states.unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
if cache_params is not None:
|
||||||
|
conv_states = nn.functional.pad(
|
||||||
|
hidden_states,
|
||||||
|
(self.conv_kernel_size - hidden_states.shape[-1], 0))
|
||||||
|
cache_params.conv_state.copy_(conv_states)
|
||||||
|
|
||||||
|
hidden_states = causal_conv1d_fn(
|
||||||
|
hidden_states,
|
||||||
|
conv_weights,
|
||||||
|
self.conv1d.bias,
|
||||||
|
activation=self.activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. State Space Model sequence transformation
|
||||||
|
# 3.a. input varying initialization of time_step, B and C
|
||||||
|
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0]
|
||||||
|
|
||||||
|
time_step, B, C = torch.split(
|
||||||
|
ssm_parameters,
|
||||||
|
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
time_step = self.dt_layernorm(time_step.contiguous())
|
||||||
|
B = self.b_layernorm(B.contiguous())
|
||||||
|
C = self.c_layernorm(C.contiguous())
|
||||||
|
|
||||||
|
discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2)
|
||||||
|
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
||||||
|
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
|
||||||
|
self.dt_proj, "bias") else None)
|
||||||
|
if cache_params is not None and not cache_params.is_prompt:
|
||||||
|
scan_outputs = selective_state_update(
|
||||||
|
cache_params.ssm_state,
|
||||||
|
hidden_states[..., 0],
|
||||||
|
discrete_time_step[..., 0],
|
||||||
|
self.A,
|
||||||
|
B[:, 0],
|
||||||
|
C[:, 0],
|
||||||
|
self.D,
|
||||||
|
gate[..., 0],
|
||||||
|
time_proj_bias,
|
||||||
|
dt_softplus=True,
|
||||||
|
).unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
scan_outputs, ssm_state = selective_scan_fn(
|
||||||
|
hidden_states,
|
||||||
|
discrete_time_step,
|
||||||
|
self.A,
|
||||||
|
B.transpose(1, 2),
|
||||||
|
C.transpose(1, 2),
|
||||||
|
self.D.float(),
|
||||||
|
gate,
|
||||||
|
time_proj_bias,
|
||||||
|
delta_softplus=True,
|
||||||
|
return_last_state=True,
|
||||||
|
)
|
||||||
|
if ssm_state is not None and cache_params is not None:
|
||||||
|
cache_params.ssm_state.copy_(ssm_state)
|
||||||
|
|
||||||
|
# 4. Final linear projection
|
||||||
|
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0]
|
||||||
|
return contextualized_states
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
conv_state: torch.Tensor,
|
||||||
|
ssm_state: torch.Tensor,
|
||||||
|
):
|
||||||
|
if attn_metadata.prefill_metadata is not None:
|
||||||
|
offset = 0
|
||||||
|
for i, prompt_len in enumerate(
|
||||||
|
attn_metadata.prefill_metadata.seq_lens):
|
||||||
|
cache = MambaCacheParams(True,
|
||||||
|
conv_state=conv_state[i].unsqueeze(0),
|
||||||
|
ssm_state=ssm_state[i].unsqueeze(0))
|
||||||
|
hidden_states[offset:offset + prompt_len].copy_(
|
||||||
|
self.mamba_forward(hidden_states[offset:offset +
|
||||||
|
prompt_len].unsqueeze(0),
|
||||||
|
cache_params=cache)[0])
|
||||||
|
offset += prompt_len
|
||||||
|
else:
|
||||||
|
cache = MambaCacheParams(False,
|
||||||
|
conv_state=conv_state,
|
||||||
|
ssm_state=ssm_state)
|
||||||
|
hidden_states = self.mamba_forward(hidden_states.unsqueeze(1),
|
||||||
|
cache_params=cache)
|
||||||
|
hidden_states = hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class JambaMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: JambaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
hidden_act = config.hidden_act
|
||||||
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
|
hidden_size, [intermediate_size] * 2,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config)
|
||||||
|
self.down_proj = RowParallelLinear(intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config)
|
||||||
|
if hidden_act != "silu":
|
||||||
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
|
"Only silu is supported for now.")
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class JambaMoE(nn.Module):
|
||||||
|
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
|
||||||
|
across all ranks.
|
||||||
|
|
||||||
|
Each expert's weights are sharded across all ranks and a fused MoE
|
||||||
|
kernel is used for the forward pass, and finally we reduce the outputs
|
||||||
|
across ranks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: JambaConfig,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
tp_size: Optional[int] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
||||||
|
self.num_total_experts = config.num_experts
|
||||||
|
self.top_k = config.num_experts_per_tok
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size // self.tp_size
|
||||||
|
|
||||||
|
if params_dtype is None:
|
||||||
|
params_dtype = torch.get_default_dtype()
|
||||||
|
self.params_dtype = params_dtype
|
||||||
|
|
||||||
|
self.router = ReplicatedLinear(self.hidden_size,
|
||||||
|
self.num_total_experts,
|
||||||
|
bias=False,
|
||||||
|
params_dtype=self.params_dtype)
|
||||||
|
|
||||||
|
self.ws = nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.num_total_experts,
|
||||||
|
2 * self.intermediate_size,
|
||||||
|
self.hidden_size,
|
||||||
|
device="cuda",
|
||||||
|
dtype=self.params_dtype,
|
||||||
|
))
|
||||||
|
self.w2s = nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.num_total_experts,
|
||||||
|
self.hidden_size,
|
||||||
|
self.intermediate_size,
|
||||||
|
device="cuda",
|
||||||
|
dtype=self.params_dtype,
|
||||||
|
))
|
||||||
|
|
||||||
|
set_weight_attrs(
|
||||||
|
self.ws,
|
||||||
|
{
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
self.w2s,
|
||||||
|
{
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def weight_loader(
|
||||||
|
self,
|
||||||
|
param: nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
weight_name: str,
|
||||||
|
expert_id: int,
|
||||||
|
):
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
param_data = param.data
|
||||||
|
shard_size = self.intermediate_size
|
||||||
|
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||||
|
if weight_name.endswith("gate_proj.weight"):
|
||||||
|
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
||||||
|
if weight_name.endswith("up_proj.weight"):
|
||||||
|
param_data[expert_id,
|
||||||
|
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
||||||
|
if weight_name.endswith("down_proj.weight"):
|
||||||
|
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
num_tokens, hidden_size = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||||
|
# router_logits: (batch * sequence_length, n_experts)
|
||||||
|
router_logits, _ = self.router(hidden_states)
|
||||||
|
|
||||||
|
final_hidden_states = fused_moe(
|
||||||
|
hidden_states,
|
||||||
|
self.ws,
|
||||||
|
self.w2s,
|
||||||
|
router_logits,
|
||||||
|
self.top_k,
|
||||||
|
renormalize=
|
||||||
|
False, # Mixtral normalize the expert probs to 1. We don't!
|
||||||
|
inplace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.tp_size > 1:
|
||||||
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
|
final_hidden_states)
|
||||||
|
|
||||||
|
return final_hidden_states.view(num_tokens, hidden_size)
|
||||||
|
|
||||||
|
|
||||||
|
class JambaMambaDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: JambaConfig,
|
||||||
|
layer_idx: int,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.config = config
|
||||||
|
self.mamba = JambaMambaMixer(config, layer_idx)
|
||||||
|
|
||||||
|
num_experts = config.layers_num_experts[layer_idx]
|
||||||
|
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
|
||||||
|
self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
conv_state: torch.Tensor,
|
||||||
|
ssm_state: torch.Tensor,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
|
|
||||||
|
hidden_states = self.mamba(hidden_states, attn_metadata, conv_state,
|
||||||
|
ssm_state)
|
||||||
|
# Fully Connected
|
||||||
|
hidden_states, residual = self.pre_ff_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
|
hidden_states = self.feed_forward(hidden_states)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class JambaAttentionDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: JambaConfig,
|
||||||
|
layer_idx: int,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = config.num_attention_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = config.num_key_value_heads
|
||||||
|
if self.total_num_kv_heads >= tp_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
|
self.head_dim = config.hidden_size // self.total_num_heads
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
config.hidden_size,
|
||||||
|
self.head_dim,
|
||||||
|
self.total_num_heads,
|
||||||
|
self.total_num_kv_heads,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
|
self.attn = Attention(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_experts = config.layers_num_experts[layer_idx]
|
||||||
|
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
|
||||||
|
self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def self_attention(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
|
|
||||||
|
hidden_states = self.self_attention(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
)
|
||||||
|
# Fully Connected
|
||||||
|
hidden_states, residual = self.pre_ff_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
|
hidden_states = self.feed_forward(hidden_states)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
ALL_DECODER_LAYER_TYPES = {
|
||||||
|
"attention": JambaAttentionDecoderLayer,
|
||||||
|
"mamba": JambaMambaDecoderLayer
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class JambaModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: JambaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||||
|
(lora_config.max_loras or 1)) if lora_config else 0)
|
||||||
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
|
self.org_vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
self.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_layers = []
|
||||||
|
for i in range(config.num_hidden_layers):
|
||||||
|
layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
|
||||||
|
decoder_layers.append(
|
||||||
|
layer_class(config,
|
||||||
|
layer_idx=i,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config))
|
||||||
|
self.layers = nn.ModuleList(decoder_layers)
|
||||||
|
self.final_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
conv_state: torch.Tensor,
|
||||||
|
ssm_state: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
residual = None
|
||||||
|
|
||||||
|
for i in range(len(self.layers)):
|
||||||
|
layer = self.layers[i]
|
||||||
|
kv_cache = None
|
||||||
|
current_ssm_state = None
|
||||||
|
current_conv_state = None
|
||||||
|
if isinstance(layer, JambaAttentionDecoderLayer):
|
||||||
|
kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
|
||||||
|
self.config.attn_layer_period]
|
||||||
|
if isinstance(layer, JambaMambaDecoderLayer):
|
||||||
|
current_state_layer = i - (1 +
|
||||||
|
(i - self.config.attn_layer_offset)
|
||||||
|
// self.config.attn_layer_period)
|
||||||
|
current_ssm_state = ssm_state[current_state_layer]
|
||||||
|
current_conv_state = conv_state[current_state_layer]
|
||||||
|
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
residual=residual,
|
||||||
|
conv_state=current_conv_state,
|
||||||
|
ssm_state=current_ssm_state,
|
||||||
|
)
|
||||||
|
hidden_states, _ = self.final_layernorm(hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class JambaForCausalLM(nn.Module):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
"qkv_proj",
|
||||||
|
"o_proj",
|
||||||
|
"embed_tokens",
|
||||||
|
"lm_head",
|
||||||
|
]
|
||||||
|
embedding_modules = {
|
||||||
|
"embed_tokens": "input_embeddings",
|
||||||
|
"lm_head": "output_embeddings",
|
||||||
|
}
|
||||||
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: JambaConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.model = JambaModel(config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
lora_config=lora_config)
|
||||||
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
if lora_config:
|
||||||
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||||
|
# We need bigger padding if using lora for kernel
|
||||||
|
# compatibility
|
||||||
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||||
|
)
|
||||||
|
# Current step used indices
|
||||||
|
self.current_indices: List[int] = []
|
||||||
|
# Used to track and store by the Mamba cache between steps.
|
||||||
|
self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple()
|
||||||
|
# Used as an input_buffer for the CUDA graph runs.
|
||||||
|
self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple()
|
||||||
|
# Maps between the request id and a dict that maps between the seq_id
|
||||||
|
# and its index inside the self.mamba_cache
|
||||||
|
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
**kwargs):
|
||||||
|
if not self.mamba_cache:
|
||||||
|
self._prepare_mamba_cache()
|
||||||
|
|
||||||
|
if "seqlen_agnostic_capture_inputs" not in kwargs:
|
||||||
|
# We get here only on Prefill/Eager mode runs
|
||||||
|
assert all(
|
||||||
|
key in kwargs
|
||||||
|
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
||||||
|
|
||||||
|
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
if attn_metadata.prefill_metadata:
|
||||||
|
batch_size = len(request_ids_to_seq_ids)
|
||||||
|
(
|
||||||
|
current_seqlen_agnostic_cache,
|
||||||
|
indices,
|
||||||
|
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
||||||
|
batch_size)
|
||||||
|
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||||
|
self._release_mamba_cache(finished_requests_ids)
|
||||||
|
else:
|
||||||
|
# CUDA graph capturing runs
|
||||||
|
current_seqlen_agnostic_cache, indices = (
|
||||||
|
kwargs["seqlen_agnostic_capture_inputs"],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
self.current_indices = indices
|
||||||
|
|
||||||
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
|
attn_metadata,
|
||||||
|
current_seqlen_agnostic_cache[0],
|
||||||
|
current_seqlen_agnostic_cache[1])
|
||||||
|
|
||||||
|
if "seqlen_agnostic_capture_inputs" not in kwargs:
|
||||||
|
self._copy_mamba_cache_by_indices(self.current_indices,
|
||||||
|
current_seqlen_agnostic_cache)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def _copy_mamba_cache_by_indices(
|
||||||
|
self, indices: List[int],
|
||||||
|
current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]):
|
||||||
|
for i, offset in enumerate(indices):
|
||||||
|
self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache)
|
||||||
|
|
||||||
|
def _copy_mamba_cache(self, index_to: int, index_from: int,
|
||||||
|
from_buffer: Tuple[torch.Tensor, torch.Tensor]):
|
||||||
|
assert len(self.mamba_cache) > 0
|
||||||
|
for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer):
|
||||||
|
cache_t[:, index_to].copy_(from_buffer_t[:, index_from],
|
||||||
|
non_blocking=True)
|
||||||
|
|
||||||
|
def _assign_seq_id_to_mamba_cache(self, cur_rid: str,
|
||||||
|
seqs_id: List[int]) -> List[int]:
|
||||||
|
indices_for_current_run = []
|
||||||
|
for seq_id in seqs_id:
|
||||||
|
if cur_rid not in self.mamba_cache_indices_mapping:
|
||||||
|
self.mamba_cache_indices_mapping[cur_rid] = {}
|
||||||
|
first_free_index = self._first_free_index_in_mamba_cache()
|
||||||
|
self.mamba_cache_indices_mapping[cur_rid][
|
||||||
|
seq_id] = first_free_index
|
||||||
|
index_for_current_run = first_free_index
|
||||||
|
## case of decoding n>1, copy prefill cache to decoding indices
|
||||||
|
elif seq_id not in (seq_ids2indices :=
|
||||||
|
self.mamba_cache_indices_mapping[cur_rid]):
|
||||||
|
first_free_index = self._first_free_index_in_mamba_cache()
|
||||||
|
index_exist = list(seq_ids2indices.values())[0]
|
||||||
|
self._copy_mamba_cache(index_from=index_exist,
|
||||||
|
index_to=first_free_index,
|
||||||
|
from_buffer=self.mamba_cache)
|
||||||
|
self.mamba_cache_indices_mapping[cur_rid][
|
||||||
|
seq_id] = first_free_index
|
||||||
|
index_for_current_run = first_free_index
|
||||||
|
else:
|
||||||
|
index_for_current_run = self.mamba_cache_indices_mapping[
|
||||||
|
cur_rid][seq_id]
|
||||||
|
|
||||||
|
indices_for_current_run.append(index_for_current_run)
|
||||||
|
return indices_for_current_run
|
||||||
|
|
||||||
|
def _prepare_current_run_mamba_cache(
|
||||||
|
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int
|
||||||
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]:
|
||||||
|
indices_for_current_run = []
|
||||||
|
for request_id, seqs_id in request_ids_to_seq_ids.items():
|
||||||
|
indices_for_current_run += self._assign_seq_id_to_mamba_cache(
|
||||||
|
request_id, seqs_id)
|
||||||
|
## Pad the batch in case of running batch that was not captured via CG
|
||||||
|
padded_indices = indices_for_current_run.copy()
|
||||||
|
pad_index = self._first_free_index_in_mamba_cache()
|
||||||
|
|
||||||
|
for _ in range(batch_size - len(indices_for_current_run)):
|
||||||
|
padded_indices.append(pad_index)
|
||||||
|
|
||||||
|
conv_state = self.mamba_cache[0][:, padded_indices]
|
||||||
|
temporal_state = self.mamba_cache[1][:, padded_indices]
|
||||||
|
|
||||||
|
return (conv_state, temporal_state), indices_for_current_run
|
||||||
|
|
||||||
|
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||||
|
"""
|
||||||
|
Copy the relevant Mamba cache into the CUDA graph input buffer
|
||||||
|
that was provided during the capture runs
|
||||||
|
(JambaForCausalLM.mamba_gc_cache_buffer).
|
||||||
|
"""
|
||||||
|
assert all(
|
||||||
|
key in kwargs
|
||||||
|
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
||||||
|
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||||
|
batch_size = len(request_ids_to_seq_ids)
|
||||||
|
(
|
||||||
|
current_mamba_cache,
|
||||||
|
indices,
|
||||||
|
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
||||||
|
batch_size)
|
||||||
|
self.current_indices = indices
|
||||||
|
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||||
|
self._release_mamba_cache(finished_requests_ids)
|
||||||
|
|
||||||
|
for input_buffer, current_cache_buffer in zip(
|
||||||
|
input_buffers["seqlen_agnostic_capture_inputs"],
|
||||||
|
current_mamba_cache):
|
||||||
|
input_buffer.copy_(current_cache_buffer, non_blocking=True)
|
||||||
|
|
||||||
|
def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs):
|
||||||
|
"""
|
||||||
|
Copy the relevant Mamba cache from the CUDA graph input_buffers
|
||||||
|
back to the JambaForCausalLM.mamba_cache after CUDA
|
||||||
|
graph replay run is done.
|
||||||
|
"""
|
||||||
|
self._copy_mamba_cache_by_indices(
|
||||||
|
self.current_indices,
|
||||||
|
input_buffers["seqlen_agnostic_capture_inputs"])
|
||||||
|
|
||||||
|
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||||
|
"""
|
||||||
|
Provide the CUDA graph capture runs with a buffer in adjusted size.
|
||||||
|
The buffer is used to maintain the Mamba Cache during the CUDA graph
|
||||||
|
replay runs.
|
||||||
|
"""
|
||||||
|
return tuple(buffer[:, :batch_size]
|
||||||
|
for buffer in self.mamba_gc_cache_buffer)
|
||||||
|
|
||||||
|
def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
|
||||||
|
for req_id in finished_seq_groups_req_ids:
|
||||||
|
if req_id in self.mamba_cache_indices_mapping:
|
||||||
|
self.mamba_cache_indices_mapping.pop(req_id)
|
||||||
|
|
||||||
|
def _first_free_index_in_mamba_cache(self) -> int:
|
||||||
|
if self.mamba_cache:
|
||||||
|
max_possible_batch_size = self.mamba_cache[0].shape[1]
|
||||||
|
occupied = [
|
||||||
|
id for seq_ids in self.mamba_cache_indices_mapping.values()
|
||||||
|
for id in seq_ids.values()
|
||||||
|
]
|
||||||
|
first_free_index = [
|
||||||
|
i not in occupied for i in range(max_possible_batch_size)
|
||||||
|
].index(True)
|
||||||
|
return first_free_index
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _get_mamba_cache_shape(
|
||||||
|
self
|
||||||
|
) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]:
|
||||||
|
world_size = get_tensor_model_parallel_world_size()
|
||||||
|
hidden_size = self.config.hidden_size
|
||||||
|
conv_state_shape = (
|
||||||
|
self.config.mamba_expand * hidden_size // world_size,
|
||||||
|
self.config.mamba_d_conv,
|
||||||
|
)
|
||||||
|
temporal_state_shape = (
|
||||||
|
self.config.mamba_expand * self.config.hidden_size // world_size,
|
||||||
|
self.config.mamba_d_state,
|
||||||
|
)
|
||||||
|
return conv_state_shape, temporal_state_shape
|
||||||
|
|
||||||
|
def _prepare_mamba_cache(self):
|
||||||
|
dtype = self.lm_head.weight.dtype
|
||||||
|
layers_type = self.config.layers_block_type
|
||||||
|
mamba_layers = sum(
|
||||||
|
[layer_type == "mamba" for layer_type in layers_type])
|
||||||
|
max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10
|
||||||
|
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
|
||||||
|
assert conv_state_shape is not None and temporal_state_shape is not None
|
||||||
|
for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]:
|
||||||
|
buffer = (torch.empty(size=(mamba_layers, max_batch_size) +
|
||||||
|
conv_state_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device="cuda"),
|
||||||
|
torch.empty(size=(mamba_layers, max_batch_size) +
|
||||||
|
temporal_state_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device="cuda"))
|
||||||
|
setattr(self, buffername, buffer)
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: Optional[torch.Tensor],
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
expert_params_mapping = [
|
||||||
|
# (param_name, weight_name, expert_id)
|
||||||
|
(
|
||||||
|
"ws" if weight_name in ["gate_proj", "up_proj"] else "w2s",
|
||||||
|
f"experts.{expert_id}.{weight_name}.weight",
|
||||||
|
expert_id,
|
||||||
|
) for expert_id in range(self.config.num_experts)
|
||||||
|
for weight_name in ["down_proj", "up_proj", "gate_proj"]
|
||||||
|
]
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "A_log" in name:
|
||||||
|
name = name.replace("A_log", "A")
|
||||||
|
|
||||||
|
if ".self_attn." in name:
|
||||||
|
name = name.replace(".self_attn", "")
|
||||||
|
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
if 'experts' in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
for param_name, weight_name, expert_id in expert_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param,
|
||||||
|
loaded_weight,
|
||||||
|
weight_name,
|
||||||
|
expert_id=expert_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
@ -934,6 +934,8 @@ class ExecuteModelRequest:
|
|||||||
previous_hidden_states: Optional[HiddenStates] = None
|
previous_hidden_states: Optional[HiddenStates] = None
|
||||||
# The number of forward steps to run.
|
# The number of forward steps to run.
|
||||||
num_steps: int = 1
|
num_steps: int = 1
|
||||||
|
# Finished request ids since last step.
|
||||||
|
finished_requests_ids: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
def clone(
|
def clone(
|
||||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||||
@ -949,4 +951,4 @@ class ExecuteModelRequest:
|
|||||||
running_queue_size=self.running_queue_size,
|
running_queue_size=self.running_queue_size,
|
||||||
previous_hidden_states=self.previous_hidden_states,
|
previous_hidden_states=self.previous_hidden_states,
|
||||||
num_steps=self.num_steps,
|
num_steps=self.num_steps,
|
||||||
)
|
finished_requests_ids=self.finished_requests_ids)
|
||||||
|
|||||||
@ -75,15 +75,19 @@ class TP1DraftModelRunner(ModelRunner):
|
|||||||
List[SequenceGroupMetadata]] = None
|
List[SequenceGroupMetadata]] = None
|
||||||
|
|
||||||
def prepare_model_input(
|
def prepare_model_input(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
virtual_engine: int = 0) -> ModelInputForGPUWithSamplingMetadata:
|
virtual_engine: int = 0,
|
||||||
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
|
) -> ModelInputForGPUWithSamplingMetadata:
|
||||||
"""A temporary solution that caches the seq_group_metadata_list
|
"""A temporary solution that caches the seq_group_metadata_list
|
||||||
for multi-step execution.
|
for multi-step execution.
|
||||||
TODO: In-place update model_input and remove this function.
|
TODO: In-place update model_input and remove this function.
|
||||||
"""
|
"""
|
||||||
self.cached_seq_group_metadata_list = seq_group_metadata_list
|
self.cached_seq_group_metadata_list = seq_group_metadata_list
|
||||||
return super().prepare_model_input(seq_group_metadata_list)
|
return super().prepare_model_input(
|
||||||
|
seq_group_metadata_list,
|
||||||
|
finished_requests_ids=finished_requests_ids)
|
||||||
|
|
||||||
def update_model_input(
|
def update_model_input(
|
||||||
self, model_input: ModelInputForGPUWithSamplingMetadata,
|
self, model_input: ModelInputForGPUWithSamplingMetadata,
|
||||||
|
|||||||
@ -33,7 +33,9 @@ class CacheEngine:
|
|||||||
self.device_config = device_config
|
self.device_config = device_config
|
||||||
|
|
||||||
self.head_size = model_config.get_head_size()
|
self.head_size = model_config.get_head_size()
|
||||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
# Models like Jamba, have mixed typed layers, E.g Mamba
|
||||||
|
self.num_attention_layers = model_config.get_num_attention_layers(
|
||||||
|
parallel_config)
|
||||||
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
||||||
|
|
||||||
self.block_size = cache_config.block_size
|
self.block_size = cache_config.block_size
|
||||||
@ -75,7 +77,7 @@ class CacheEngine:
|
|||||||
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
|
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
|
||||||
pin_memory = is_pin_memory_available() if device == "cpu" else False
|
pin_memory = is_pin_memory_available() if device == "cpu" else False
|
||||||
kv_cache: List[torch.Tensor] = []
|
kv_cache: List[torch.Tensor] = []
|
||||||
for _ in range(self.num_layers):
|
for _ in range(self.num_attention_layers):
|
||||||
# null block in CpuGpuBlockAllocator requires at least that
|
# null block in CpuGpuBlockAllocator requires at least that
|
||||||
# block to be zeroed-out.
|
# block to be zeroed-out.
|
||||||
# We zero-out everything for simplicity.
|
# We zero-out everything for simplicity.
|
||||||
@ -87,12 +89,12 @@ class CacheEngine:
|
|||||||
return kv_cache
|
return kv_cache
|
||||||
|
|
||||||
def swap_in(self, src_to_dst: torch.Tensor) -> None:
|
def swap_in(self, src_to_dst: torch.Tensor) -> None:
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_attention_layers):
|
||||||
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
|
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
|
||||||
src_to_dst)
|
src_to_dst)
|
||||||
|
|
||||||
def swap_out(self, src_to_dst: torch.Tensor) -> None:
|
def swap_out(self, src_to_dst: torch.Tensor) -> None:
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_attention_layers):
|
||||||
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
|
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
|
||||||
src_to_dst)
|
src_to_dst)
|
||||||
|
|
||||||
@ -107,11 +109,12 @@ class CacheEngine:
|
|||||||
) -> int:
|
) -> int:
|
||||||
head_size = model_config.get_head_size()
|
head_size = model_config.get_head_size()
|
||||||
num_heads = model_config.get_num_kv_heads(parallel_config)
|
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||||
num_layers = model_config.get_num_layers(parallel_config)
|
num_attention_layers = model_config.get_num_attention_layers(
|
||||||
|
parallel_config)
|
||||||
|
|
||||||
key_cache_block = cache_config.block_size * num_heads * head_size
|
key_cache_block = cache_config.block_size * num_heads * head_size
|
||||||
value_cache_block = key_cache_block
|
value_cache_block = key_cache_block
|
||||||
total = num_layers * (key_cache_block + value_cache_block)
|
total = num_attention_layers * (key_cache_block + value_cache_block)
|
||||||
if cache_config.cache_dtype == "auto":
|
if cache_config.cache_dtype == "auto":
|
||||||
dtype = model_config.dtype
|
dtype = model_config.dtype
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -314,9 +314,10 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def prepare_model_input(
|
def prepare_model_input(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
virtual_engine: int = 0,
|
virtual_engine: int = 0,
|
||||||
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
) -> CPUModelInput:
|
) -> CPUModelInput:
|
||||||
multi_modal_kwargs = None
|
multi_modal_kwargs = None
|
||||||
# NOTE: We assume that all sequences in the group are all prompts or
|
# NOTE: We assume that all sequences in the group are all prompts or
|
||||||
|
|||||||
@ -120,10 +120,11 @@ class EmbeddingModelRunner(
|
|||||||
self,
|
self,
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
virtual_engine: int = 0,
|
virtual_engine: int = 0,
|
||||||
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
) -> ModelInputForGPUWithPoolingMetadata:
|
) -> ModelInputForGPUWithPoolingMetadata:
|
||||||
assert seq_group_metadata_list is not None
|
assert seq_group_metadata_list is not None
|
||||||
model_input = self._prepare_model_input_tensors(
|
model_input = self._prepare_model_input_tensors(
|
||||||
seq_group_metadata_list)
|
seq_group_metadata_list, finished_requests_ids)
|
||||||
# Prepare PoolingMetadata.
|
# Prepare PoolingMetadata.
|
||||||
assert model_input.seq_lens is not None
|
assert model_input.seq_lens is not None
|
||||||
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
|
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
|
||||||
|
|||||||
@ -84,6 +84,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
|||||||
lora_requests: Optional[Set[LoRARequest]] = None
|
lora_requests: Optional[Set[LoRARequest]] = None
|
||||||
attn_metadata: Optional["AttentionMetadata"] = None
|
attn_metadata: Optional["AttentionMetadata"] = None
|
||||||
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
|
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
|
||||||
|
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
|
||||||
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
virtual_engine: int = 0
|
virtual_engine: int = 0
|
||||||
|
|
||||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
@ -94,6 +96,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
|||||||
"lora_mapping": self.lora_mapping,
|
"lora_mapping": self.lora_mapping,
|
||||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||||
"virtual_engine": self.virtual_engine,
|
"virtual_engine": self.virtual_engine,
|
||||||
|
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
|
||||||
|
"finished_requests_ids": self.finished_requests_ids,
|
||||||
}
|
}
|
||||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||||
return tensor_dict
|
return tensor_dict
|
||||||
@ -128,6 +132,8 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
|
|||||||
"lora_mapping": self.lora_mapping,
|
"lora_mapping": self.lora_mapping,
|
||||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||||
"virtual_engine": self.virtual_engine,
|
"virtual_engine": self.virtual_engine,
|
||||||
|
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
|
||||||
|
"finished_requests_ids": self.finished_requests_ids,
|
||||||
}
|
}
|
||||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||||
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||||
@ -191,6 +197,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
]
|
]
|
||||||
self.graph_memory_pool: Optional[Tuple[
|
self.graph_memory_pool: Optional[Tuple[
|
||||||
int, int]] = None # Set during graph capture.
|
int, int]] = None # Set during graph capture.
|
||||||
|
|
||||||
|
self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers(
|
||||||
|
parallel_config)
|
||||||
|
|
||||||
# When using CUDA graph, the input block tables must be padded to
|
# When using CUDA graph, the input block tables must be padded to
|
||||||
# max_seq_len_to_capture. However, creating the block table in
|
# max_seq_len_to_capture. However, creating the block table in
|
||||||
# Python can be expensive. To optimize this, we cache the block table
|
# Python can be expensive. To optimize this, we cache the block table
|
||||||
@ -317,6 +327,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
def _prepare_model_input_tensors(
|
def _prepare_model_input_tensors(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
) -> TModelInputForGPU:
|
) -> TModelInputForGPU:
|
||||||
"""Helper method to prepare the model input based on a given sequence
|
"""Helper method to prepare the model input based on a given sequence
|
||||||
group. Prepares metadata needed for the base model forward pass but not
|
group. Prepares metadata needed for the base model forward pass but not
|
||||||
@ -347,6 +358,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
block_tables: List[List[int]] = []
|
block_tables: List[List[int]] = []
|
||||||
multi_modal_kwargs_list: Dict[str,
|
multi_modal_kwargs_list: Dict[str,
|
||||||
List[torch.Tensor]] = defaultdict(list)
|
List[torch.Tensor]] = defaultdict(list)
|
||||||
|
request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list)
|
||||||
decode_only = True
|
decode_only = True
|
||||||
num_prefills = 0
|
num_prefills = 0
|
||||||
num_prefill_tokens = 0
|
num_prefill_tokens = 0
|
||||||
@ -738,7 +750,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
k: torch.cat(v, dim=0).to(self.device)
|
k: torch.cat(v, dim=0).to(self.device)
|
||||||
for k, v in multi_modal_kwargs_list.items()
|
for k, v in multi_modal_kwargs_list.items()
|
||||||
}
|
}
|
||||||
|
request_ids_to_seq_ids = {
|
||||||
|
seq_group_metadata.request_id:
|
||||||
|
list(seq_group_metadata.seq_data.keys())
|
||||||
|
for seq_group_metadata in seq_group_metadata_list
|
||||||
|
}
|
||||||
return self._model_input_cls(
|
return self._model_input_cls(
|
||||||
input_tokens=input_tokens_tensor,
|
input_tokens=input_tokens_tensor,
|
||||||
input_positions=input_positions_tensor,
|
input_positions=input_positions_tensor,
|
||||||
@ -748,7 +764,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
lora_mapping=lora_mapping,
|
lora_mapping=lora_mapping,
|
||||||
lora_requests=lora_requests,
|
lora_requests=lora_requests,
|
||||||
multi_modal_kwargs=multi_modal_kwargs,
|
multi_modal_kwargs=multi_modal_kwargs,
|
||||||
)
|
request_ids_to_seq_ids=request_ids_to_seq_ids,
|
||||||
|
finished_requests_ids=finished_requests_ids)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
@ -821,7 +838,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
# Run the model with the dummy inputs.
|
# Run the model with the dummy inputs.
|
||||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
kv_caches = [None] * num_layers
|
kv_caches = [None] * num_layers
|
||||||
model_input = self.prepare_model_input(seqs)
|
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||||
|
model_input = self.prepare_model_input(
|
||||||
|
seqs, finished_requests_ids=finished_requests_ids)
|
||||||
intermediate_tensors = None
|
intermediate_tensors = None
|
||||||
if not get_pp_group().is_first_rank:
|
if not get_pp_group().is_first_rank:
|
||||||
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
||||||
@ -1033,21 +1052,37 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
graph_runner.flashinfer_decode_wrapper = \
|
graph_runner.flashinfer_decode_wrapper = \
|
||||||
decode_wrapper
|
decode_wrapper
|
||||||
|
|
||||||
graph_runner.capture(
|
capture_inputs = {
|
||||||
|
"input_ids":
|
||||||
input_tokens[:batch_size],
|
input_tokens[:batch_size],
|
||||||
|
"positions":
|
||||||
input_positions[:batch_size],
|
input_positions[:batch_size],
|
||||||
|
"hidden_or_intermediate_states":
|
||||||
hidden_or_intermediate_states[
|
hidden_or_intermediate_states[
|
||||||
virtual_engine] # type: ignore
|
virtual_engine] # type: ignore
|
||||||
[:batch_size]
|
[:batch_size]
|
||||||
if hidden_or_intermediate_states[virtual_engine]
|
if hidden_or_intermediate_states[virtual_engine]
|
||||||
is not None else None,
|
is not None else None,
|
||||||
|
"intermediate_inputs":
|
||||||
intermediate_inputs[:batch_size]
|
intermediate_inputs[:batch_size]
|
||||||
if intermediate_inputs is not None else None,
|
if intermediate_inputs is not None else None,
|
||||||
|
"kv_caches":
|
||||||
kv_caches[virtual_engine],
|
kv_caches[virtual_engine],
|
||||||
|
"attn_metadata":
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
memory_pool=self.graph_memory_pool,
|
"memory_pool":
|
||||||
stream=graph_capture_context.stream,
|
self.graph_memory_pool,
|
||||||
)
|
"stream":
|
||||||
|
graph_capture_context.stream
|
||||||
|
}
|
||||||
|
if self.has_seqlen_agnostic:
|
||||||
|
# Only used by Mamba-based models CUDA graph atm (Jamba)
|
||||||
|
capture_inputs.update({
|
||||||
|
"seqlen_agnostic_capture_inputs":
|
||||||
|
self.model.get_seqlen_agnostic_capture_inputs(
|
||||||
|
batch_size)
|
||||||
|
})
|
||||||
|
graph_runner.capture(**capture_inputs)
|
||||||
self.graph_memory_pool = graph_runner.graph.pool()
|
self.graph_memory_pool = graph_runner.graph.pool()
|
||||||
self.graph_runners[virtual_engine][batch_size] = (
|
self.graph_runners[virtual_engine][batch_size] = (
|
||||||
graph_runner)
|
graph_runner)
|
||||||
@ -1084,6 +1119,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
virtual_engine: int = 0,
|
virtual_engine: int = 0,
|
||||||
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
) -> ModelInputForGPUWithSamplingMetadata:
|
) -> ModelInputForGPUWithSamplingMetadata:
|
||||||
"""Prepare the model input based on a given sequence group, including
|
"""Prepare the model input based on a given sequence group, including
|
||||||
metadata for the sampling step.
|
metadata for the sampling step.
|
||||||
@ -1099,7 +1135,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
If cuda graph is required, this API automatically pads inputs.
|
If cuda graph is required, this API automatically pads inputs.
|
||||||
"""
|
"""
|
||||||
model_input = self._prepare_model_input_tensors(
|
model_input = self._prepare_model_input_tensors(
|
||||||
seq_group_metadata_list)
|
seq_group_metadata_list, finished_requests_ids)
|
||||||
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
|
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
|
||||||
model_input.seq_lens,
|
model_input.seq_lens,
|
||||||
model_input.query_lens,
|
model_input.query_lens,
|
||||||
@ -1175,6 +1211,10 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
model_executable = self.model
|
model_executable = self.model
|
||||||
|
|
||||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||||
|
seqlen_agnostic_kwargs = {
|
||||||
|
"finished_requests_ids": model_input.finished_requests_ids,
|
||||||
|
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||||
|
} if self.has_seqlen_agnostic else {}
|
||||||
hidden_or_intermediate_states = model_executable(
|
hidden_or_intermediate_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
@ -1182,7 +1222,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
attn_metadata=model_input.attn_metadata,
|
attn_metadata=model_input.attn_metadata,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**multi_modal_kwargs,
|
**multi_modal_kwargs,
|
||||||
)
|
**seqlen_agnostic_kwargs)
|
||||||
|
|
||||||
# Compute the logits in the last pipeline stage.
|
# Compute the logits in the last pipeline stage.
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
@ -1305,6 +1345,7 @@ class CUDAGraphRunner:
|
|||||||
"positions": positions,
|
"positions": positions,
|
||||||
"kv_caches": kv_caches,
|
"kv_caches": kv_caches,
|
||||||
"slot_mapping": attn_metadata.slot_mapping,
|
"slot_mapping": attn_metadata.slot_mapping,
|
||||||
|
**kwargs,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
self.input_buffers = {
|
self.input_buffers = {
|
||||||
@ -1315,6 +1356,7 @@ class CUDAGraphRunner:
|
|||||||
"seq_lens_tensor":
|
"seq_lens_tensor":
|
||||||
attn_metadata.decode_metadata.seq_lens_tensor,
|
attn_metadata.decode_metadata.seq_lens_tensor,
|
||||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||||
|
**kwargs,
|
||||||
}
|
}
|
||||||
if intermediate_inputs is not None:
|
if intermediate_inputs is not None:
|
||||||
self.input_buffers.update(intermediate_inputs.tensors)
|
self.input_buffers.update(intermediate_inputs.tensors)
|
||||||
@ -1349,13 +1391,18 @@ class CUDAGraphRunner:
|
|||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
self.input_buffers["block_tables"].copy_(
|
self.input_buffers["block_tables"].copy_(
|
||||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||||
|
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
|
||||||
|
self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
|
||||||
|
**kwargs)
|
||||||
if intermediate_tensors is not None:
|
if intermediate_tensors is not None:
|
||||||
for key in intermediate_tensors.tensors:
|
for key in intermediate_tensors.tensors:
|
||||||
self.input_buffers[key].copy_(intermediate_tensors[key],
|
self.input_buffers[key].copy_(intermediate_tensors[key],
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
# Run the graph.
|
# Run the graph.
|
||||||
self.graph.replay()
|
self.graph.replay()
|
||||||
|
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
|
||||||
|
self.model.copy_outputs_after_cuda_graphs(self.input_buffers,
|
||||||
|
**kwargs)
|
||||||
# Return the output tensor.
|
# Return the output tensor.
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
return self.output_buffers["hidden_states"]
|
return self.output_buffers["hidden_states"]
|
||||||
|
|||||||
@ -139,6 +139,7 @@ class ModelRunnerBase(ABC, Generic[T]):
|
|||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
virtual_engine: int = 0,
|
virtual_engine: int = 0,
|
||||||
|
finished_requests_ids: Optional[List[str]] = None,
|
||||||
) -> T:
|
) -> T:
|
||||||
"""
|
"""
|
||||||
Prepare the inputs to ModelRunnerBase.execute_model from an execution
|
Prepare the inputs to ModelRunnerBase.execute_model from an execution
|
||||||
|
|||||||
@ -177,6 +177,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
|||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
virtual_engine: int = 0,
|
virtual_engine: int = 0,
|
||||||
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
) -> ModelInputForNeuron:
|
) -> ModelInputForNeuron:
|
||||||
# NOTE: We assume that all sequences in the group are all prompts or
|
# NOTE: We assume that all sequences in the group are all prompts or
|
||||||
# all decodes.
|
# all decodes.
|
||||||
|
|||||||
@ -234,7 +234,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
|||||||
model_input: ModelRunnerInputBase = (
|
model_input: ModelRunnerInputBase = (
|
||||||
self.model_runner.prepare_model_input(
|
self.model_runner.prepare_model_input(
|
||||||
execute_model_req.seq_group_metadata_list,
|
execute_model_req.seq_group_metadata_list,
|
||||||
execute_model_req.virtual_engine))
|
execute_model_req.virtual_engine,
|
||||||
|
execute_model_req.finished_requests_ids))
|
||||||
num_steps = execute_model_req.num_steps
|
num_steps = execute_model_req.num_steps
|
||||||
|
|
||||||
if self.do_metadata_broadcast:
|
if self.do_metadata_broadcast:
|
||||||
|
|||||||
@ -189,9 +189,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
|||||||
))
|
))
|
||||||
|
|
||||||
def prepare_model_input(
|
def prepare_model_input(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
virtual_engine: int = 0,
|
virtual_engine: int = 0,
|
||||||
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
) -> ModelInputForXPU:
|
) -> ModelInputForXPU:
|
||||||
multi_modal_input = None
|
multi_modal_input = None
|
||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user