mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
[unrevert] Add batch invariant kernel override for FlashInfer backend [2/n] (#26373)
Signed-off-by: Bram Wasti <bwasti@meta.com> Signed-off-by: Bram Wasti <bwasti@fb.com>
This commit is contained in:
parent
8e67b2557a
commit
3263799056
@ -21,7 +21,6 @@
|
|||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include "../cuda_compat.h"
|
#include "../cuda_compat.h"
|
||||||
#include "../cub_helpers.h"
|
#include "../cub_helpers.h"
|
||||||
#include "../core/batch_invariant.hpp"
|
|
||||||
|
|
||||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||||
@ -406,8 +405,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
|
|||||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
|
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
|
||||||
static constexpr int VPT = Constants::VPT;
|
static constexpr int VPT = Constants::VPT;
|
||||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||||
const bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||||
const int num_warps = batch_invariant_launch ? 32 : (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
|
||||||
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||||
|
|
||||||
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
|
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
|
||||||
|
|||||||
@ -76,18 +76,21 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
|||||||
seed.
|
seed.
|
||||||
- Keep max_tokens and max_model_len bounded for speed and memory use.
|
- Keep max_tokens and max_model_len bounded for speed and memory use.
|
||||||
"""
|
"""
|
||||||
random.seed(12345)
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
# Allow overrides from environment (useful for CI tuning)
|
# Allow overrides from environment (useful for CI tuning)
|
||||||
# "facebook/opt-125m" is too small, doesn't reliably test determinism
|
# "facebook/opt-125m" is too small, doesn't reliably test determinism
|
||||||
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||||
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
|
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
|
||||||
batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "64"))
|
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
|
||||||
assert batch_size >= 2, "Batch size should be >= 2 to mix needle."
|
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
|
||||||
|
max_random_prompt = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
|
||||||
|
assert max_batch_size >= 2, "Batch size should be >= 2 to mix needle."
|
||||||
|
|
||||||
# Keep GPU memory usage low to avoid startup allocation failures.
|
# Keep GPU memory usage low to avoid startup allocation failures.
|
||||||
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.3"))
|
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4"))
|
||||||
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "4096"))
|
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120"))
|
||||||
swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4"))
|
swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4"))
|
||||||
|
|
||||||
# Sampling parameters: longer outputs with a more random-sounding
|
# Sampling parameters: longer outputs with a more random-sounding
|
||||||
@ -111,7 +114,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
|||||||
# Engine with bs=1 behavior
|
# Engine with bs=1 behavior
|
||||||
llm_bs1 = LLM_with_max_seqs(
|
llm_bs1 = LLM_with_max_seqs(
|
||||||
model=model,
|
model=model,
|
||||||
max_num_seqs=1,
|
max_num_seqs=max_batch_size,
|
||||||
gpu_memory_utilization=gpu_mem_util,
|
gpu_memory_utilization=gpu_mem_util,
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
swap_space=swap_space_gb,
|
swap_space=swap_space_gb,
|
||||||
@ -126,7 +129,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
|||||||
# Engine with larger batch limit (e.g., 64)
|
# Engine with larger batch limit (e.g., 64)
|
||||||
llm_bsN = LLM_with_max_seqs(
|
llm_bsN = LLM_with_max_seqs(
|
||||||
model=model,
|
model=model,
|
||||||
max_num_seqs=batch_size,
|
max_num_seqs=max_batch_size,
|
||||||
gpu_memory_utilization=gpu_mem_util,
|
gpu_memory_utilization=gpu_mem_util,
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
swap_space=swap_space_gb,
|
swap_space=swap_space_gb,
|
||||||
@ -135,15 +138,16 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
|||||||
mismatches = 0
|
mismatches = 0
|
||||||
|
|
||||||
for trial in range(num_trials):
|
for trial in range(num_trials):
|
||||||
# Create a batch of size `batch_size` and insert the needle at
|
# Create a batch of size `max_batch_size` and insert the needle at
|
||||||
# a random index
|
# a random index
|
||||||
prompts: list[str] = []
|
prompts: list[str] = []
|
||||||
|
batch_size = random.randint(max_batch_size // 2, max_batch_size)
|
||||||
needle_pos = random.randint(0, batch_size - 1)
|
needle_pos = random.randint(0, batch_size - 1)
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
if i == needle_pos:
|
if i == needle_pos:
|
||||||
prompts.append(needle_prompt)
|
prompts.append(needle_prompt)
|
||||||
else:
|
else:
|
||||||
prompts.append(_random_prompt())
|
prompts.append(_random_prompt(min_random_prompt, max_random_prompt))
|
||||||
|
|
||||||
# Generate with the larger-batch engine
|
# Generate with the larger-batch engine
|
||||||
outputs = llm_bsN.generate(prompts, sampling)
|
outputs = llm_bsN.generate(prompts, sampling)
|
||||||
@ -154,19 +158,20 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
|||||||
text = needle_output.outputs[0].text
|
text = needle_output.outputs[0].text
|
||||||
|
|
||||||
if text != baseline_text:
|
if text != baseline_text:
|
||||||
|
print(f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n")
|
||||||
mismatches += 1
|
mismatches += 1
|
||||||
|
|
||||||
passes = num_trials - mismatches
|
passes = num_trials - mismatches
|
||||||
# Dump how many passed vs failed
|
# Dump how many passed vs failed
|
||||||
print(
|
print(
|
||||||
f"[determinism] total={num_trials}, passed={passes}, "
|
f"[determinism] total={num_trials}, passed={passes}, "
|
||||||
f"failed={mismatches}, batch_size={batch_size}"
|
f"failed={mismatches}, max_batch_size={max_batch_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if mismatches > 0:
|
if mismatches > 0:
|
||||||
pytest.fail(
|
pytest.fail(
|
||||||
f"Nondeterministic outputs detected: {mismatches} failed out "
|
f"Nondeterministic outputs detected: {mismatches} failed out "
|
||||||
f"of {num_trials} trials (batch_size={batch_size})."
|
f"of {num_trials} trials (max_batch_size={max_batch_size})."
|
||||||
)
|
)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
@ -199,8 +204,13 @@ def _extract_step_logprobs(request_output):
|
|||||||
not torch.cuda.is_available(),
|
not torch.cuda.is_available(),
|
||||||
reason="Requires CUDA to match production inference path.",
|
reason="Requires CUDA to match production inference path.",
|
||||||
)
|
)
|
||||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
|
@pytest.mark.parametrize("backend", ["FLEX_ATTENTION", "FLASHINFER"])
|
||||||
# model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m")
|
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
||||||
|
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
|
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||||
|
|
||||||
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
|
random.seed(seed)
|
||||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||||
|
|
||||||
@ -208,16 +218,14 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
|
|||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
tensor_parallel_size=tp_size,
|
tensor_parallel_size=tp_size,
|
||||||
enforce_eager=True, # helps reduce nondeterminism from some backends
|
enforce_eager=True,
|
||||||
|
enable_prefix_caching=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompts = [
|
prompts = [_random_prompt(10, 1024) for i in range(100)]
|
||||||
"The capital of France is",
|
|
||||||
"The capital of Germany is",
|
|
||||||
]
|
|
||||||
|
|
||||||
sp = SamplingParams(
|
sp = SamplingParams(
|
||||||
temperature=0.0,
|
temperature=0.6,
|
||||||
top_p=1.0,
|
top_p=1.0,
|
||||||
max_tokens=8,
|
max_tokens=8,
|
||||||
# Seed shouldn't matter at temperature=0, but keeping it stable anyway.
|
# Seed shouldn't matter at temperature=0, but keeping it stable anyway.
|
||||||
@ -238,11 +246,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
|
|||||||
)
|
)
|
||||||
bs1_logprobs_per_prompt.append(step_logprobs)
|
bs1_logprobs_per_prompt.append(step_logprobs)
|
||||||
|
|
||||||
# BS=2: run prompts in a batch and collect logprobs per step for each
|
# BS=N: run prompts in a batch and collect logprobs per step for each
|
||||||
# prompt.
|
# prompt.
|
||||||
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
|
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
|
||||||
assert len(outs_batched) == len(prompts)
|
assert len(outs_batched) == len(prompts)
|
||||||
bs2_logprobs_per_prompt = []
|
bsN_logprobs_per_prompt = []
|
||||||
for o in outs_batched:
|
for o in outs_batched:
|
||||||
step_logprobs = _extract_step_logprobs(o)
|
step_logprobs = _extract_step_logprobs(o)
|
||||||
if step_logprobs is None:
|
if step_logprobs is None:
|
||||||
@ -250,17 +258,17 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
|
|||||||
"Logits are not available on RequestOutput; "
|
"Logits are not available on RequestOutput; "
|
||||||
"enable logprobs return to run this test."
|
"enable logprobs return to run this test."
|
||||||
)
|
)
|
||||||
bs2_logprobs_per_prompt.append(step_logprobs)
|
bsN_logprobs_per_prompt.append(step_logprobs)
|
||||||
|
|
||||||
# Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs.
|
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
|
||||||
for i, (logprobs_bs1, logprobs_bs2) in enumerate(
|
for i, (logprobs_bs1, logprobs_bsN) in enumerate(
|
||||||
zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt)
|
zip(bs1_logprobs_per_prompt, bsN_logprobs_per_prompt)
|
||||||
):
|
):
|
||||||
assert len(logprobs_bs1) == len(logprobs_bs2), (
|
assert len(logprobs_bs1) == len(logprobs_bsN), (
|
||||||
f"Different number of generation steps for prompt index {i}: "
|
f"Different number of generation steps for prompt index {i}: "
|
||||||
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)"
|
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)"
|
||||||
)
|
)
|
||||||
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bs2)):
|
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
|
||||||
assert a.shape == b.shape, (
|
assert a.shape == b.shape, (
|
||||||
f"Logits shape mismatch at prompt {i}, step {t}: {a.shape} vs {b.shape}"
|
f"Logits shape mismatch at prompt {i}, step {t}: {a.shape} vs {b.shape}"
|
||||||
)
|
)
|
||||||
@ -297,6 +305,7 @@ def LLM_with_max_seqs(
|
|||||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||||
trust_remote_code=os.getenv("VLLM_TRUST_REMOTE_CODE", "0") == "1",
|
trust_remote_code=os.getenv("VLLM_TRUST_REMOTE_CODE", "0") == "1",
|
||||||
enable_prefix_caching=False,
|
enable_prefix_caching=False,
|
||||||
|
enforce_eager=True,
|
||||||
# Enable for MOE models
|
# Enable for MOE models
|
||||||
# enable_expert_parallel=True,
|
# enable_expert_parallel=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -8,8 +8,12 @@ from typing import Any
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _matmul_launch_metadata(
|
def _matmul_launch_metadata(
|
||||||
grid: Callable[..., Any], kernel: Any, args: dict[str, Any]
|
grid: Callable[..., Any], kernel: Any, args: dict[str, Any]
|
||||||
@ -562,5 +566,14 @@ def vllm_kernel_override_batch_invariant():
|
|||||||
def init_batch_invariance():
|
def init_batch_invariance():
|
||||||
# this will hit all the csrc overrides as well
|
# this will hit all the csrc overrides as well
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_kernel_override_batch_invariant():
|
||||||
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
|
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
|
||||||
|
supported_backends = ["FLEX_ATTENTION", "FLASHINFER"]
|
||||||
|
if curr_attn_backend not in supported_backends:
|
||||||
|
warning = (
|
||||||
|
"Forcibly updating attention backend to"
|
||||||
|
f" {supported_backends[0]} for batch_invariant. "
|
||||||
|
f" Supported backends: {supported_backends}."
|
||||||
|
)
|
||||||
|
logger.warning_once(warning)
|
||||||
|
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
|
||||||
enable_batch_invariant_mode()
|
enable_batch_invariant_mode()
|
||||||
|
|||||||
@ -25,6 +25,9 @@ from vllm.attention.backends.abstract import (
|
|||||||
)
|
)
|
||||||
from vllm.config import CUDAGraphMode, VllmConfig
|
from vllm.config import CUDAGraphMode, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
vllm_kernel_override_batch_invariant,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
QuantKey,
|
QuantKey,
|
||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
@ -50,6 +53,7 @@ from vllm.v1.attention.backends.utils import (
|
|||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||||
|
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
|
||||||
|
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
FP4_DTYPE = torch.uint8
|
FP4_DTYPE = torch.uint8
|
||||||
@ -288,6 +292,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self._prefill_wrapper = None # Wrapper for prefill/append
|
self._prefill_wrapper = None # Wrapper for prefill/append
|
||||||
self._decode_wrapper = None # Wrapper for decode (general shape)
|
self._decode_wrapper = None # Wrapper for decode (general shape)
|
||||||
|
|
||||||
|
if vllm_kernel_override_batch_invariant():
|
||||||
|
self.decode_fixed_split_size = 2048
|
||||||
|
self.prefill_fixed_split_size = 4096
|
||||||
|
self.disable_split_kv = True
|
||||||
|
else:
|
||||||
|
self.decode_fixed_split_size = -1
|
||||||
|
self.prefill_fixed_split_size = -1
|
||||||
|
self.disable_split_kv = False
|
||||||
|
|
||||||
self.compilation_config = vllm_config.compilation_config
|
self.compilation_config = vllm_config.compilation_config
|
||||||
max_num_pages_per_req = cdiv(
|
max_num_pages_per_req = cdiv(
|
||||||
self.model_config.max_model_len, self.kv_cache_spec.block_size
|
self.model_config.max_model_len, self.kv_cache_spec.block_size
|
||||||
@ -391,8 +404,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
|
|
||||||
def _get_workspace_buffer(self):
|
def _get_workspace_buffer(self):
|
||||||
if self._workspace_buffer is None:
|
if self._workspace_buffer is None:
|
||||||
|
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE
|
||||||
|
if vllm_kernel_override_batch_invariant():
|
||||||
|
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
|
||||||
self._workspace_buffer = torch.zeros(
|
self._workspace_buffer = torch.zeros(
|
||||||
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=self.device
|
buffer_size, dtype=torch.uint8, device=self.device
|
||||||
)
|
)
|
||||||
return self._workspace_buffer
|
return self._workspace_buffer
|
||||||
|
|
||||||
@ -669,6 +685,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
logits_soft_cap=self.logits_soft_cap,
|
logits_soft_cap=self.logits_soft_cap,
|
||||||
q_data_type=self.q_data_type,
|
q_data_type=self.q_data_type,
|
||||||
kv_data_type=self.kv_cache_dtype,
|
kv_data_type=self.kv_cache_dtype,
|
||||||
|
fixed_split_size=self.prefill_fixed_split_size,
|
||||||
|
disable_split_kv=self.disable_split_kv,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(
|
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(
|
||||||
@ -730,6 +748,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
logits_soft_cap=self.logits_soft_cap,
|
logits_soft_cap=self.logits_soft_cap,
|
||||||
q_data_type=self.q_data_type,
|
q_data_type=self.q_data_type,
|
||||||
kv_data_type=self.kv_cache_dtype,
|
kv_data_type=self.kv_cache_dtype,
|
||||||
|
fixed_split_size=self.decode_fixed_split_size,
|
||||||
|
disable_split_kv=self.disable_split_kv,
|
||||||
)
|
)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
@ -1121,6 +1141,8 @@ def fast_plan_decode(
|
|||||||
rope_scale: float | None = None,
|
rope_scale: float | None = None,
|
||||||
rope_theta: float | None = None,
|
rope_theta: float | None = None,
|
||||||
non_blocking: bool = True,
|
non_blocking: bool = True,
|
||||||
|
fixed_split_size: int = -1,
|
||||||
|
disable_split_kv: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
|
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
|
||||||
@ -1157,6 +1179,10 @@ def fast_plan_decode(
|
|||||||
rope_scale,
|
rope_scale,
|
||||||
rope_theta,
|
rope_theta,
|
||||||
non_blocking,
|
non_blocking,
|
||||||
|
None, # block_tables
|
||||||
|
None, # seq_lens
|
||||||
|
fixed_split_size,
|
||||||
|
disable_split_kv,
|
||||||
)
|
)
|
||||||
self.vllm_first_call = False
|
self.vllm_first_call = False
|
||||||
return
|
return
|
||||||
@ -1222,8 +1248,8 @@ def fast_plan_decode(
|
|||||||
head_dim,
|
head_dim,
|
||||||
False, # causal
|
False, # causal
|
||||||
window_left,
|
window_left,
|
||||||
-1, # fixed_split_size
|
fixed_split_size,
|
||||||
False, # disable_split_kv
|
disable_split_kv,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Error in tensor core plan: {e}") from e
|
raise RuntimeError(f"Error in tensor core plan: {e}") from e
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user