mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-16 16:07:07 +08:00
merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
commit
286eeb91e8
104
benchmarks/kernels/benchmark_activation.py
Normal file
104
benchmarks/kernels/benchmark_activation.py
Normal file
@ -0,0 +1,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# benchmark custom activation op performance
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.activation # noqa F401
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||
|
||||
batch_size_range = [1, 16, 32, 64, 128]
|
||||
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
intermediate_size = [3072, 9728, 12288]
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size))
|
||||
|
||||
|
||||
def benchmark_activation(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
intermediate_size: int,
|
||||
provider: str,
|
||||
func_name: str,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
device = "cuda"
|
||||
num_tokens = batch_size * seq_len
|
||||
dim = intermediate_size
|
||||
current_platform.seed_everything(42)
|
||||
torch.set_default_device(device)
|
||||
|
||||
if func_name == "gelu_and_mul":
|
||||
layer = CustomOp.op_registry[func_name](approximate="none")
|
||||
elif func_name == "gelu_and_mul_tanh":
|
||||
layer = CustomOp.op_registry["gelu_and_mul"](approximate="tanh")
|
||||
elif func_name == "fatrelu_and_mul":
|
||||
threshold = 0.5
|
||||
layer = CustomOp.op_registry[func_name](threshold)
|
||||
else:
|
||||
layer = CustomOp.op_registry[func_name]()
|
||||
|
||||
x = torch.randn(num_tokens, dim, dtype=dtype, device=device)
|
||||
compiled_layer = torch.compile(layer.forward_native)
|
||||
|
||||
if provider == "custom":
|
||||
fn = lambda: layer(x)
|
||||
elif provider == "compiled":
|
||||
fn = lambda: compiled_layer(x)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
fn, quantiles=[0.5, 0.2, 0.8]
|
||||
)
|
||||
return ms, max_ms, min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(description="Benchmark the custom activation op.")
|
||||
parser.add_argument(
|
||||
"--func-name",
|
||||
type=str,
|
||||
choices=[
|
||||
"mul_and_silu",
|
||||
"silu_and_mul",
|
||||
"gelu_and_mul",
|
||||
"gelu_and_mul_tanh",
|
||||
"fatrelu_and_mul",
|
||||
"swigluoai_and_mul",
|
||||
"gelu_new",
|
||||
"gelu_fast",
|
||||
"quick_gelu",
|
||||
],
|
||||
default="silu_and_mul",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
assert args
|
||||
|
||||
func_name = args.func_name
|
||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype]
|
||||
|
||||
perf_report = triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len", "intermediate_size"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["custom", "compiled"],
|
||||
line_names=["Custom OP", "Compiled"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"{func_name}-op-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
|
||||
perf_report(
|
||||
lambda batch_size, seq_len, intermediate_size, provider: benchmark_activation(
|
||||
batch_size, seq_len, intermediate_size, provider, func_name, dtype
|
||||
)
|
||||
).run(print_data=True)
|
||||
@ -678,7 +678,11 @@ def main(args: argparse.Namespace):
|
||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||
search_space = get_configs_compute_bound(is_fp16, block_quant_shape)
|
||||
print(f"Start tuning over {len(search_space)} configurations...")
|
||||
|
||||
if use_deep_gemm:
|
||||
raise ValueError(
|
||||
"Tuning with --use-deep-gemm is not supported as it only tunes Triton "
|
||||
"kernels. Please remove the flag."
|
||||
)
|
||||
start = time.time()
|
||||
configs = _distribute(
|
||||
"tune",
|
||||
|
||||
@ -73,17 +73,11 @@ async def test_zero_truncation_size(client: openai.AsyncOpenAI):
|
||||
"truncate_prompt_tokens": truncation_size
|
||||
}
|
||||
|
||||
with pytest.raises(openai.BadRequestError) as err:
|
||||
await client.post(path="embeddings", cast_to=object, body={**kwargs})
|
||||
response = await client.post(path="embeddings",
|
||||
cast_to=object,
|
||||
body={**kwargs})
|
||||
|
||||
assert err.value.status_code == 400
|
||||
error_details = err.value.response.json()["error"]
|
||||
|
||||
assert error_details["type"] == "BadRequestError"
|
||||
assert "This model's maximum context length is" in error_details["message"]
|
||||
assert "tokens in the input for embedding generation" in error_details[
|
||||
"message"]
|
||||
assert "Please reduce the length of the input" in error_details["message"]
|
||||
assert response["usage"]["prompt_tokens"] == truncation_size
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -130,6 +130,23 @@ class TestRenderPrompt:
|
||||
assert call_args.kwargs["truncation"] is True
|
||||
assert call_args.kwargs["max_length"] == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncation_negative(self, renderer, mock_async_tokenizer):
|
||||
# Test that negative truncation uses model's max_model_len
|
||||
mock_async_tokenizer.return_value = MockTokenizerResult(
|
||||
[101, 7592, 2088]) # Truncated to max_model_len
|
||||
renderer.async_tokenizer_pool[
|
||||
renderer.tokenizer] = mock_async_tokenizer
|
||||
|
||||
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
|
||||
max_length=200,
|
||||
truncate_prompt_tokens=-1)
|
||||
|
||||
assert len(results) == 1
|
||||
call_args = mock_async_tokenizer.call_args
|
||||
assert call_args.kwargs["truncation"] is True
|
||||
assert call_args.kwargs["max_length"] == 100 # model's max_model_len
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_truncation_last_elements(self, renderer):
|
||||
# Test that token truncation keeps the last N elements
|
||||
|
||||
@ -41,9 +41,8 @@ EAGLE_SPEC_CONFIG = {
|
||||
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
|
||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
|
||||
#FIXME: This test is flaky on CI thus disabled
|
||||
#("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto",
|
||||
# None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto",
|
||||
None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
|
||||
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
|
||||
("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None),
|
||||
@ -123,6 +122,7 @@ def test_structured_output(
|
||||
guided_decoding_backend=guided_decoding_backend,
|
||||
guided_decoding_disable_any_whitespace=(guided_decoding_backend
|
||||
in {"xgrammar", "guidance"}),
|
||||
seed=120,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
speculative_config=speculative_config)
|
||||
|
||||
|
||||
@ -6,8 +6,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
|
||||
apply_top_k_top_p_tpu)
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
|
||||
# isort: off
|
||||
from vllm.v1.sample.tpu.sampler import (apply_top_k_top_p as
|
||||
apply_top_k_top_p_tpu)
|
||||
# isort: on
|
||||
|
||||
if not current_platform.is_tpu():
|
||||
pytest.skip("This test needs a TPU.", allow_module_level=True)
|
||||
|
||||
@ -257,6 +257,32 @@ class AttentionLayer(Protocol):
|
||||
|
||||
class AttentionImpl(ABC, Generic[T]):
|
||||
|
||||
# Whether the attention impl can return the softmax lse for decode.
|
||||
# Some features like decode context parallelism require the softmax lse.
|
||||
can_return_lse_for_decode: bool = False
|
||||
|
||||
# some attention backends might not always want to return lse
|
||||
# even if they can return lse (for efficiency reasons)
|
||||
need_to_return_lse_for_decode: bool = False
|
||||
|
||||
dcp_world_size: int
|
||||
dcp_rank: int
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# use __new__ so that all subclasses will call this
|
||||
self = super().__new__(cls)
|
||||
try:
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \
|
||||
and self.can_return_lse_for_decode
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -74,6 +74,7 @@ DEFAULT_CONDA_PATTERNS = {
|
||||
"zmq",
|
||||
"nvidia",
|
||||
"pynvml",
|
||||
"flashinfer-python",
|
||||
}
|
||||
|
||||
DEFAULT_PIP_PATTERNS = {
|
||||
@ -89,6 +90,7 @@ DEFAULT_PIP_PATTERNS = {
|
||||
"zmq",
|
||||
"nvidia",
|
||||
"pynvml",
|
||||
"flashinfer-python",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -54,14 +54,11 @@ class ClassificationMixin(OpenAIServing):
|
||||
ctx.tokenizer = await self.engine_client.get_tokenizer(
|
||||
ctx.lora_request)
|
||||
|
||||
(
|
||||
ctx.request_prompts,
|
||||
ctx.engine_prompts,
|
||||
) = await self._preprocess_completion(
|
||||
ctx.request,
|
||||
ctx.tokenizer,
|
||||
ctx.request.input,
|
||||
)
|
||||
renderer = self._get_renderer(ctx.tokenizer)
|
||||
ctx.engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=ctx.request.input,
|
||||
max_length=self.max_model_len,
|
||||
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@ -24,7 +24,6 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
|
||||
ErrorResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
|
||||
OpenAIServing,
|
||||
RequestPrompt,
|
||||
ServeContext,
|
||||
TextTokensPrompt)
|
||||
# yapf: enable
|
||||
@ -79,11 +78,12 @@ class EmbeddingMixin(OpenAIServing):
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
|
||||
)
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
if isinstance(ctx.request, EmbeddingChatRequest):
|
||||
(
|
||||
_,
|
||||
ctx.request_prompts,
|
||||
_,
|
||||
ctx.engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
ctx.request,
|
||||
@ -98,13 +98,18 @@ class EmbeddingMixin(OpenAIServing):
|
||||
add_special_tokens=ctx.request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
(ctx.request_prompts,
|
||||
ctx.engine_prompts) = await self._preprocess_completion(
|
||||
ctx.request,
|
||||
tokenizer,
|
||||
ctx.request.input,
|
||||
add_special_tokens=ctx.request.add_special_tokens,
|
||||
)
|
||||
# Set max_length based on chunked processing capability
|
||||
if self._should_use_chunked_processing(ctx.request):
|
||||
max_length = None
|
||||
else:
|
||||
max_length = self.max_embed_len or self.max_model_len
|
||||
|
||||
ctx.engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=ctx.request.input,
|
||||
max_length=max_length,
|
||||
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens,
|
||||
add_special_tokens=ctx.request.add_special_tokens,
|
||||
)
|
||||
return None
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
@ -286,7 +291,6 @@ class EmbeddingMixin(OpenAIServing):
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
||||
request_prompt: RequestPrompt,
|
||||
pooling_params: PoolingParams,
|
||||
trace_headers: Optional[Mapping[str, str]],
|
||||
prompt_index: int,
|
||||
@ -295,7 +299,7 @@ class EmbeddingMixin(OpenAIServing):
|
||||
request_id_item = f"{ctx.request_id}-{prompt_index}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
request_prompt,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request)
|
||||
|
||||
@ -353,20 +357,14 @@ class EmbeddingMixin(OpenAIServing):
|
||||
return self.create_error_response(
|
||||
"Engine prompts not available")
|
||||
|
||||
if ctx.request_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Request prompts not available")
|
||||
|
||||
max_pos_embeddings = self._get_max_position_embeddings()
|
||||
|
||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||
request_prompt = ctx.request_prompts[i]
|
||||
|
||||
# Check if this specific prompt needs chunked processing
|
||||
if self._is_text_tokens_prompt(request_prompt):
|
||||
if self._is_text_tokens_prompt(engine_prompt):
|
||||
# Cast to TextTokensPrompt since we've verified
|
||||
# prompt_token_ids
|
||||
text_tokens_prompt = cast(TextTokensPrompt, request_prompt)
|
||||
text_tokens_prompt = cast(TextTokensPrompt, engine_prompt)
|
||||
if (len(text_tokens_prompt["prompt_token_ids"])
|
||||
> max_pos_embeddings):
|
||||
# Use chunked processing for this prompt
|
||||
@ -382,8 +380,7 @@ class EmbeddingMixin(OpenAIServing):
|
||||
Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
||||
engine_prompt)
|
||||
generator = await self._create_single_prompt_generator(
|
||||
ctx, engine_prompt_typed, request_prompt, pooling_params,
|
||||
trace_headers, i)
|
||||
ctx, engine_prompt_typed, pooling_params, trace_headers, i)
|
||||
generators.append(generator)
|
||||
|
||||
from vllm.utils import merge_async_iterators
|
||||
@ -419,10 +416,6 @@ class EmbeddingMixin(OpenAIServing):
|
||||
if not use_chunked:
|
||||
return await super()._collect_batch(ctx=ctx)
|
||||
|
||||
if ctx.request_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Request prompts not available")
|
||||
|
||||
if ctx.result_generator is None:
|
||||
return self.create_error_response(
|
||||
"Result generator not available")
|
||||
@ -538,7 +531,7 @@ class EmbeddingMixin(OpenAIServing):
|
||||
data=final_embedding)
|
||||
|
||||
# Get original prompt token IDs for this prompt
|
||||
original_prompt = ctx.request_prompts[prompt_idx]
|
||||
original_prompt = ctx.engine_prompts[prompt_idx]
|
||||
if not self._is_text_tokens_prompt(original_prompt):
|
||||
return self.create_error_response(
|
||||
f"Chunked prompt {prompt_idx} is not a "
|
||||
|
||||
@ -368,23 +368,20 @@ class OpenAIServing:
|
||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||
request_id_item = f"{ctx.request_id}-{i}"
|
||||
|
||||
if ctx.request_prompts is None:
|
||||
return self.create_error_response(
|
||||
"Request prompts not available")
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
ctx.request_prompts[i],
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request,
|
||||
)
|
||||
|
||||
# Mypy has an existing bug related to inferring the variance of
|
||||
# TypedDicts with `builtins.enumerate`:
|
||||
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
|
||||
engine_prompt = cast(
|
||||
Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
||||
engine_prompt)
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=ctx.lora_request,
|
||||
)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
|
||||
@ -108,10 +108,15 @@ class CompletionRenderer(BaseRenderer):
|
||||
for detailed parameter documentation.
|
||||
"""
|
||||
if truncate_prompt_tokens is not None:
|
||||
if max_length is not None:
|
||||
assert 0 <= truncate_prompt_tokens <= max_length
|
||||
if truncate_prompt_tokens == 0:
|
||||
return []
|
||||
if truncate_prompt_tokens < 0:
|
||||
truncate_prompt_tokens = self.model_config.max_model_len
|
||||
if max_length is not None and truncate_prompt_tokens > max_length:
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
|
||||
f"cannot be greater than max_length ({max_length}). "
|
||||
f"Please select a smaller truncation size.")
|
||||
|
||||
# Parse and batch the input prompts
|
||||
batch_inputs = parse_and_batch_prompt(prompt_or_prompts)
|
||||
|
||||
@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@ -9,16 +9,16 @@
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
@ -26,15 +26,15 @@
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
@ -42,7 +42,7 @@
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
@ -53,12 +53,12 @@
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
@ -82,10 +82,10 @@
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
@ -98,8 +98,8 @@
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
@ -107,7 +107,7 @@
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
@ -115,7 +115,7 @@
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
@ -123,15 +123,15 @@
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
|
||||
@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@ -18,18 +18,18 @@
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
@ -58,7 +58,7 @@
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
@ -74,73 +74,73 @@
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
|
||||
@ -25,7 +25,7 @@
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -41,15 +41,13 @@ from transformers.models.whisper import WhisperFeatureExtractor
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.qwen2_5_vl import (
|
||||
Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs,
|
||||
Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs,
|
||||
Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs,
|
||||
Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs)
|
||||
from vllm.model_executor.models.qwen2_audio import (
|
||||
Qwen2AudioFeatureInputs, Qwen2AudioProcessingInfo,
|
||||
_get_feat_extract_output_lengths)
|
||||
Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths)
|
||||
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -66,9 +64,9 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
@ -81,6 +79,26 @@ except (ImportError, ModuleNotFoundError):
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Qwen2_5OmniAudioFeatureInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- na: Number of audios
|
||||
- nmb: Number of mel bins
|
||||
- msl: Maximum sequence length
|
||||
- tsl: Total sequence length
|
||||
"""
|
||||
type: Literal["audio_features"]
|
||||
input_features: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("nmb", "tsl"),
|
||||
]
|
||||
|
||||
feature_attention_mask: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("na", "msl"),
|
||||
]
|
||||
|
||||
|
||||
def create_qwen2_5_omni_thinker_field_factory(
|
||||
spatial_merge_size: int
|
||||
) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str,
|
||||
@ -536,7 +554,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
return torch.concat(mm_input, dim=dim)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object) -> Optional[Qwen2AudioFeatureInputs]:
|
||||
self, **kwargs: object) -> Optional[Qwen2_5OmniAudioFeatureInputs]:
|
||||
input_audio_features = kwargs.pop('input_audio_features', None)
|
||||
audio_feature_lengths = kwargs.pop('audio_feature_lengths', None)
|
||||
feature_attention_mask = kwargs.pop('feature_attention_mask', None)
|
||||
@ -550,7 +568,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
if not isinstance(input_audio_features, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio input features. "
|
||||
f"Got type: {type(input_audio_features)}")
|
||||
return Qwen2AudioFeatureInputs(
|
||||
return Qwen2_5OmniAudioFeatureInputs(
|
||||
type="audio_features",
|
||||
input_features=input_audio_features,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
feature_attention_mask=feature_attention_mask)
|
||||
@ -633,7 +652,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
|
||||
def _process_audio_input(
|
||||
self,
|
||||
audio_input: Qwen2AudioFeatureInputs,
|
||||
audio_input: Qwen2_5OmniAudioFeatureInputs,
|
||||
audio_hashes: list[str] = None,
|
||||
cached_audio_features: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
@ -660,8 +679,8 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
feature_lens=audio_feature_lengths,
|
||||
aftercnn_lens=audio_feat_lengths,
|
||||
)
|
||||
audio_features = audio_outputs.last_hidden_state
|
||||
return audio_features.split(audio_output_lengths.tolist())
|
||||
return audio_outputs.last_hidden_state.split(
|
||||
audio_output_lengths.tolist())
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
@ -707,7 +726,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder,
|
||||
)
|
||||
class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
nn.Module, SupportsMultiModal, SupportsPP,
|
||||
Qwen2_5OmniConditionalGenerationMixin):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
@ -800,15 +819,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""Get module prefix for multimodal models to filter LoRA modules."""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector=[], # No explicit connector in this model
|
||||
tower_model=["visual",
|
||||
"audio_tower"], # Exclude vision and audio towers
|
||||
)
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@
|
||||
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable, Mapping
|
||||
from functools import lru_cache, partial
|
||||
from typing import Callable, Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Callable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -64,6 +64,7 @@ from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||
@ -80,84 +81,125 @@ logger = init_logger(__name__)
|
||||
# === Vision Inputs === #
|
||||
|
||||
|
||||
class Qwen2_5_VLImagePixelInputs(TypedDict):
|
||||
class Qwen2_5_VLImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- np: Number of patches
|
||||
- ni: Number of images
|
||||
- cps: Number of channels * patch_size * patch_size
|
||||
|
||||
Historical context:
|
||||
- pixel_values shape: (num_patches, num_channels * patch_size *
|
||||
patch_size)
|
||||
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
|
||||
formatnum_channels * patch_size * patch_size
|
||||
"""
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: torch.Tensor
|
||||
"""Shape:
|
||||
`(num_patches, num_channels * patch_size * patch_size)`
|
||||
|
||||
pixel_values: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("np", "cps"),
|
||||
]
|
||||
|
||||
image_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("ni", 3),
|
||||
]
|
||||
|
||||
|
||||
class Qwen2_5_VLImageEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
|
||||
image_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_images, 3)`
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
Dimensions:
|
||||
- nf: Number of image features
|
||||
- hs: Hidden size
|
||||
- ni: Number of images
|
||||
|
||||
Historical context:
|
||||
- image_embeds shape: (num_image_features, hidden_size)
|
||||
- num_image_features varies based on the number and resolution of the
|
||||
images.
|
||||
- hidden_size must match the hidden size of language model backbone.
|
||||
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
|
||||
format
|
||||
"""
|
||||
|
||||
|
||||
class Qwen2_5_VLImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
image_embeds: torch.Tensor
|
||||
"""Supported types:
|
||||
- list[`torch.Tensor`]: A list of tensors holding all images' features.
|
||||
Each tensor holds an image's features.
|
||||
- `torch.Tensor`: A tensor holding all images' features
|
||||
(concatenation of all images' feature tensors).
|
||||
|
||||
Tensor shape: `(num_image_features, hidden_size)`
|
||||
- `num_image_features` varies based on
|
||||
the number and resolution of the images.
|
||||
- `hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
image_embeds: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nf", "hs"),
|
||||
]
|
||||
|
||||
image_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_images, 3)`
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
"""
|
||||
image_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("ni", 3),
|
||||
]
|
||||
|
||||
|
||||
Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs,
|
||||
Qwen2_5_VLImageEmbeddingInputs]
|
||||
|
||||
|
||||
class Qwen2_5_VLVideoPixelInputs(TypedDict):
|
||||
class Qwen2_5_VLVideoPixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- np: Number of patches
|
||||
- nv: Number of videos
|
||||
- ctps: Number of channels * temporal_patch_size * patch_size *
|
||||
patch_size
|
||||
|
||||
Historical context:
|
||||
- pixel_values_videos shape: (num_patches, num_channels *
|
||||
temporal_patch_size * patch_size * patch_size)
|
||||
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
|
||||
format
|
||||
- second_per_grid_ts: The video time interval (in seconds) for each
|
||||
grid along the temporal dimension in the 3D position IDs. Returned
|
||||
when `videos` is not `None`.
|
||||
"""
|
||||
type: Literal["pixel_values_videos"]
|
||||
pixel_values_videos: torch.Tensor
|
||||
"""Shape:
|
||||
`(num_patches,
|
||||
num_channels * temporal_patch_size * patch_size * patch_size)`
|
||||
|
||||
pixel_values_videos: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("np", "ctps"),
|
||||
]
|
||||
|
||||
video_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nv", 3),
|
||||
]
|
||||
|
||||
second_per_grid_ts: Annotated[
|
||||
Optional[torch.Tensor],
|
||||
TensorShape("nv"),
|
||||
]
|
||||
|
||||
|
||||
class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
|
||||
video_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_videos, 3)`
|
||||
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
Dimensions:
|
||||
- nf: Number of video features
|
||||
- hs: Hidden size
|
||||
- nv: Number of videos
|
||||
|
||||
Historical context:
|
||||
- video_embeds shape: (num_video_features, hidden_size)
|
||||
- num_video_features varies based on the number and resolution of the
|
||||
videos.
|
||||
- hidden_size must match the hidden size of language model backbone.
|
||||
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
|
||||
format
|
||||
"""
|
||||
|
||||
second_per_grid_ts: torch.Tensor
|
||||
"""
|
||||
The video time interval (in seconds) for each grid along the temporal
|
||||
dimension in the 3D position IDs. Returned when `videos` is not `None`.
|
||||
"""
|
||||
|
||||
|
||||
class Qwen2_5_VLVideoEmbeddingInputs(TypedDict):
|
||||
type: Literal["video_embeds"]
|
||||
video_embeds: torch.Tensor
|
||||
"""Supported types:
|
||||
- list[`torch.Tensor`]: A list of tensors holding all videos' features.
|
||||
Each tensor holds an video's features.
|
||||
- `torch.Tensor`: A tensor holding all videos' features
|
||||
(concatenation of all videos' feature tensors).
|
||||
|
||||
Tensor shape: `(num_image_features, hidden_size)`
|
||||
- `num_image_features` varies based on
|
||||
the number and resolution of the videos.
|
||||
- `hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
video_embeds: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nf", "hs"),
|
||||
]
|
||||
|
||||
video_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_videos, 3)`
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
"""
|
||||
video_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nv", 3),
|
||||
]
|
||||
|
||||
|
||||
Qwen2_5_VLVideoInputs = Union[Qwen2_5_VLVideoPixelInputs,
|
||||
@ -936,10 +978,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw")
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return Qwen2_5_VLImagePixelInputs(type="pixel_values",
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw)
|
||||
@ -950,9 +988,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw")
|
||||
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
return Qwen2_5_VLImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
image_embeds=image_embeds,
|
||||
@ -973,7 +1008,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
pixel_values_videos, "video pixel values")
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw")
|
||||
|
||||
if second_per_grid_ts is not None and second_per_grid_ts.ndim == 2:
|
||||
second_per_grid_ts = second_per_grid_ts.squeeze(-1)
|
||||
return Qwen2_5_VLVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
@ -987,9 +1023,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw")
|
||||
|
||||
if not isinstance(video_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of video embeddings. "
|
||||
f"Got type: {type(video_embeds)}")
|
||||
return Qwen2_5_VLVideoEmbeddingInputs(
|
||||
type="video_embeds",
|
||||
video_embeds=video_embeds,
|
||||
|
||||
@ -23,7 +23,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -47,6 +47,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
@ -54,21 +55,38 @@ from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
|
||||
|
||||
# # === Audio Inputs === #
|
||||
class Qwen2AudioFeatureInputs(TypedDict):
|
||||
type: Literal["audio_features"]
|
||||
input_features: torch.Tensor
|
||||
"""Shape: `(num_audios, num_mel_bins, 3000)`"""
|
||||
|
||||
feature_attention_mask: torch.Tensor
|
||||
"""Shape: `(num_audios, 3000)`"""
|
||||
|
||||
|
||||
class Qwen2AudioEmbeddingInputs(TypedDict):
|
||||
type: Literal["audio_embeds"]
|
||||
audio_embeds: list[torch.Tensor]
|
||||
"""Shape: `(num_audio_features, hidden_size)`
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
class Qwen2AudioFeatureInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- na: Number of audios
|
||||
- nmb: Number of mel bins
|
||||
"""
|
||||
type: Literal["audio_features"]
|
||||
input_features: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("na", "nmb", 3000),
|
||||
]
|
||||
|
||||
feature_attention_mask: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("na", 3000),
|
||||
]
|
||||
|
||||
|
||||
class Qwen2AudioEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size
|
||||
- naf: Number of audio features
|
||||
- hs: Hidden size (must match the hidden size of language model
|
||||
backbone)
|
||||
"""
|
||||
type: Literal["audio_embeds"] = "audio_embeds"
|
||||
|
||||
audio_embeds: Annotated[
|
||||
list[torch.Tensor],
|
||||
TensorShape("bn", "naf", "hs"),
|
||||
]
|
||||
|
||||
|
||||
Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs]
|
||||
|
||||
@ -26,7 +26,7 @@
|
||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -70,6 +70,7 @@ from vllm.platforms import _Backend, current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
@ -86,78 +87,119 @@ _MAX_FRAMES_PER_VIDEO = 16
|
||||
# === Vision Inputs === #
|
||||
|
||||
|
||||
class Qwen2VLImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: torch.Tensor
|
||||
"""Shape:
|
||||
`(num_patches, num_channels * patch_size * patch_size)`
|
||||
class Qwen2VLImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
|
||||
image_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_images, 3)`
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
"""
|
||||
|
||||
|
||||
class Qwen2VLImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
image_embeds: torch.Tensor
|
||||
"""Supported types:
|
||||
- list[`torch.Tensor`]: A list of tensors holding all images' features.
|
||||
Each tensor holds an image's features.
|
||||
- `torch.Tensor`: A tensor holding all images' features
|
||||
(concatenation of all images' feature tensors).
|
||||
Dimensions:
|
||||
- np: The total number of patches over each image over each prompt in
|
||||
the batch
|
||||
- ni: Number of images
|
||||
- cps: Number of channels * patch_size * patch_size
|
||||
|
||||
Tensor shape: `(num_image_features, hidden_size)`
|
||||
- `num_image_features` varies based on
|
||||
the number and resolution of the images.
|
||||
- `hidden_size` must match the hidden size of language model backbone.
|
||||
Historical context:
|
||||
- pixel_values shape: (num_patches, num_channels * patch_size *
|
||||
patch_size)
|
||||
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
|
||||
format
|
||||
"""
|
||||
type: Literal["pixel_values"]
|
||||
|
||||
image_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_images, 3)`
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
pixel_values: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("np", "cps"),
|
||||
]
|
||||
|
||||
image_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("ni", 3),
|
||||
]
|
||||
|
||||
|
||||
class Qwen2VLImageEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- nf: Number of image features
|
||||
- hs: Hidden size
|
||||
- ni: Number of images
|
||||
|
||||
Historical context:
|
||||
- image_embeds shape: (num_image_features, hidden_size)
|
||||
- num_image_features varies based on the number and resolution of the
|
||||
images.
|
||||
- hidden_size must match the hidden size of language model backbone.
|
||||
- image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
|
||||
format
|
||||
"""
|
||||
type: Literal["image_embeds"]
|
||||
|
||||
image_embeds: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nf", "hs"),
|
||||
]
|
||||
|
||||
image_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("ni", 3),
|
||||
]
|
||||
|
||||
|
||||
Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs,
|
||||
Qwen2VLImageEmbeddingInputs]
|
||||
|
||||
|
||||
class Qwen2VLVideoPixelInputs(TypedDict):
|
||||
type: Literal["pixel_values_videos"]
|
||||
pixel_values_videos: torch.Tensor
|
||||
"""Shape:
|
||||
`(num_patches,
|
||||
num_channels * temporal_patch_size * patch_size * patch_size)`
|
||||
class Qwen2VLVideoPixelInputs(TensorSchema):
|
||||
"""
|
||||
|
||||
video_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_videos, 3)`
|
||||
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
"""
|
||||
|
||||
|
||||
class Qwen2VLVideoEmbeddingInputs(TypedDict):
|
||||
type: Literal["video_embeds"]
|
||||
video_embeds: torch.Tensor
|
||||
"""Supported types:
|
||||
- list[`torch.Tensor`]: A list of tensors holding all videos' features.
|
||||
Each tensor holds an video's features.
|
||||
- `torch.Tensor`: A tensor holding all videos' features
|
||||
(concatenation of all videos' feature tensors).
|
||||
Dimensions:
|
||||
- np: The total number of patches over each video over each prompt in
|
||||
the batch
|
||||
- ctps: Number of channels * temporal_patch_size * patch_size *
|
||||
patch_size
|
||||
- nv: Number of videos
|
||||
|
||||
Tensor shape: `(num_image_features, hidden_size)`
|
||||
- `num_image_features` varies based on
|
||||
the number and resolution of the videos.
|
||||
- `hidden_size` must match the hidden size of language model backbone.
|
||||
Historical context:
|
||||
- pixel_values_videos shape: (num_patches, num_channels *
|
||||
temporal_patch_size * patch_size * patch_size)
|
||||
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
|
||||
format
|
||||
"""
|
||||
type: Literal["pixel_values_videos"]
|
||||
|
||||
video_grid_thw: torch.Tensor
|
||||
"""Shape: `(num_videos, 3)`
|
||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
||||
pixel_values_videos: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("np", "ctps"),
|
||||
]
|
||||
|
||||
video_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nv", 3),
|
||||
]
|
||||
|
||||
|
||||
class Qwen2VLVideoEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- nf: Number of video features
|
||||
- hs: Hidden size
|
||||
- nv: Number of videos
|
||||
|
||||
Historical context:
|
||||
- video_embeds shape: (num_video_features, hidden_size)
|
||||
- num_video_features varies based on the number and resolution of the
|
||||
videos.
|
||||
- hidden_size must match the hidden size of language model backbone.
|
||||
- video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
|
||||
format
|
||||
"""
|
||||
type: Literal["video_embeds"]
|
||||
|
||||
video_embeds: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nf", "hs"),
|
||||
]
|
||||
|
||||
video_grid_thw: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("nv", 3),
|
||||
]
|
||||
|
||||
|
||||
Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs,
|
||||
@ -1126,10 +1168,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw")
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return Qwen2VLImagePixelInputs(type="pixel_values",
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw)
|
||||
@ -1140,9 +1178,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw")
|
||||
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
return Qwen2VLImageEmbeddingInputs(type="image_embeds",
|
||||
image_embeds=image_embeds,
|
||||
image_grid_thw=image_grid_thw)
|
||||
@ -1174,9 +1209,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw")
|
||||
|
||||
if not isinstance(video_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of video embeddings. "
|
||||
f"Got type: {type(video_embeds)}")
|
||||
return Qwen2VLVideoEmbeddingInputs(type="video_embeds",
|
||||
video_embeds=video_embeds,
|
||||
video_grid_thw=video_grid_thw)
|
||||
|
||||
@ -1592,10 +1592,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
|
||||
# recorect dcp attn_out with lse.
|
||||
if self.dcp_world_size > 1:
|
||||
assert lse is not None, (
|
||||
"For a mla backend want to enable"
|
||||
"DCP, it is mandatory that the corresponding decode attn"
|
||||
"kernel return the softmax lse.")
|
||||
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
|
||||
|
||||
# v_up projection
|
||||
|
||||
@ -133,6 +133,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
|
||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
|
||||
@ -73,10 +73,8 @@ class TopKTopPSampler(nn.Module):
|
||||
self.forward = self.forward_native
|
||||
else:
|
||||
self.forward = self.forward_native
|
||||
if current_platform.is_tpu():
|
||||
self.apply_top_k_top_p = apply_top_k_top_p_tpu
|
||||
else:
|
||||
self.apply_top_k_top_p = apply_top_k_top_p
|
||||
|
||||
self.apply_top_k_top_p = apply_top_k_top_p
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
@ -125,53 +123,6 @@ class TopKTopPSampler(nn.Module):
|
||||
return flashinfer_sample(logits.contiguous(), k, p, generators), None
|
||||
|
||||
|
||||
def apply_top_k_top_p_tpu(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-k and top-p optimized for TPU.
|
||||
|
||||
This algorithm avoids using torch.scatter which is extremely slow on TPU.
|
||||
This is achieved by finding a "cut-off" element in the original logit, and
|
||||
after thresholding the logit using this cut-off, the remaining elements
|
||||
shall constitute the top-p set.
|
||||
|
||||
Note: in the case of tie (i.e. multipple cut-off elements present in the
|
||||
logit), all tie elements are included in the top-p set. In other words,
|
||||
this function does not break ties. Instead, these tie tokens have equal
|
||||
chance of being chosen during final sampling, so we can consider the tie
|
||||
being broken then.
|
||||
"""
|
||||
probs = logits.softmax(dim=-1)
|
||||
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
|
||||
top_k_count = top_k_count.unsqueeze(dim=1)
|
||||
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
||||
|
||||
# Make sure the no top-k rows are no-op.
|
||||
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
||||
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
||||
|
||||
elements_to_discard = probs < top_k_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
if p is not None:
|
||||
cumprob = torch.cumsum(probs_sort, dim=-1)
|
||||
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
||||
top_p_mask[:, -1] = False # at least one
|
||||
|
||||
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
||||
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
||||
elements_to_discard = probs < top_p_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
|
||||
@ -2,11 +2,12 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Sampler layer implementing TPU supported operations."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
@ -17,7 +18,6 @@ class Sampler(nn.Module):
|
||||
def __init__(self):
|
||||
# TODO(houseroad): Add support for logprobs_mode.
|
||||
super().__init__()
|
||||
self.topk_topp_sampler = TopKTopPSampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -65,13 +65,17 @@ class Sampler(nn.Module):
|
||||
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
random_sampled, _ = self.topk_topp_sampler(
|
||||
logits = apply_top_k_top_p(
|
||||
logits,
|
||||
sampling_metadata.generators,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
|
||||
# Random sample.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
random_sampled = self.random_sample(probs,
|
||||
sampling_metadata.generators)
|
||||
|
||||
sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS,
|
||||
greedy_sampled, random_sampled)
|
||||
return sampled
|
||||
@ -144,3 +148,66 @@ class Sampler(nn.Module):
|
||||
# Apply mask using boolean indexing (xla friendly)
|
||||
logits.masked_fill_(~valid_token_mask, -float("inf"))
|
||||
return logits
|
||||
|
||||
def random_sample(
|
||||
self,
|
||||
probs: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
q = torch.empty_like(probs)
|
||||
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||
# which is the common case, we first assume that every request does
|
||||
# not have its own seed. Then, we overwrite the values for the requests
|
||||
# that have their own seeds.
|
||||
q.exponential_()
|
||||
if generators:
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-k and top-p optimized for TPU.
|
||||
|
||||
This algorithm avoids using torch.scatter which is extremely slow on TPU.
|
||||
This is achieved by finding a "cut-off" element in the original logit, and
|
||||
after thresholding the logit using this cut-off, the remaining elements
|
||||
shall constitute the top-p set.
|
||||
|
||||
Note: in the case of tie (i.e. multipple cut-off elements present in the
|
||||
logit), all tie elements are included in the top-p set. In other words,
|
||||
this function does not break ties. Instead, these tie tokens have equal
|
||||
chance of being chosen during final sampling, so we can consider the tie
|
||||
being broken then.
|
||||
"""
|
||||
probs = logits.softmax(dim=-1)
|
||||
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
|
||||
top_k_count = top_k_count.unsqueeze(dim=1)
|
||||
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
||||
|
||||
# Make sure the no top-k rows are no-op.
|
||||
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
||||
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
||||
|
||||
elements_to_discard = probs < top_k_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
if p is not None:
|
||||
cumprob = torch.cumsum(probs_sort, dim=-1)
|
||||
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
||||
top_p_mask[:, -1] = False # at least one
|
||||
|
||||
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
||||
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
||||
elements_to_discard = probs < top_p_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
return logits
|
||||
|
||||
@ -55,7 +55,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
||||
get_dtype_size, is_pin_memory_available, round_up,
|
||||
supports_dynamo)
|
||||
from vllm.v1.attention.backends.mla.flashmla import FlashMLABackend
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
create_fast_prefill_custom_backend)
|
||||
@ -1343,16 +1342,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
|
||||
num_input_tokens += num_pad
|
||||
|
||||
# _prepare_inputs decides the order of the requests, so we must gather
|
||||
# multimodal outputs after that.
|
||||
if self.supports_mm_inputs:
|
||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||
# modal outputs after that to ensure the correct order
|
||||
if self.supports_mm_inputs and get_pp_group().is_first_rank:
|
||||
# Run the multimodal encoder if any.
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
mm_embeds = self._gather_mm_embeddings(input_batch)
|
||||
else:
|
||||
mm_embeds = []
|
||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||
|
||||
if self.supports_mm_inputs and get_pp_group().is_first_rank:
|
||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
# as input to the multimodal model, even when the input is text.
|
||||
@ -3066,10 +3062,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
copy_kv_blocks)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
assert self.attn_groups[0][0].backend is FlashMLABackend, (
|
||||
"DCP only support flashmla now."
|
||||
"For a mla backend want to enable DCP, it is mandatory that the"
|
||||
"corresponding decode attn kernel return the softmax lse.")
|
||||
layer_names = self.attn_groups[0][0].layer_names
|
||||
layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
AttentionLayerBase,
|
||||
layer_names)
|
||||
for layer in layers.values():
|
||||
assert layer.impl.need_to_return_lse_for_decode, (
|
||||
"DCP requires attention impls to return"
|
||||
" the softmax lse for decode, but the impl "
|
||||
f"{layer.impl.__class__.__name__} "
|
||||
"does not return the softmax lse for decode.")
|
||||
|
||||
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user