mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-28 08:57:04 +08:00
merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
commit
286eeb91e8
104
benchmarks/kernels/benchmark_activation.py
Normal file
104
benchmarks/kernels/benchmark_activation.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# benchmark custom activation op performance
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.activation # noqa F401
|
||||||
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||||
|
|
||||||
|
batch_size_range = [1, 16, 32, 64, 128]
|
||||||
|
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||||
|
intermediate_size = [3072, 9728, 12288]
|
||||||
|
configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size))
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_activation(
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
provider: str,
|
||||||
|
func_name: str,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
device = "cuda"
|
||||||
|
num_tokens = batch_size * seq_len
|
||||||
|
dim = intermediate_size
|
||||||
|
current_platform.seed_everything(42)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
|
||||||
|
if func_name == "gelu_and_mul":
|
||||||
|
layer = CustomOp.op_registry[func_name](approximate="none")
|
||||||
|
elif func_name == "gelu_and_mul_tanh":
|
||||||
|
layer = CustomOp.op_registry["gelu_and_mul"](approximate="tanh")
|
||||||
|
elif func_name == "fatrelu_and_mul":
|
||||||
|
threshold = 0.5
|
||||||
|
layer = CustomOp.op_registry[func_name](threshold)
|
||||||
|
else:
|
||||||
|
layer = CustomOp.op_registry[func_name]()
|
||||||
|
|
||||||
|
x = torch.randn(num_tokens, dim, dtype=dtype, device=device)
|
||||||
|
compiled_layer = torch.compile(layer.forward_native)
|
||||||
|
|
||||||
|
if provider == "custom":
|
||||||
|
fn = lambda: layer(x)
|
||||||
|
elif provider == "compiled":
|
||||||
|
fn = lambda: compiled_layer(x)
|
||||||
|
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
fn, quantiles=[0.5, 0.2, 0.8]
|
||||||
|
)
|
||||||
|
return ms, max_ms, min_ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(description="Benchmark the custom activation op.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--func-name",
|
||||||
|
type=str,
|
||||||
|
choices=[
|
||||||
|
"mul_and_silu",
|
||||||
|
"silu_and_mul",
|
||||||
|
"gelu_and_mul",
|
||||||
|
"gelu_and_mul_tanh",
|
||||||
|
"fatrelu_and_mul",
|
||||||
|
"swigluoai_and_mul",
|
||||||
|
"gelu_new",
|
||||||
|
"gelu_fast",
|
||||||
|
"quick_gelu",
|
||||||
|
],
|
||||||
|
default="silu_and_mul",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
assert args
|
||||||
|
|
||||||
|
func_name = args.func_name
|
||||||
|
dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype]
|
||||||
|
|
||||||
|
perf_report = triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size", "seq_len", "intermediate_size"],
|
||||||
|
x_vals=configs,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["custom", "compiled"],
|
||||||
|
line_names=["Custom OP", "Compiled"],
|
||||||
|
styles=[("blue", "-"), ("green", "-")],
|
||||||
|
ylabel="ms",
|
||||||
|
plot_name=f"{func_name}-op-performance",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
perf_report(
|
||||||
|
lambda batch_size, seq_len, intermediate_size, provider: benchmark_activation(
|
||||||
|
batch_size, seq_len, intermediate_size, provider, func_name, dtype
|
||||||
|
)
|
||||||
|
).run(print_data=True)
|
||||||
@ -678,7 +678,11 @@ def main(args: argparse.Namespace):
|
|||||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||||
search_space = get_configs_compute_bound(is_fp16, block_quant_shape)
|
search_space = get_configs_compute_bound(is_fp16, block_quant_shape)
|
||||||
print(f"Start tuning over {len(search_space)} configurations...")
|
print(f"Start tuning over {len(search_space)} configurations...")
|
||||||
|
if use_deep_gemm:
|
||||||
|
raise ValueError(
|
||||||
|
"Tuning with --use-deep-gemm is not supported as it only tunes Triton "
|
||||||
|
"kernels. Please remove the flag."
|
||||||
|
)
|
||||||
start = time.time()
|
start = time.time()
|
||||||
configs = _distribute(
|
configs = _distribute(
|
||||||
"tune",
|
"tune",
|
||||||
|
|||||||
@ -73,17 +73,11 @@ async def test_zero_truncation_size(client: openai.AsyncOpenAI):
|
|||||||
"truncate_prompt_tokens": truncation_size
|
"truncate_prompt_tokens": truncation_size
|
||||||
}
|
}
|
||||||
|
|
||||||
with pytest.raises(openai.BadRequestError) as err:
|
response = await client.post(path="embeddings",
|
||||||
await client.post(path="embeddings", cast_to=object, body={**kwargs})
|
cast_to=object,
|
||||||
|
body={**kwargs})
|
||||||
|
|
||||||
assert err.value.status_code == 400
|
assert response["usage"]["prompt_tokens"] == truncation_size
|
||||||
error_details = err.value.response.json()["error"]
|
|
||||||
|
|
||||||
assert error_details["type"] == "BadRequestError"
|
|
||||||
assert "This model's maximum context length is" in error_details["message"]
|
|
||||||
assert "tokens in the input for embedding generation" in error_details[
|
|
||||||
"message"]
|
|
||||||
assert "Please reduce the length of the input" in error_details["message"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@ -130,6 +130,23 @@ class TestRenderPrompt:
|
|||||||
assert call_args.kwargs["truncation"] is True
|
assert call_args.kwargs["truncation"] is True
|
||||||
assert call_args.kwargs["max_length"] == 50
|
assert call_args.kwargs["max_length"] == 50
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncation_negative(self, renderer, mock_async_tokenizer):
|
||||||
|
# Test that negative truncation uses model's max_model_len
|
||||||
|
mock_async_tokenizer.return_value = MockTokenizerResult(
|
||||||
|
[101, 7592, 2088]) # Truncated to max_model_len
|
||||||
|
renderer.async_tokenizer_pool[
|
||||||
|
renderer.tokenizer] = mock_async_tokenizer
|
||||||
|
|
||||||
|
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
|
||||||
|
max_length=200,
|
||||||
|
truncate_prompt_tokens=-1)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
call_args = mock_async_tokenizer.call_args
|
||||||
|
assert call_args.kwargs["truncation"] is True
|
||||||
|
assert call_args.kwargs["max_length"] == 100 # model's max_model_len
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_token_truncation_last_elements(self, renderer):
|
async def test_token_truncation_last_elements(self, renderer):
|
||||||
# Test that token truncation keeps the last N elements
|
# Test that token truncation keeps the last N elements
|
||||||
|
|||||||
@ -41,9 +41,8 @@ EAGLE_SPEC_CONFIG = {
|
|||||||
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
|
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
|
||||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
|
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
|
||||||
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
|
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
|
||||||
#FIXME: This test is flaky on CI thus disabled
|
("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto",
|
||||||
#("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto",
|
None),
|
||||||
# None),
|
|
||||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
|
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
|
||||||
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
|
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
|
||||||
("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None),
|
("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None),
|
||||||
@ -123,6 +122,7 @@ def test_structured_output(
|
|||||||
guided_decoding_backend=guided_decoding_backend,
|
guided_decoding_backend=guided_decoding_backend,
|
||||||
guided_decoding_disable_any_whitespace=(guided_decoding_backend
|
guided_decoding_disable_any_whitespace=(guided_decoding_backend
|
||||||
in {"xgrammar", "guidance"}),
|
in {"xgrammar", "guidance"}),
|
||||||
|
seed=120,
|
||||||
tokenizer_mode=tokenizer_mode,
|
tokenizer_mode=tokenizer_mode,
|
||||||
speculative_config=speculative_config)
|
speculative_config=speculative_config)
|
||||||
|
|
||||||
|
|||||||
@ -6,8 +6,12 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
|
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||||
apply_top_k_top_p_tpu)
|
|
||||||
|
# isort: off
|
||||||
|
from vllm.v1.sample.tpu.sampler import (apply_top_k_top_p as
|
||||||
|
apply_top_k_top_p_tpu)
|
||||||
|
# isort: on
|
||||||
|
|
||||||
if not current_platform.is_tpu():
|
if not current_platform.is_tpu():
|
||||||
pytest.skip("This test needs a TPU.", allow_module_level=True)
|
pytest.skip("This test needs a TPU.", allow_module_level=True)
|
||||||
|
|||||||
@ -257,6 +257,32 @@ class AttentionLayer(Protocol):
|
|||||||
|
|
||||||
class AttentionImpl(ABC, Generic[T]):
|
class AttentionImpl(ABC, Generic[T]):
|
||||||
|
|
||||||
|
# Whether the attention impl can return the softmax lse for decode.
|
||||||
|
# Some features like decode context parallelism require the softmax lse.
|
||||||
|
can_return_lse_for_decode: bool = False
|
||||||
|
|
||||||
|
# some attention backends might not always want to return lse
|
||||||
|
# even if they can return lse (for efficiency reasons)
|
||||||
|
need_to_return_lse_for_decode: bool = False
|
||||||
|
|
||||||
|
dcp_world_size: int
|
||||||
|
dcp_rank: int
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
# use __new__ so that all subclasses will call this
|
||||||
|
self = super().__new__(cls)
|
||||||
|
try:
|
||||||
|
from vllm.distributed.parallel_state import get_dcp_group
|
||||||
|
self.dcp_world_size = get_dcp_group().world_size
|
||||||
|
self.dcp_rank = get_dcp_group().rank_in_group
|
||||||
|
except AssertionError:
|
||||||
|
# DCP might not be initialized in testing
|
||||||
|
self.dcp_world_size = 1
|
||||||
|
self.dcp_rank = 0
|
||||||
|
self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \
|
||||||
|
and self.can_return_lse_for_decode
|
||||||
|
return self
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -74,6 +74,7 @@ DEFAULT_CONDA_PATTERNS = {
|
|||||||
"zmq",
|
"zmq",
|
||||||
"nvidia",
|
"nvidia",
|
||||||
"pynvml",
|
"pynvml",
|
||||||
|
"flashinfer-python",
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFAULT_PIP_PATTERNS = {
|
DEFAULT_PIP_PATTERNS = {
|
||||||
@ -89,6 +90,7 @@ DEFAULT_PIP_PATTERNS = {
|
|||||||
"zmq",
|
"zmq",
|
||||||
"nvidia",
|
"nvidia",
|
||||||
"pynvml",
|
"pynvml",
|
||||||
|
"flashinfer-python",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -54,14 +54,11 @@ class ClassificationMixin(OpenAIServing):
|
|||||||
ctx.tokenizer = await self.engine_client.get_tokenizer(
|
ctx.tokenizer = await self.engine_client.get_tokenizer(
|
||||||
ctx.lora_request)
|
ctx.lora_request)
|
||||||
|
|
||||||
(
|
renderer = self._get_renderer(ctx.tokenizer)
|
||||||
ctx.request_prompts,
|
ctx.engine_prompts = await renderer.render_prompt(
|
||||||
ctx.engine_prompts,
|
prompt_or_prompts=ctx.request.input,
|
||||||
) = await self._preprocess_completion(
|
max_length=self.max_model_len,
|
||||||
ctx.request,
|
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens)
|
||||||
ctx.tokenizer,
|
|
||||||
ctx.request.input,
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -24,7 +24,6 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
|
|||||||
ErrorResponse, UsageInfo)
|
ErrorResponse, UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
|
from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
|
||||||
OpenAIServing,
|
OpenAIServing,
|
||||||
RequestPrompt,
|
|
||||||
ServeContext,
|
ServeContext,
|
||||||
TextTokensPrompt)
|
TextTokensPrompt)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
@ -79,11 +78,12 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
|
|
||||||
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
|
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
|
||||||
)
|
)
|
||||||
|
renderer = self._get_renderer(tokenizer)
|
||||||
|
|
||||||
if isinstance(ctx.request, EmbeddingChatRequest):
|
if isinstance(ctx.request, EmbeddingChatRequest):
|
||||||
(
|
(
|
||||||
_,
|
_,
|
||||||
ctx.request_prompts,
|
_,
|
||||||
ctx.engine_prompts,
|
ctx.engine_prompts,
|
||||||
) = await self._preprocess_chat(
|
) = await self._preprocess_chat(
|
||||||
ctx.request,
|
ctx.request,
|
||||||
@ -98,13 +98,18 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
add_special_tokens=ctx.request.add_special_tokens,
|
add_special_tokens=ctx.request.add_special_tokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
(ctx.request_prompts,
|
# Set max_length based on chunked processing capability
|
||||||
ctx.engine_prompts) = await self._preprocess_completion(
|
if self._should_use_chunked_processing(ctx.request):
|
||||||
ctx.request,
|
max_length = None
|
||||||
tokenizer,
|
else:
|
||||||
ctx.request.input,
|
max_length = self.max_embed_len or self.max_model_len
|
||||||
add_special_tokens=ctx.request.add_special_tokens,
|
|
||||||
)
|
ctx.engine_prompts = await renderer.render_prompt(
|
||||||
|
prompt_or_prompts=ctx.request.input,
|
||||||
|
max_length=max_length,
|
||||||
|
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens,
|
||||||
|
add_special_tokens=ctx.request.add_special_tokens,
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
except (ValueError, TypeError) as e:
|
except (ValueError, TypeError) as e:
|
||||||
logger.exception("Error in preprocessing prompt inputs")
|
logger.exception("Error in preprocessing prompt inputs")
|
||||||
@ -286,7 +291,6 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
self,
|
self,
|
||||||
ctx: EmbeddingServeContext,
|
ctx: EmbeddingServeContext,
|
||||||
engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
||||||
request_prompt: RequestPrompt,
|
|
||||||
pooling_params: PoolingParams,
|
pooling_params: PoolingParams,
|
||||||
trace_headers: Optional[Mapping[str, str]],
|
trace_headers: Optional[Mapping[str, str]],
|
||||||
prompt_index: int,
|
prompt_index: int,
|
||||||
@ -295,7 +299,7 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
request_id_item = f"{ctx.request_id}-{prompt_index}"
|
request_id_item = f"{ctx.request_id}-{prompt_index}"
|
||||||
|
|
||||||
self._log_inputs(request_id_item,
|
self._log_inputs(request_id_item,
|
||||||
request_prompt,
|
engine_prompt,
|
||||||
params=pooling_params,
|
params=pooling_params,
|
||||||
lora_request=ctx.lora_request)
|
lora_request=ctx.lora_request)
|
||||||
|
|
||||||
@ -353,20 +357,14 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"Engine prompts not available")
|
"Engine prompts not available")
|
||||||
|
|
||||||
if ctx.request_prompts is None:
|
|
||||||
return self.create_error_response(
|
|
||||||
"Request prompts not available")
|
|
||||||
|
|
||||||
max_pos_embeddings = self._get_max_position_embeddings()
|
max_pos_embeddings = self._get_max_position_embeddings()
|
||||||
|
|
||||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||||
request_prompt = ctx.request_prompts[i]
|
|
||||||
|
|
||||||
# Check if this specific prompt needs chunked processing
|
# Check if this specific prompt needs chunked processing
|
||||||
if self._is_text_tokens_prompt(request_prompt):
|
if self._is_text_tokens_prompt(engine_prompt):
|
||||||
# Cast to TextTokensPrompt since we've verified
|
# Cast to TextTokensPrompt since we've verified
|
||||||
# prompt_token_ids
|
# prompt_token_ids
|
||||||
text_tokens_prompt = cast(TextTokensPrompt, request_prompt)
|
text_tokens_prompt = cast(TextTokensPrompt, engine_prompt)
|
||||||
if (len(text_tokens_prompt["prompt_token_ids"])
|
if (len(text_tokens_prompt["prompt_token_ids"])
|
||||||
> max_pos_embeddings):
|
> max_pos_embeddings):
|
||||||
# Use chunked processing for this prompt
|
# Use chunked processing for this prompt
|
||||||
@ -382,8 +380,7 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
||||||
engine_prompt)
|
engine_prompt)
|
||||||
generator = await self._create_single_prompt_generator(
|
generator = await self._create_single_prompt_generator(
|
||||||
ctx, engine_prompt_typed, request_prompt, pooling_params,
|
ctx, engine_prompt_typed, pooling_params, trace_headers, i)
|
||||||
trace_headers, i)
|
|
||||||
generators.append(generator)
|
generators.append(generator)
|
||||||
|
|
||||||
from vllm.utils import merge_async_iterators
|
from vllm.utils import merge_async_iterators
|
||||||
@ -419,10 +416,6 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
if not use_chunked:
|
if not use_chunked:
|
||||||
return await super()._collect_batch(ctx=ctx)
|
return await super()._collect_batch(ctx=ctx)
|
||||||
|
|
||||||
if ctx.request_prompts is None:
|
|
||||||
return self.create_error_response(
|
|
||||||
"Request prompts not available")
|
|
||||||
|
|
||||||
if ctx.result_generator is None:
|
if ctx.result_generator is None:
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"Result generator not available")
|
"Result generator not available")
|
||||||
@ -538,7 +531,7 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
data=final_embedding)
|
data=final_embedding)
|
||||||
|
|
||||||
# Get original prompt token IDs for this prompt
|
# Get original prompt token IDs for this prompt
|
||||||
original_prompt = ctx.request_prompts[prompt_idx]
|
original_prompt = ctx.engine_prompts[prompt_idx]
|
||||||
if not self._is_text_tokens_prompt(original_prompt):
|
if not self._is_text_tokens_prompt(original_prompt):
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
f"Chunked prompt {prompt_idx} is not a "
|
f"Chunked prompt {prompt_idx} is not a "
|
||||||
|
|||||||
@ -368,23 +368,20 @@ class OpenAIServing:
|
|||||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||||
request_id_item = f"{ctx.request_id}-{i}"
|
request_id_item = f"{ctx.request_id}-{i}"
|
||||||
|
|
||||||
if ctx.request_prompts is None:
|
|
||||||
return self.create_error_response(
|
|
||||||
"Request prompts not available")
|
|
||||||
|
|
||||||
self._log_inputs(
|
|
||||||
request_id_item,
|
|
||||||
ctx.request_prompts[i],
|
|
||||||
params=pooling_params,
|
|
||||||
lora_request=ctx.lora_request,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mypy has an existing bug related to inferring the variance of
|
# Mypy has an existing bug related to inferring the variance of
|
||||||
# TypedDicts with `builtins.enumerate`:
|
# TypedDicts with `builtins.enumerate`:
|
||||||
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
|
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
|
||||||
engine_prompt = cast(
|
engine_prompt = cast(
|
||||||
Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
||||||
engine_prompt)
|
engine_prompt)
|
||||||
|
|
||||||
|
self._log_inputs(
|
||||||
|
request_id_item,
|
||||||
|
engine_prompt,
|
||||||
|
params=pooling_params,
|
||||||
|
lora_request=ctx.lora_request,
|
||||||
|
)
|
||||||
|
|
||||||
generator = self.engine_client.encode(
|
generator = self.engine_client.encode(
|
||||||
engine_prompt,
|
engine_prompt,
|
||||||
pooling_params,
|
pooling_params,
|
||||||
|
|||||||
@ -108,10 +108,15 @@ class CompletionRenderer(BaseRenderer):
|
|||||||
for detailed parameter documentation.
|
for detailed parameter documentation.
|
||||||
"""
|
"""
|
||||||
if truncate_prompt_tokens is not None:
|
if truncate_prompt_tokens is not None:
|
||||||
if max_length is not None:
|
|
||||||
assert 0 <= truncate_prompt_tokens <= max_length
|
|
||||||
if truncate_prompt_tokens == 0:
|
if truncate_prompt_tokens == 0:
|
||||||
return []
|
return []
|
||||||
|
if truncate_prompt_tokens < 0:
|
||||||
|
truncate_prompt_tokens = self.model_config.max_model_len
|
||||||
|
if max_length is not None and truncate_prompt_tokens > max_length:
|
||||||
|
raise ValueError(
|
||||||
|
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
|
||||||
|
f"cannot be greater than max_length ({max_length}). "
|
||||||
|
f"Please select a smaller truncation size.")
|
||||||
|
|
||||||
# Parse and batch the input prompts
|
# Parse and batch the input prompts
|
||||||
batch_inputs = parse_and_batch_prompt(prompt_or_prompts)
|
batch_inputs = parse_and_batch_prompt(prompt_or_prompts)
|
||||||
|
|||||||
@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -9,16 +9,16 @@
|
|||||||
},
|
},
|
||||||
"2": {
|
"2": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 256,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 5
|
"num_stages": 4
|
||||||
},
|
},
|
||||||
"4": {
|
"4": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 4
|
||||||
@ -26,15 +26,15 @@
|
|||||||
"8": {
|
"8": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 256,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 3
|
||||||
},
|
},
|
||||||
"16": {
|
"16": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 4
|
||||||
@ -42,7 +42,7 @@
|
|||||||
"24": {
|
"24": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 4
|
||||||
@ -53,12 +53,12 @@
|
|||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 256,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 5
|
||||||
},
|
},
|
||||||
"48": {
|
"48": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 256,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 4
|
||||||
@ -82,10 +82,10 @@
|
|||||||
"128": {
|
"128": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 3
|
||||||
},
|
},
|
||||||
"256": {
|
"256": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
@ -98,8 +98,8 @@
|
|||||||
"512": {
|
"512": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 32,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 4
|
||||||
},
|
},
|
||||||
@ -107,7 +107,7 @@
|
|||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 32,
|
"GROUP_SIZE_M": 16,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 4
|
||||||
},
|
},
|
||||||
@ -115,7 +115,7 @@
|
|||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 16,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 4
|
||||||
},
|
},
|
||||||
@ -123,15 +123,15 @@
|
|||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 256,
|
||||||
"GROUP_SIZE_M": 32,
|
"GROUP_SIZE_M": 16,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 4
|
||||||
},
|
},
|
||||||
"3072": {
|
"3072": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 256,
|
||||||
"GROUP_SIZE_M": 32,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 4
|
||||||
},
|
},
|
||||||
|
|||||||
@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 3
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -18,18 +18,18 @@
|
|||||||
"4": {
|
"4": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 256,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 32,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 3
|
"num_stages": 3
|
||||||
},
|
},
|
||||||
"8": {
|
"8": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 64,
|
"GROUP_SIZE_M": 64,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 3
|
||||||
},
|
},
|
||||||
"16": {
|
"16": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
@ -58,7 +58,7 @@
|
|||||||
"48": {
|
"48": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 256,
|
||||||
"GROUP_SIZE_M": 64,
|
"GROUP_SIZE_M": 64,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 4
|
||||||
@ -74,73 +74,73 @@
|
|||||||
"96": {
|
"96": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 256,
|
||||||
"GROUP_SIZE_M": 16,
|
"GROUP_SIZE_M": 32,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 3
|
"num_stages": 4
|
||||||
},
|
},
|
||||||
"128": {
|
"128": {
|
||||||
"BLOCK_SIZE_M": 128,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 256,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 64,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 2
|
"num_stages": 4
|
||||||
},
|
},
|
||||||
"256": {
|
"256": {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 256,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 256,
|
||||||
"GROUP_SIZE_M": 64,
|
"GROUP_SIZE_M": 64,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"512": {
|
|
||||||
"BLOCK_SIZE_M": 256,
|
|
||||||
"BLOCK_SIZE_N": 256,
|
|
||||||
"BLOCK_SIZE_K": 256,
|
|
||||||
"GROUP_SIZE_M": 64,
|
|
||||||
"num_warps": 8,
|
|
||||||
"num_stages": 4
|
"num_stages": 4
|
||||||
},
|
},
|
||||||
"1024": {
|
"1024": {
|
||||||
"BLOCK_SIZE_M": 256,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 256,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 256,
|
||||||
"GROUP_SIZE_M": 16,
|
"GROUP_SIZE_M": 64,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 4
|
||||||
},
|
},
|
||||||
"1536": {
|
"1536": {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 256,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 8,
|
|
||||||
"num_stages": 4
|
|
||||||
},
|
|
||||||
"2048": {
|
|
||||||
"BLOCK_SIZE_M": 32,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 256,
|
||||||
"GROUP_SIZE_M": 16,
|
"GROUP_SIZE_M": 16,
|
||||||
"num_warps": 8,
|
"num_warps": 4,
|
||||||
"num_stages": 5
|
"num_stages": 3
|
||||||
},
|
},
|
||||||
"3072": {
|
"2048": {
|
||||||
"BLOCK_SIZE_M": 128,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 64,
|
"GROUP_SIZE_M": 64,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 4
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
},
|
},
|
||||||
"4096": {
|
"4096": {
|
||||||
"BLOCK_SIZE_M": 128,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 256,
|
"BLOCK_SIZE_N": 128,
|
||||||
"BLOCK_SIZE_K": 256,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 64,
|
"GROUP_SIZE_M": 16,
|
||||||
"num_warps": 8,
|
"num_warps": 4,
|
||||||
"num_stages": 5
|
"num_stages": 3
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -25,7 +25,7 @@
|
|||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -41,15 +41,13 @@ from transformers.models.whisper import WhisperFeatureExtractor
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
|
||||||
from vllm.model_executor.models.qwen2_5_vl import (
|
from vllm.model_executor.models.qwen2_5_vl import (
|
||||||
Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs,
|
Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs,
|
||||||
Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs,
|
Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs,
|
||||||
Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs,
|
Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs,
|
||||||
Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs)
|
Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs)
|
||||||
from vllm.model_executor.models.qwen2_audio import (
|
from vllm.model_executor.models.qwen2_audio import (
|
||||||
Qwen2AudioFeatureInputs, Qwen2AudioProcessingInfo,
|
Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths)
|
||||||
_get_feat_extract_output_lengths)
|
|
||||||
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser
|
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
@ -66,9 +64,9 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
|
from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
SupportsMultiModal, SupportsPP)
|
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||||
init_vllm_registered_model, maybe_prefix,
|
init_vllm_registered_model, maybe_prefix,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
@ -81,6 +79,26 @@ except (ImportError, ModuleNotFoundError):
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5OmniAudioFeatureInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- na: Number of audios
|
||||||
|
- nmb: Number of mel bins
|
||||||
|
- msl: Maximum sequence length
|
||||||
|
- tsl: Total sequence length
|
||||||
|
"""
|
||||||
|
type: Literal["audio_features"]
|
||||||
|
input_features: Annotated[
|
||||||
|
Union[torch.Tensor, list[torch.Tensor]],
|
||||||
|
TensorShape("nmb", "tsl"),
|
||||||
|
]
|
||||||
|
|
||||||
|
feature_attention_mask: Annotated[
|
||||||
|
torch.Tensor,
|
||||||
|
TensorShape("na", "msl"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def create_qwen2_5_omni_thinker_field_factory(
|
def create_qwen2_5_omni_thinker_field_factory(
|
||||||
spatial_merge_size: int
|
spatial_merge_size: int
|
||||||
) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str,
|
) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str,
|
||||||
@ -536,7 +554,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
|||||||
return torch.concat(mm_input, dim=dim)
|
return torch.concat(mm_input, dim=dim)
|
||||||
|
|
||||||
def _parse_and_validate_audio_input(
|
def _parse_and_validate_audio_input(
|
||||||
self, **kwargs: object) -> Optional[Qwen2AudioFeatureInputs]:
|
self, **kwargs: object) -> Optional[Qwen2_5OmniAudioFeatureInputs]:
|
||||||
input_audio_features = kwargs.pop('input_audio_features', None)
|
input_audio_features = kwargs.pop('input_audio_features', None)
|
||||||
audio_feature_lengths = kwargs.pop('audio_feature_lengths', None)
|
audio_feature_lengths = kwargs.pop('audio_feature_lengths', None)
|
||||||
feature_attention_mask = kwargs.pop('feature_attention_mask', None)
|
feature_attention_mask = kwargs.pop('feature_attention_mask', None)
|
||||||
@ -550,7 +568,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
|||||||
if not isinstance(input_audio_features, (torch.Tensor, list)):
|
if not isinstance(input_audio_features, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of audio input features. "
|
raise ValueError("Incorrect type of audio input features. "
|
||||||
f"Got type: {type(input_audio_features)}")
|
f"Got type: {type(input_audio_features)}")
|
||||||
return Qwen2AudioFeatureInputs(
|
return Qwen2_5OmniAudioFeatureInputs(
|
||||||
|
type="audio_features",
|
||||||
input_features=input_audio_features,
|
input_features=input_audio_features,
|
||||||
audio_feature_lengths=audio_feature_lengths,
|
audio_feature_lengths=audio_feature_lengths,
|
||||||
feature_attention_mask=feature_attention_mask)
|
feature_attention_mask=feature_attention_mask)
|
||||||
@ -633,7 +652,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
|||||||
|
|
||||||
def _process_audio_input(
|
def _process_audio_input(
|
||||||
self,
|
self,
|
||||||
audio_input: Qwen2AudioFeatureInputs,
|
audio_input: Qwen2_5OmniAudioFeatureInputs,
|
||||||
audio_hashes: list[str] = None,
|
audio_hashes: list[str] = None,
|
||||||
cached_audio_features: torch.Tensor = None,
|
cached_audio_features: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -660,8 +679,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
|||||||
feature_lens=audio_feature_lengths,
|
feature_lens=audio_feature_lengths,
|
||||||
aftercnn_lens=audio_feat_lengths,
|
aftercnn_lens=audio_feat_lengths,
|
||||||
)
|
)
|
||||||
audio_features = audio_outputs.last_hidden_state
|
return audio_outputs.last_hidden_state.split(
|
||||||
return audio_features.split(audio_output_lengths.tolist())
|
audio_output_lengths.tolist())
|
||||||
|
|
||||||
def _process_image_input(
|
def _process_image_input(
|
||||||
self,
|
self,
|
||||||
@ -707,7 +726,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
|||||||
dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder,
|
dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder,
|
||||||
)
|
)
|
||||||
class Qwen2_5OmniThinkerForConditionalGeneration(
|
class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||||
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
nn.Module, SupportsMultiModal, SupportsPP,
|
||||||
Qwen2_5OmniConditionalGenerationMixin):
|
Qwen2_5OmniConditionalGenerationMixin):
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_prefix={
|
orig_to_new_prefix={
|
||||||
@ -800,15 +819,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
|||||||
def get_language_model(self) -> torch.nn.Module:
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
return self.language_model
|
return self.language_model
|
||||||
|
|
||||||
def get_mm_mapping(self) -> MultiModelKeys:
|
|
||||||
"""Get module prefix for multimodal models to filter LoRA modules."""
|
|
||||||
return MultiModelKeys.from_string_field(
|
|
||||||
language_model="language_model",
|
|
||||||
connector=[], # No explicit connector in this model
|
|
||||||
tower_model=["visual",
|
|
||||||
"audio_tower"], # Exclude vision and audio towers
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_multimodal_embeddings(self,
|
def get_multimodal_embeddings(self,
|
||||||
**kwargs: object) -> MultiModalEmbeddings:
|
**kwargs: object) -> MultiModalEmbeddings:
|
||||||
|
|
||||||
|
|||||||
@ -27,7 +27,7 @@
|
|||||||
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
|
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
|
||||||
from collections.abc import Iterable, Mapping
|
from collections.abc import Iterable, Mapping
|
||||||
from functools import lru_cache, partial
|
from functools import lru_cache, partial
|
||||||
from typing import Callable, Literal, Optional, TypedDict, Union
|
from typing import Annotated, Callable, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -64,6 +64,7 @@ from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
|
|||||||
from vllm.platforms import _Backend
|
from vllm.platforms import _Backend
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.config import uses_mrope
|
from vllm.transformers_utils.config import uses_mrope
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
SupportsMultiModal, SupportsPP, SupportsQuant)
|
SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||||
@ -80,84 +81,125 @@ logger = init_logger(__name__)
|
|||||||
# === Vision Inputs === #
|
# === Vision Inputs === #
|
||||||
|
|
||||||
|
|
||||||
class Qwen2_5_VLImagePixelInputs(TypedDict):
|
class Qwen2_5_VLImagePixelInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- np: Number of patches
|
||||||
|
- ni: Number of images
|
||||||
|
- cps: Number of channels * patch_size * patch_size
|
||||||
|
|
||||||
|
Historical context:
|
||||||
|
- pixel_values shape: (num_patches, num_channels * patch_size *
|
||||||
|
patch_size)
|
||||||
|
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
|
||||||
|
formatnum_channels * patch_size * patch_size
|
||||||
|
"""
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
pixel_values: torch.Tensor
|
|
||||||
"""Shape:
|
pixel_values: Annotated[
|
||||||
`(num_patches, num_channels * patch_size * patch_size)`
|
torch.Tensor,
|
||||||
|
TensorShape("np", "cps"),
|
||||||
|
]
|
||||||
|
|
||||||
|
image_grid_thw: Annotated[
|
||||||
|
torch.Tensor,
|
||||||
|
TensorShape("ni", 3),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5_VLImageEmbeddingInputs(TensorSchema):
|
||||||
"""
|
"""
|
||||||
|
Dimensions:
|
||||||
image_grid_thw: torch.Tensor
|
- nf: Number of image features
|
||||||
"""Shape: `(num_images, 3)`
|
- hs: Hidden size
|
||||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
- ni: Number of images
|
||||||
|
|
||||||
|
Historical context:
|
||||||
|
- image_embeds shape: (num_image_features, hidden_size)
|
||||||
|
- num_image_features varies based on the number and resolution of the
|
||||||
|
images.
|
||||||
|
- hidden_size must match the hidden size of language model backbone.
|
||||||
|
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
|
||||||
|
format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Qwen2_5_VLImageEmbeddingInputs(TypedDict):
|
|
||||||
type: Literal["image_embeds"]
|
type: Literal["image_embeds"]
|
||||||
image_embeds: torch.Tensor
|
|
||||||
"""Supported types:
|
|
||||||
- list[`torch.Tensor`]: A list of tensors holding all images' features.
|
|
||||||
Each tensor holds an image's features.
|
|
||||||
- `torch.Tensor`: A tensor holding all images' features
|
|
||||||
(concatenation of all images' feature tensors).
|
|
||||||
|
|
||||||
Tensor shape: `(num_image_features, hidden_size)`
|
image_embeds: Annotated[
|
||||||
- `num_image_features` varies based on
|
torch.Tensor,
|
||||||
the number and resolution of the images.
|
TensorShape("nf", "hs"),
|
||||||
- `hidden_size` must match the hidden size of language model backbone.
|
]
|
||||||
"""
|
|
||||||
|
|
||||||
image_grid_thw: torch.Tensor
|
image_grid_thw: Annotated[
|
||||||
"""Shape: `(num_images, 3)`
|
torch.Tensor,
|
||||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
TensorShape("ni", 3),
|
||||||
"""
|
]
|
||||||
|
|
||||||
|
|
||||||
Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs,
|
Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs,
|
||||||
Qwen2_5_VLImageEmbeddingInputs]
|
Qwen2_5_VLImageEmbeddingInputs]
|
||||||
|
|
||||||
|
|
||||||
class Qwen2_5_VLVideoPixelInputs(TypedDict):
|
class Qwen2_5_VLVideoPixelInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- np: Number of patches
|
||||||
|
- nv: Number of videos
|
||||||
|
- ctps: Number of channels * temporal_patch_size * patch_size *
|
||||||
|
patch_size
|
||||||
|
|
||||||
|
Historical context:
|
||||||
|
- pixel_values_videos shape: (num_patches, num_channels *
|
||||||
|
temporal_patch_size * patch_size * patch_size)
|
||||||
|
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
|
||||||
|
format
|
||||||
|
- second_per_grid_ts: The video time interval (in seconds) for each
|
||||||
|
grid along the temporal dimension in the 3D position IDs. Returned
|
||||||
|
when `videos` is not `None`.
|
||||||
|
"""
|
||||||
type: Literal["pixel_values_videos"]
|
type: Literal["pixel_values_videos"]
|
||||||
pixel_values_videos: torch.Tensor
|
|
||||||
"""Shape:
|
pixel_values_videos: Annotated[
|
||||||
`(num_patches,
|
torch.Tensor,
|
||||||
num_channels * temporal_patch_size * patch_size * patch_size)`
|
TensorShape("np", "ctps"),
|
||||||
|
]
|
||||||
|
|
||||||
|
video_grid_thw: Annotated[
|
||||||
|
torch.Tensor,
|
||||||
|
TensorShape("nv", 3),
|
||||||
|
]
|
||||||
|
|
||||||
|
second_per_grid_ts: Annotated[
|
||||||
|
Optional[torch.Tensor],
|
||||||
|
TensorShape("nv"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
|
||||||
"""
|
"""
|
||||||
|
Dimensions:
|
||||||
video_grid_thw: torch.Tensor
|
- nf: Number of video features
|
||||||
"""Shape: `(num_videos, 3)`
|
- hs: Hidden size
|
||||||
|
- nv: Number of videos
|
||||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
|
||||||
|
Historical context:
|
||||||
|
- video_embeds shape: (num_video_features, hidden_size)
|
||||||
|
- num_video_features varies based on the number and resolution of the
|
||||||
|
videos.
|
||||||
|
- hidden_size must match the hidden size of language model backbone.
|
||||||
|
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
|
||||||
|
format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
second_per_grid_ts: torch.Tensor
|
|
||||||
"""
|
|
||||||
The video time interval (in seconds) for each grid along the temporal
|
|
||||||
dimension in the 3D position IDs. Returned when `videos` is not `None`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen2_5_VLVideoEmbeddingInputs(TypedDict):
|
|
||||||
type: Literal["video_embeds"]
|
type: Literal["video_embeds"]
|
||||||
video_embeds: torch.Tensor
|
|
||||||
"""Supported types:
|
|
||||||
- list[`torch.Tensor`]: A list of tensors holding all videos' features.
|
|
||||||
Each tensor holds an video's features.
|
|
||||||
- `torch.Tensor`: A tensor holding all videos' features
|
|
||||||
(concatenation of all videos' feature tensors).
|
|
||||||
|
|
||||||
Tensor shape: `(num_image_features, hidden_size)`
|
video_embeds: Annotated[
|
||||||
- `num_image_features` varies based on
|
torch.Tensor,
|
||||||
the number and resolution of the videos.
|
TensorShape("nf", "hs"),
|
||||||
- `hidden_size` must match the hidden size of language model backbone.
|
]
|
||||||
"""
|
|
||||||
|
|
||||||
video_grid_thw: torch.Tensor
|
video_grid_thw: Annotated[
|
||||||
"""Shape: `(num_videos, 3)`
|
torch.Tensor,
|
||||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
TensorShape("nv", 3),
|
||||||
"""
|
]
|
||||||
|
|
||||||
|
|
||||||
Qwen2_5_VLVideoInputs = Union[Qwen2_5_VLVideoPixelInputs,
|
Qwen2_5_VLVideoInputs = Union[Qwen2_5_VLVideoPixelInputs,
|
||||||
@ -936,10 +978,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||||
image_grid_thw, "image grid_thw")
|
image_grid_thw, "image grid_thw")
|
||||||
|
|
||||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
|
||||||
raise ValueError("Incorrect type of image pixel values. "
|
|
||||||
f"Got type: {type(pixel_values)}")
|
|
||||||
|
|
||||||
return Qwen2_5_VLImagePixelInputs(type="pixel_values",
|
return Qwen2_5_VLImagePixelInputs(type="pixel_values",
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
image_grid_thw=image_grid_thw)
|
image_grid_thw=image_grid_thw)
|
||||||
@ -950,9 +988,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||||
image_grid_thw, "image grid_thw")
|
image_grid_thw, "image grid_thw")
|
||||||
|
|
||||||
if not isinstance(image_embeds, torch.Tensor):
|
|
||||||
raise ValueError("Incorrect type of image embeddings. "
|
|
||||||
f"Got type: {type(image_embeds)}")
|
|
||||||
return Qwen2_5_VLImageEmbeddingInputs(
|
return Qwen2_5_VLImageEmbeddingInputs(
|
||||||
type="image_embeds",
|
type="image_embeds",
|
||||||
image_embeds=image_embeds,
|
image_embeds=image_embeds,
|
||||||
@ -973,7 +1008,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
pixel_values_videos, "video pixel values")
|
pixel_values_videos, "video pixel values")
|
||||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||||
video_grid_thw, "video grid_thw")
|
video_grid_thw, "video grid_thw")
|
||||||
|
if second_per_grid_ts is not None and second_per_grid_ts.ndim == 2:
|
||||||
|
second_per_grid_ts = second_per_grid_ts.squeeze(-1)
|
||||||
return Qwen2_5_VLVideoPixelInputs(
|
return Qwen2_5_VLVideoPixelInputs(
|
||||||
type="pixel_values_videos",
|
type="pixel_values_videos",
|
||||||
pixel_values_videos=pixel_values_videos,
|
pixel_values_videos=pixel_values_videos,
|
||||||
@ -987,9 +1023,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||||
video_grid_thw, "video grid_thw")
|
video_grid_thw, "video grid_thw")
|
||||||
|
|
||||||
if not isinstance(video_embeds, torch.Tensor):
|
|
||||||
raise ValueError("Incorrect type of video embeddings. "
|
|
||||||
f"Got type: {type(video_embeds)}")
|
|
||||||
return Qwen2_5_VLVideoEmbeddingInputs(
|
return Qwen2_5_VLVideoEmbeddingInputs(
|
||||||
type="video_embeds",
|
type="video_embeds",
|
||||||
video_embeds=video_embeds,
|
video_embeds=video_embeds,
|
||||||
|
|||||||
@ -23,7 +23,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
|
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import Any, Literal, Optional, TypedDict, Union
|
from typing import Annotated, Any, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -47,6 +47,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
PromptUpdate, PromptUpdateDetails)
|
PromptUpdate, PromptUpdateDetails)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||||
@ -54,21 +55,38 @@ from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
|||||||
|
|
||||||
|
|
||||||
# # === Audio Inputs === #
|
# # === Audio Inputs === #
|
||||||
class Qwen2AudioFeatureInputs(TypedDict):
|
class Qwen2AudioFeatureInputs(TensorSchema):
|
||||||
type: Literal["audio_features"]
|
|
||||||
input_features: torch.Tensor
|
|
||||||
"""Shape: `(num_audios, num_mel_bins, 3000)`"""
|
|
||||||
|
|
||||||
feature_attention_mask: torch.Tensor
|
|
||||||
"""Shape: `(num_audios, 3000)`"""
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen2AudioEmbeddingInputs(TypedDict):
|
|
||||||
type: Literal["audio_embeds"]
|
|
||||||
audio_embeds: list[torch.Tensor]
|
|
||||||
"""Shape: `(num_audio_features, hidden_size)`
|
|
||||||
`hidden_size` must match the hidden size of language model backbone.
|
|
||||||
"""
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- na: Number of audios
|
||||||
|
- nmb: Number of mel bins
|
||||||
|
"""
|
||||||
|
type: Literal["audio_features"]
|
||||||
|
input_features: Annotated[
|
||||||
|
Union[torch.Tensor, list[torch.Tensor]],
|
||||||
|
TensorShape("na", "nmb", 3000),
|
||||||
|
]
|
||||||
|
|
||||||
|
feature_attention_mask: Annotated[
|
||||||
|
torch.Tensor,
|
||||||
|
TensorShape("na", 3000),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2AudioEmbeddingInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- bn: Batch size
|
||||||
|
- naf: Number of audio features
|
||||||
|
- hs: Hidden size (must match the hidden size of language model
|
||||||
|
backbone)
|
||||||
|
"""
|
||||||
|
type: Literal["audio_embeds"] = "audio_embeds"
|
||||||
|
|
||||||
|
audio_embeds: Annotated[
|
||||||
|
list[torch.Tensor],
|
||||||
|
TensorShape("bn", "naf", "hs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs]
|
Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs]
|
||||||
|
|||||||
@ -26,7 +26,7 @@
|
|||||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Literal, Optional, TypedDict, Union
|
from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -70,6 +70,7 @@ from vllm.platforms import _Backend, current_platform
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.config import uses_mrope
|
from vllm.transformers_utils.config import uses_mrope
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
SupportsMultiModal, SupportsPP)
|
SupportsMultiModal, SupportsPP)
|
||||||
@ -86,78 +87,119 @@ _MAX_FRAMES_PER_VIDEO = 16
|
|||||||
# === Vision Inputs === #
|
# === Vision Inputs === #
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VLImagePixelInputs(TypedDict):
|
class Qwen2VLImagePixelInputs(TensorSchema):
|
||||||
type: Literal["pixel_values"]
|
|
||||||
pixel_values: torch.Tensor
|
|
||||||
"""Shape:
|
|
||||||
`(num_patches, num_channels * patch_size * patch_size)`
|
|
||||||
"""
|
"""
|
||||||
|
Dimensions:
|
||||||
image_grid_thw: torch.Tensor
|
- np: The total number of patches over each image over each prompt in
|
||||||
"""Shape: `(num_images, 3)`
|
the batch
|
||||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
- ni: Number of images
|
||||||
"""
|
- cps: Number of channels * patch_size * patch_size
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VLImageEmbeddingInputs(TypedDict):
|
|
||||||
type: Literal["image_embeds"]
|
|
||||||
image_embeds: torch.Tensor
|
|
||||||
"""Supported types:
|
|
||||||
- list[`torch.Tensor`]: A list of tensors holding all images' features.
|
|
||||||
Each tensor holds an image's features.
|
|
||||||
- `torch.Tensor`: A tensor holding all images' features
|
|
||||||
(concatenation of all images' feature tensors).
|
|
||||||
|
|
||||||
Tensor shape: `(num_image_features, hidden_size)`
|
Historical context:
|
||||||
- `num_image_features` varies based on
|
- pixel_values shape: (num_patches, num_channels * patch_size *
|
||||||
the number and resolution of the images.
|
patch_size)
|
||||||
- `hidden_size` must match the hidden size of language model backbone.
|
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
|
||||||
|
format
|
||||||
"""
|
"""
|
||||||
|
type: Literal["pixel_values"]
|
||||||
|
|
||||||
image_grid_thw: torch.Tensor
|
pixel_values: Annotated[
|
||||||
"""Shape: `(num_images, 3)`
|
torch.Tensor,
|
||||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
TensorShape("np", "cps"),
|
||||||
|
]
|
||||||
|
|
||||||
|
image_grid_thw: Annotated[
|
||||||
|
torch.Tensor,
|
||||||
|
TensorShape("ni", 3),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLImageEmbeddingInputs(TensorSchema):
|
||||||
"""
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- nf: Number of image features
|
||||||
|
- hs: Hidden size
|
||||||
|
- ni: Number of images
|
||||||
|
|
||||||
|
Historical context:
|
||||||
|
- image_embeds shape: (num_image_features, hidden_size)
|
||||||
|
- num_image_features varies based on the number and resolution of the
|
||||||
|
images.
|
||||||
|
- hidden_size must match the hidden size of language model backbone.
|
||||||
|
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
|
||||||
|
format
|
||||||
|
"""
|
||||||
|
type: Literal["image_embeds"]
|
||||||
|
|
||||||
|
image_embeds: Annotated[
|
||||||
|
torch.Tensor,
|
||||||
|
TensorShape("nf", "hs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
image_grid_thw: Annotated[
|
||||||
|
torch.Tensor,
|
||||||
|
TensorShape("ni", 3),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs,
|
Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs,
|
||||||
Qwen2VLImageEmbeddingInputs]
|
Qwen2VLImageEmbeddingInputs]
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VLVideoPixelInputs(TypedDict):
|
class Qwen2VLVideoPixelInputs(TensorSchema):
|
||||||
type: Literal["pixel_values_videos"]
|
|
||||||
pixel_values_videos: torch.Tensor
|
|
||||||
"""Shape:
|
|
||||||
`(num_patches,
|
|
||||||
num_channels * temporal_patch_size * patch_size * patch_size)`
|
|
||||||
"""
|
"""
|
||||||
|
Dimensions:
|
||||||
video_grid_thw: torch.Tensor
|
- np: The total number of patches over each video over each prompt in
|
||||||
"""Shape: `(num_videos, 3)`
|
the batch
|
||||||
|
- ctps: Number of channels * temporal_patch_size * patch_size *
|
||||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
patch_size
|
||||||
"""
|
- nv: Number of videos
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VLVideoEmbeddingInputs(TypedDict):
|
|
||||||
type: Literal["video_embeds"]
|
|
||||||
video_embeds: torch.Tensor
|
|
||||||
"""Supported types:
|
|
||||||
- list[`torch.Tensor`]: A list of tensors holding all videos' features.
|
|
||||||
Each tensor holds an video's features.
|
|
||||||
- `torch.Tensor`: A tensor holding all videos' features
|
|
||||||
(concatenation of all videos' feature tensors).
|
|
||||||
|
|
||||||
Tensor shape: `(num_image_features, hidden_size)`
|
Historical context:
|
||||||
- `num_image_features` varies based on
|
- pixel_values_videos shape: (num_patches, num_channels *
|
||||||
the number and resolution of the videos.
|
temporal_patch_size * patch_size * patch_size)
|
||||||
- `hidden_size` must match the hidden size of language model backbone.
|
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
|
||||||
|
format
|
||||||
"""
|
"""
|
||||||
|
type: Literal["pixel_values_videos"]
|
||||||
|
|
||||||
video_grid_thw: torch.Tensor
|
pixel_values_videos: Annotated[
|
||||||
"""Shape: `(num_videos, 3)`
|
torch.Tensor,
|
||||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
TensorShape("np", "ctps"),
|
||||||
|
]
|
||||||
|
|
||||||
|
video_grid_thw: Annotated[
|
||||||
|
torch.Tensor,
|
||||||
|
TensorShape("nv", 3),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
|
||||||
"""
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- nf: Number of video features
|
||||||
|
- hs: Hidden size
|
||||||
|
- nv: Number of videos
|
||||||
|
|
||||||
|
Historical context:
|
||||||
|
- video_embeds shape: (num_video_features, hidden_size)
|
||||||
|
- num_video_features varies based on the number and resolution of the
|
||||||
|
videos.
|
||||||
|
- hidden_size must match the hidden size of language model backbone.
|
||||||
|
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
|
||||||
|
format
|
||||||
|
"""
|
||||||
|
type: Literal["video_embeds"]
|
||||||
|
|
||||||
|
video_embeds: Annotated[
|
||||||
|
torch.Tensor,
|
||||||
|
TensorShape("nf", "hs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
video_grid_thw: Annotated[
|
||||||
|
torch.Tensor,
|
||||||
|
TensorShape("nv", 3),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs,
|
Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs,
|
||||||
@ -1126,10 +1168,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||||
image_grid_thw, "image grid_thw")
|
image_grid_thw, "image grid_thw")
|
||||||
|
|
||||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
|
||||||
raise ValueError("Incorrect type of image pixel values. "
|
|
||||||
f"Got type: {type(pixel_values)}")
|
|
||||||
|
|
||||||
return Qwen2VLImagePixelInputs(type="pixel_values",
|
return Qwen2VLImagePixelInputs(type="pixel_values",
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
image_grid_thw=image_grid_thw)
|
image_grid_thw=image_grid_thw)
|
||||||
@ -1140,9 +1178,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||||
image_grid_thw, "image grid_thw")
|
image_grid_thw, "image grid_thw")
|
||||||
|
|
||||||
if not isinstance(image_embeds, torch.Tensor):
|
|
||||||
raise ValueError("Incorrect type of image embeddings. "
|
|
||||||
f"Got type: {type(image_embeds)}")
|
|
||||||
return Qwen2VLImageEmbeddingInputs(type="image_embeds",
|
return Qwen2VLImageEmbeddingInputs(type="image_embeds",
|
||||||
image_embeds=image_embeds,
|
image_embeds=image_embeds,
|
||||||
image_grid_thw=image_grid_thw)
|
image_grid_thw=image_grid_thw)
|
||||||
@ -1174,9 +1209,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||||
video_grid_thw, "video grid_thw")
|
video_grid_thw, "video grid_thw")
|
||||||
|
|
||||||
if not isinstance(video_embeds, torch.Tensor):
|
|
||||||
raise ValueError("Incorrect type of video embeddings. "
|
|
||||||
f"Got type: {type(video_embeds)}")
|
|
||||||
return Qwen2VLVideoEmbeddingInputs(type="video_embeds",
|
return Qwen2VLVideoEmbeddingInputs(type="video_embeds",
|
||||||
video_embeds=video_embeds,
|
video_embeds=video_embeds,
|
||||||
video_grid_thw=video_grid_thw)
|
video_grid_thw=video_grid_thw)
|
||||||
|
|||||||
@ -1592,10 +1592,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
|
|
||||||
# recorect dcp attn_out with lse.
|
# recorect dcp attn_out with lse.
|
||||||
if self.dcp_world_size > 1:
|
if self.dcp_world_size > 1:
|
||||||
assert lse is not None, (
|
|
||||||
"For a mla backend want to enable"
|
|
||||||
"DCP, it is mandatory that the corresponding decode attn"
|
|
||||||
"kernel return the softmax lse.")
|
|
||||||
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
|
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
|
||||||
|
|
||||||
# v_up projection
|
# v_up projection
|
||||||
|
|||||||
@ -133,6 +133,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
|||||||
|
|
||||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||||
|
|
||||||
|
can_return_lse_for_decode: bool = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
|
|||||||
@ -73,10 +73,8 @@ class TopKTopPSampler(nn.Module):
|
|||||||
self.forward = self.forward_native
|
self.forward = self.forward_native
|
||||||
else:
|
else:
|
||||||
self.forward = self.forward_native
|
self.forward = self.forward_native
|
||||||
if current_platform.is_tpu():
|
|
||||||
self.apply_top_k_top_p = apply_top_k_top_p_tpu
|
self.apply_top_k_top_p = apply_top_k_top_p
|
||||||
else:
|
|
||||||
self.apply_top_k_top_p = apply_top_k_top_p
|
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
@ -125,53 +123,6 @@ class TopKTopPSampler(nn.Module):
|
|||||||
return flashinfer_sample(logits.contiguous(), k, p, generators), None
|
return flashinfer_sample(logits.contiguous(), k, p, generators), None
|
||||||
|
|
||||||
|
|
||||||
def apply_top_k_top_p_tpu(
|
|
||||||
logits: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
p: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Apply top-k and top-p optimized for TPU.
|
|
||||||
|
|
||||||
This algorithm avoids using torch.scatter which is extremely slow on TPU.
|
|
||||||
This is achieved by finding a "cut-off" element in the original logit, and
|
|
||||||
after thresholding the logit using this cut-off, the remaining elements
|
|
||||||
shall constitute the top-p set.
|
|
||||||
|
|
||||||
Note: in the case of tie (i.e. multipple cut-off elements present in the
|
|
||||||
logit), all tie elements are included in the top-p set. In other words,
|
|
||||||
this function does not break ties. Instead, these tie tokens have equal
|
|
||||||
chance of being chosen during final sampling, so we can consider the tie
|
|
||||||
being broken then.
|
|
||||||
"""
|
|
||||||
probs = logits.softmax(dim=-1)
|
|
||||||
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
|
||||||
|
|
||||||
if k is not None:
|
|
||||||
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
|
|
||||||
top_k_count = top_k_count.unsqueeze(dim=1)
|
|
||||||
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
|
||||||
|
|
||||||
# Make sure the no top-k rows are no-op.
|
|
||||||
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
|
||||||
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
|
||||||
|
|
||||||
elements_to_discard = probs < top_k_cutoff
|
|
||||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
|
||||||
|
|
||||||
if p is not None:
|
|
||||||
cumprob = torch.cumsum(probs_sort, dim=-1)
|
|
||||||
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
|
||||||
top_p_mask[:, -1] = False # at least one
|
|
||||||
|
|
||||||
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
|
||||||
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
|
||||||
elements_to_discard = probs < top_p_cutoff
|
|
||||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def apply_top_k_top_p(
|
def apply_top_k_top_p(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
k: Optional[torch.Tensor],
|
k: Optional[torch.Tensor],
|
||||||
|
|||||||
@ -2,11 +2,12 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Sampler layer implementing TPU supported operations."""
|
"""Sampler layer implementing TPU supported operations."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
|
||||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
_SAMPLING_EPS = 1e-5
|
||||||
@ -17,7 +18,6 @@ class Sampler(nn.Module):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
# TODO(houseroad): Add support for logprobs_mode.
|
# TODO(houseroad): Add support for logprobs_mode.
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.topk_topp_sampler = TopKTopPSampler()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -65,13 +65,17 @@ class Sampler(nn.Module):
|
|||||||
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
||||||
|
|
||||||
# Apply top_k and/or top_p.
|
# Apply top_k and/or top_p.
|
||||||
random_sampled, _ = self.topk_topp_sampler(
|
logits = apply_top_k_top_p(
|
||||||
logits,
|
logits,
|
||||||
sampling_metadata.generators,
|
|
||||||
sampling_metadata.top_k,
|
sampling_metadata.top_k,
|
||||||
sampling_metadata.top_p,
|
sampling_metadata.top_p,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Random sample.
|
||||||
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
|
random_sampled = self.random_sample(probs,
|
||||||
|
sampling_metadata.generators)
|
||||||
|
|
||||||
sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS,
|
sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS,
|
||||||
greedy_sampled, random_sampled)
|
greedy_sampled, random_sampled)
|
||||||
return sampled
|
return sampled
|
||||||
@ -144,3 +148,66 @@ class Sampler(nn.Module):
|
|||||||
# Apply mask using boolean indexing (xla friendly)
|
# Apply mask using boolean indexing (xla friendly)
|
||||||
logits.masked_fill_(~valid_token_mask, -float("inf"))
|
logits.masked_fill_(~valid_token_mask, -float("inf"))
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
def random_sample(
|
||||||
|
self,
|
||||||
|
probs: torch.Tensor,
|
||||||
|
generators: dict[int, torch.Generator],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
q = torch.empty_like(probs)
|
||||||
|
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||||
|
# which is the common case, we first assume that every request does
|
||||||
|
# not have its own seed. Then, we overwrite the values for the requests
|
||||||
|
# that have their own seeds.
|
||||||
|
q.exponential_()
|
||||||
|
if generators:
|
||||||
|
for i, generator in generators.items():
|
||||||
|
q[i].exponential_(generator=generator)
|
||||||
|
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_top_k_top_p(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
k: Optional[torch.Tensor],
|
||||||
|
p: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply top-k and top-p optimized for TPU.
|
||||||
|
|
||||||
|
This algorithm avoids using torch.scatter which is extremely slow on TPU.
|
||||||
|
This is achieved by finding a "cut-off" element in the original logit, and
|
||||||
|
after thresholding the logit using this cut-off, the remaining elements
|
||||||
|
shall constitute the top-p set.
|
||||||
|
|
||||||
|
Note: in the case of tie (i.e. multipple cut-off elements present in the
|
||||||
|
logit), all tie elements are included in the top-p set. In other words,
|
||||||
|
this function does not break ties. Instead, these tie tokens have equal
|
||||||
|
chance of being chosen during final sampling, so we can consider the tie
|
||||||
|
being broken then.
|
||||||
|
"""
|
||||||
|
probs = logits.softmax(dim=-1)
|
||||||
|
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
||||||
|
|
||||||
|
if k is not None:
|
||||||
|
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
|
||||||
|
top_k_count = top_k_count.unsqueeze(dim=1)
|
||||||
|
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
||||||
|
|
||||||
|
# Make sure the no top-k rows are no-op.
|
||||||
|
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
||||||
|
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
||||||
|
|
||||||
|
elements_to_discard = probs < top_k_cutoff
|
||||||
|
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||||
|
|
||||||
|
if p is not None:
|
||||||
|
cumprob = torch.cumsum(probs_sort, dim=-1)
|
||||||
|
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
||||||
|
top_p_mask[:, -1] = False # at least one
|
||||||
|
|
||||||
|
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
||||||
|
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
||||||
|
elements_to_discard = probs < top_p_cutoff
|
||||||
|
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|||||||
@ -55,7 +55,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
|||||||
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
||||||
get_dtype_size, is_pin_memory_available, round_up,
|
get_dtype_size, is_pin_memory_available, round_up,
|
||||||
supports_dynamo)
|
supports_dynamo)
|
||||||
from vllm.v1.attention.backends.mla.flashmla import FlashMLABackend
|
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||||
create_fast_prefill_custom_backend)
|
create_fast_prefill_custom_backend)
|
||||||
@ -1343,16 +1342,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
|
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
|
||||||
num_input_tokens += num_pad
|
num_input_tokens += num_pad
|
||||||
|
|
||||||
# _prepare_inputs decides the order of the requests, so we must gather
|
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||||
# multimodal outputs after that.
|
# modal outputs after that to ensure the correct order
|
||||||
if self.supports_mm_inputs:
|
if self.supports_mm_inputs and get_pp_group().is_first_rank:
|
||||||
# Run the multimodal encoder if any.
|
# Run the multimodal encoder if any.
|
||||||
self._execute_mm_encoder(scheduler_output)
|
self._execute_mm_encoder(scheduler_output)
|
||||||
mm_embeds = self._gather_mm_embeddings(input_batch)
|
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||||
else:
|
|
||||||
mm_embeds = []
|
|
||||||
|
|
||||||
if self.supports_mm_inputs and get_pp_group().is_first_rank:
|
|
||||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||||
# embeddings), we always use embeddings (rather than token ids)
|
# embeddings), we always use embeddings (rather than token ids)
|
||||||
# as input to the multimodal model, even when the input is text.
|
# as input to the multimodal model, even when the input is text.
|
||||||
@ -3066,10 +3062,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
copy_kv_blocks)
|
copy_kv_blocks)
|
||||||
|
|
||||||
if self.dcp_world_size > 1:
|
if self.dcp_world_size > 1:
|
||||||
assert self.attn_groups[0][0].backend is FlashMLABackend, (
|
layer_names = self.attn_groups[0][0].layer_names
|
||||||
"DCP only support flashmla now."
|
layers = get_layers_from_vllm_config(self.vllm_config,
|
||||||
"For a mla backend want to enable DCP, it is mandatory that the"
|
AttentionLayerBase,
|
||||||
"corresponding decode attn kernel return the softmax lse.")
|
layer_names)
|
||||||
|
for layer in layers.values():
|
||||||
|
assert layer.impl.need_to_return_lse_for_decode, (
|
||||||
|
"DCP requires attention impls to return"
|
||||||
|
" the softmax lse for decode, but the impl "
|
||||||
|
f"{layer.impl.__class__.__name__} "
|
||||||
|
"does not return the softmax lse for decode.")
|
||||||
|
|
||||||
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user