Speed up Punica compilation (#2632)

This commit is contained in:
Woosuk Kwon 2024-01-27 17:46:56 -08:00 committed by GitHub
parent 5f036d2bcc
commit f8ecb84c02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 100 additions and 22 deletions

View File

@ -5,7 +5,7 @@
steps:
- label: ":docker: build image"
commands:
- "docker build --tag {{ docker_image }} --target test --progress plain ."
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
- "docker push {{ docker_image }}"
env:
DOCKER_BUILDKIT: "1"

View File

@ -1,21 +0,0 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)

View File

@ -0,0 +1,27 @@
DTYPES = ["fp16", "bf16", "fp32"]
DTYPE_MAP = {
"fp16": "nv_half",
"bf16": "nv_bfloat16",
"fp32": "float",
}
TEMPLATE = """
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
""".lstrip()
for input_dtype in DTYPES:
for output_dtype in DTYPES:
for weight_dtype in DTYPES:
if weight_dtype == "fp32":
# FP32 weights are not supported.
continue
kernel_definition = TEMPLATE.format(
input_dtype=DTYPE_MAP[input_dtype],
output_dtype=DTYPE_MAP[output_dtype],
weight_dtype=DTYPE_MAP[weight_dtype])
filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu"
with open(filename, "w") as f:
f.write(kernel_definition)