diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index b35511293539..7c709b6097fd 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -5,7 +5,7 @@ steps: - label: ":docker: build image" commands: - - "docker build --tag {{ docker_image }} --target test --progress plain ." + - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ." - "docker push {{ docker_image }}" env: DOCKER_BUILDKIT: "1" diff --git a/csrc/punica/bgmv/bgmv_all.cu b/csrc/punica/bgmv/bgmv_all.cu deleted file mode 100644 index 2502a67e3c81..000000000000 --- a/csrc/punica/bgmv/bgmv_all.cu +++ /dev/null @@ -1,21 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu new file mode 100644 index 000000000000..c642e94925fe --- /dev/null +++ b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu new file mode 100644 index 000000000000..e8202dff561d --- /dev/null +++ b/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu new file mode 100644 index 000000000000..3e7cf31dead0 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu new file mode 100644 index 000000000000..68277fa6b7d5 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu new file mode 100644 index 000000000000..0607cebfeac4 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu new file mode 100644 index 000000000000..3b7531b8fbcf --- /dev/null +++ b/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu new file mode 100644 index 000000000000..b3b74aa3ec90 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu new file mode 100644 index 000000000000..3cc87f5df76a --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu new file mode 100644 index 000000000000..9eda98bd8ddc --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu new file mode 100644 index 000000000000..f1db6df5f733 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu new file mode 100644 index 000000000000..060f9ebb8c2b --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu new file mode 100644 index 000000000000..c01ddd009d74 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu new file mode 100644 index 000000000000..f45183ffd348 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu new file mode 100644 index 000000000000..b37e44570bf4 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu new file mode 100644 index 000000000000..06718cbb0a3e --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu new file mode 100644 index 000000000000..409774348808 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu new file mode 100644 index 000000000000..41fb0e45ef4e --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu new file mode 100644 index 000000000000..50b7ead9fcef --- /dev/null +++ b/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu @@ -0,0 +1,4 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half) diff --git a/csrc/punica/bgmv/generator.py b/csrc/punica/bgmv/generator.py new file mode 100644 index 000000000000..66de56d74f3e --- /dev/null +++ b/csrc/punica/bgmv/generator.py @@ -0,0 +1,27 @@ +DTYPES = ["fp16", "bf16", "fp32"] +DTYPE_MAP = { + "fp16": "nv_half", + "bf16": "nv_bfloat16", + "fp32": "float", +} + +TEMPLATE = """ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype}) +""".lstrip() + +for input_dtype in DTYPES: + for output_dtype in DTYPES: + for weight_dtype in DTYPES: + if weight_dtype == "fp32": + # FP32 weights are not supported. + continue + kernel_definition = TEMPLATE.format( + input_dtype=DTYPE_MAP[input_dtype], + output_dtype=DTYPE_MAP[output_dtype], + weight_dtype=DTYPE_MAP[weight_dtype]) + filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu" + with open(filename, "w") as f: + f.write(kernel_definition)