diff --git a/0001-modify-operator-throughput-speed.patch b/0001-modify-operator-throughput-speed.patch new file mode 100644 index 0000000000000000000000000000000000000000..e7d04115e0d27e293295d3cf2ff096fa8617831d --- /dev/null +++ b/0001-modify-operator-throughput-speed.patch @@ -0,0 +1,352174 @@ +From 7f26a931153a32c5099746b6c18bbb88054ce70d Mon Sep 17 00:00:00 2001 +From: =?UTF-8?q?=E5=AD=99=E6=B5=B7=E4=BA=AE?= +Date: Mon, 17 Feb 2025 11:30:08 +0800 +Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96v0.6.2?= +MIME-Version: 1.0 +Content-Type: text/plain; charset=UTF-8 +Content-Transfer-Encoding: 8bit + +--- + .buildkite/check-wheel-size.py | 35 +- + .buildkite/generate_index.py | 24 + + .../configs/DeepSeek-V2-Lite-Chat.yaml | 12 + + ...lama-3-70B-Instruct-FBGEMM-nonuniform.yaml | 11 + + .../configs/Meta-Llama-3-70B-Instruct.yaml | 11 + + ...struct-Channelwise-compressed-tensors.yaml | 11 + + ...Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml | 11 + + ...-3-8B-Instruct-FP8-compressed-tensors.yaml | 11 + + .../configs/Meta-Llama-3-8B-Instruct-FP8.yaml | 11 + + ...Instruct-INT8-compressed-tensors-asym.yaml | 11 + + ...3-8B-Instruct-INT8-compressed-tensors.yaml | 11 + + ...nstruct-nonuniform-compressed-tensors.yaml | 11 + + .../configs/Meta-Llama-3-8B-Instruct.yaml | 11 + + .../configs/Meta-Llama-3-8B-QQQ.yaml | 11 + + ...2-1B-Instruct-INT8-compressed-tensors.yaml | 11 + + .../configs/Minitron-4B-Base-FP8.yaml | 11 + + ...xtral-8x22B-Instruct-v0.1-FP8-Dynamic.yaml | 11 + + .../Mixtral-8x7B-Instruct-v0.1-FP8.yaml | 11 + + .../configs/Mixtral-8x7B-Instruct-v0.1.yaml | 11 + + .../configs/Qwen2-1.5B-Instruct-FP8W8.yaml | 11 + + ...1.5B-Instruct-INT8-compressed-tensors.yaml | 11 + + ....5B-Instruct-W8A16-compressed-tensors.yaml | 11 + + .../configs/Qwen2-57B-A14-Instruct.yaml | 11 + + .../lm-eval-harness/configs/models-large.txt | 5 + + .../lm-eval-harness/configs/models-small.txt | 10 + + .../run-lm-eval-gsm-hf-baseline.sh | 46 + + .../run-lm-eval-gsm-vllm-baseline.sh | 51 + + .buildkite/lm-eval-harness/run-tests.sh | 59 + + .../test_lm_eval_correctness.py | 63 + + .buildkite/nightly-benchmarks/README.md | 153 + + .../benchmark-pipeline.yaml | 92 + + .../nightly-benchmarks/nightly-annotation.md | 28 + + .../nightly-descriptions.md | 39 + + .../nightly-benchmarks/nightly-pipeline.yaml | 196 + + .../performance-benchmarks-descriptions.md | 62 + + .../convert-results-json-to-markdown.py | 204 + + .../scripts/download-tokenizer.py | 26 + + .../scripts/generate-nightly-markdown.py | 95 + + .../scripts/get-lmdeploy-modelname.py | 6 + + .../scripts/launch-server.sh | 228 ++ + .../scripts/nightly-annotate.sh | 78 + + .../scripts/run-nightly-benchmarks.sh | 355 ++ + .../scripts/run-performance-benchmarks.sh | 377 ++ + .../scripts/summary-nightly-results.py | 83 + + .../scripts/wait-for-image.sh | 19 + + .../tests/latency-tests.json | 32 + + .../tests/nightly-tests.json | 323 ++ + .../tests/serving-tests.json | 80 + + .../tests/throughput-tests.json | 35 + + .buildkite/release-pipeline.yaml | 72 + + .buildkite/run-amd-test.sh | 144 +- + .buildkite/run-benchmarks.sh | 15 +- + .buildkite/run-cpu-test-ppc64le.sh | 14 + + .buildkite/run-cpu-test.sh | 82 +- + .buildkite/run-gh200-test.sh | 28 + + .buildkite/run-hpu-test.sh | 16 + + .buildkite/run-multi-node-test.sh | 108 + + .buildkite/run-neuron-test.sh | 59 +- + .buildkite/run-openvino-test.sh | 16 + + .buildkite/run-tpu-test.sh | 26 + + .buildkite/run-xpu-test.sh | 19 + + .buildkite/test-pipeline.yaml | 628 ++- + .buildkite/upload-wheels.sh | 71 + + .clang-format | 26 + + .dockerignore | 32 + + .github/CODEOWNERS | 33 + + .github/FUNDING.yml | 2 + + .github/ISSUE_TEMPLATE/100-documentation.yml | 7 + + .github/ISSUE_TEMPLATE/200-installation.yml | 7 + + .github/ISSUE_TEMPLATE/300-usage.yml | 7 + + .github/ISSUE_TEMPLATE/400-bug-report.yml | 107 + + .../ISSUE_TEMPLATE/500-feature-request.yml | 38 + + .github/ISSUE_TEMPLATE/600-new-model.yml | 40 + + .../700-performance-discussion.yml | 59 + + .github/ISSUE_TEMPLATE/750-RFC.yml | 7 + + .../ISSUE_TEMPLATE/800-misc-discussion.yml | 28 + + .github/PULL_REQUEST_TEMPLATE.md | 61 +- + .github/dependabot.yml | 31 + + .github/mergify.yml | 60 + + .github/scripts/cleanup_pr_body.sh | 50 + + .github/workflows/actionlint.yml | 40 + + .github/workflows/add_label_automerge.yml | 21 + + .github/workflows/clang-format.yml | 53 + + .github/workflows/cleanup_pr_body.yml | 26 + + .github/workflows/codespell.yml | 45 + + .github/workflows/doc-lint.yml | 32 + + .github/workflows/lint-and-deploy.yaml | 82 + + .github/workflows/matchers/actionlint.json | 17 + + .github/workflows/matchers/mypy.json | 16 + + .github/workflows/matchers/ruff.json | 17 + + .github/workflows/mypy.yaml | 43 +- + .github/workflows/png-lint.yml | 37 + + .github/workflows/publish.yml | 126 +- + .github/workflows/reminder_comment.yml | 21 + + .github/workflows/ruff.yml | 53 +- + .github/workflows/scripts/build.sh | 10 +- + .github/workflows/scripts/cuda-install.sh | 8 +- + .github/workflows/scripts/pytorch-install.sh | 2 +- + .github/workflows/shellcheck.yml | 37 + + .github/workflows/stale.yml | 52 + + .github/workflows/yapf.yml | 35 +- + .gitignore | 21 +- + .readthedocs.yaml | 12 +- + .shellcheckrc | 9 + + CMakeLists.txt | 539 ++- + CODE_OF_CONDUCT.md | 128 + + CONTRIBUTING.md | 55 +- + DCO | 34 + + Dockerfile | 239 +- + Dockerfile.arm | 62 + + Dockerfile.cpu | 67 +- + Dockerfile.hpu | 21 + + Dockerfile.neuron | 47 +- + Dockerfile.openvino | 29 + + Dockerfile.ppc64le | 38 + + Dockerfile.rocm | 189 +- + Dockerfile.tpu | 28 + + Dockerfile.xpu | 69 + + README.md | 149 +- + SECURITY.md | 11 + + benchmarks/README.md | 11 + + benchmarks/backend_request_func.py | 119 +- + benchmarks/benchmark_guided.py | 494 +++ + benchmarks/benchmark_latency.py | 141 +- + .../benchmark_long_document_qa_throughput.py | 183 + + benchmarks/benchmark_prefix_caching.py | 242 +- + benchmarks/benchmark_prioritization.py | 177 + + benchmarks/benchmark_serving.py | 852 +++- + benchmarks/benchmark_serving_guided.py | 881 +++++ + benchmarks/benchmark_throughput.py | 486 ++- + .../cutlass_benchmarks/sparse_benchmarks.py | 384 ++ + benchmarks/cutlass_benchmarks/utils.py | 96 + + .../cutlass_benchmarks/w8a8_benchmarks.py | 365 ++ + .../cutlass_benchmarks/weight_shapes.py | 43 + + .../disagg_overhead_benchmark.sh | 145 + + .../disagg_performance_benchmark.sh | 163 + + .../disagg_prefill_proxy_server.py | 61 + + .../disagg_benchmarks/round_robin_proxy.py | 60 + + .../visualize_benchmark_results.py | 46 + + .../fused_kernels/layernorm_rms_benchmarks.py | 173 + + benchmarks/kernels/benchmark_aqlm.py | 14 +- + benchmarks/kernels/benchmark_layernorm.py | 86 + + benchmarks/kernels/benchmark_machete.py | 672 ++++ + benchmarks/kernels/benchmark_marlin.py | 254 ++ + benchmarks/kernels/benchmark_moe.py | 367 ++ + .../kernels/benchmark_paged_attention.py | 45 +- + benchmarks/kernels/benchmark_quant.py | 100 + + benchmarks/kernels/benchmark_rmsnorm.py | 262 ++ + benchmarks/kernels/benchmark_rope.py | 22 +- + benchmarks/kernels/benchmark_shapes.py | 75 + + benchmarks/kernels/graph_machete_bench.py | 63 + + benchmarks/kernels/requirements.txt | 1 + + benchmarks/kernels/weight_shapes.py | 49 + + benchmarks/launch_tgi_server.sh | 10 +- + benchmarks/overheads/benchmark_hashing.py | 59 + + .../structured_schema_1.json | 113 + + cmake/cpu_extension.cmake | 164 +- + cmake/utils.cmake | 315 +- + collect_env.py | 64 +- + csrc/activation_kernels.cu | 195 +- + csrc/attention/attention_generic.cuh | 19 +- + csrc/attention/attention_kernels.cuh | 676 ++++ + csrc/attention/attention_utils.cuh | 13 +- + csrc/attention/dtype_bfloat16.cuh | 82 +- + csrc/attention/dtype_float16.cuh | 92 +- + csrc/attention/dtype_float32.cuh | 88 +- + csrc/attention/dtype_fp8.cuh | 36 +- + csrc/attention/paged_attention_v1.cu | 193 + + csrc/attention/paged_attention_v2.cu | 203 + + csrc/cache.h | 49 +- + csrc/cache_kernels.cu | 448 ++- + csrc/core/exception.hpp | 3 + + csrc/core/math.hpp | 7 + + csrc/core/registration.h | 27 + + csrc/core/scalar_type.hpp | 347 ++ + csrc/cpu/activation.cpp | 79 +- + csrc/cpu/attention.cpp | 476 ++- + csrc/cpu/cache.cpp | 77 +- + csrc/cpu/cpu_types.hpp | 357 +- + csrc/cpu/cpu_types_arm.hpp | 572 +++ + csrc/cpu/cpu_types_vsx.hpp | 491 +++ + csrc/cpu/cpu_types_x86.hpp | 632 +++ + csrc/cpu/dnnl_helper.hpp | 174 + + csrc/cpu/layernorm.cpp | 32 +- + csrc/cpu/pos_encoding.cpp | 166 +- + csrc/cpu/quant.cpp | 613 +++ + csrc/cpu/torch_bindings.cpp | 160 + + csrc/cpu/utils.cpp | 103 + + csrc/cuda_compat.h | 17 +- + csrc/cuda_utils.h | 17 +- + csrc/cuda_utils_kernels.cu | 40 +- + csrc/custom_all_reduce.cu | 150 +- + csrc/custom_all_reduce.cuh | 293 +- + csrc/custom_all_reduce_test.cu | 79 +- + csrc/cutlass_extensions/common.cpp | 11 + + csrc/cutlass_extensions/common.hpp | 35 + + csrc/cutlass_extensions/cute_utils.cuh | 68 + + .../epilogue/broadcast_load_epilogue_c2x.hpp | 497 +++ + .../epilogue/broadcast_load_epilogue_c3x.hpp | 447 +++ + .../epilogue/scaled_mm_epilogues_c2x.hpp | 319 ++ + .../epilogue/scaled_mm_epilogues_c3x.hpp | 317 ++ + csrc/cutlass_extensions/torch_utils.hpp | 160 + + .../vllm_collective_builder.cuh | 43 + + csrc/cutlass_extensions/vllm_custom_types.cuh | 50 + + .../vllm_cutlass_library_extension.py | 78 + + .../vllm_numeric_conversion.cuh | 992 +++++ + csrc/cutlass_extensions/vllm_type_utils.cuh | 42 + + csrc/dispatch_utils.h | 58 +- + csrc/layernorm_kernels.cu | 312 +- + csrc/layernorm_quant_kernels.cu | 234 ++ + csrc/mamba/causal_conv1d/causal_conv1d.cu | 662 ++++ + csrc/mamba/causal_conv1d/causal_conv1d.h | 159 + + csrc/mamba/causal_conv1d/static_switch.h | 28 + + csrc/mamba/mamba_ssm/selective_scan.h | 266 ++ + csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 658 ++++ + csrc/mamba/mamba_ssm/static_switch.h | 28 + + csrc/moe/marlin_kernels/marlin_moe_kernel.h | 1616 ++++++++ + .../marlin_kernels/marlin_moe_kernel_ku4.cu | 31 + + .../marlin_kernels/marlin_moe_kernel_ku4.h | 20 + + .../marlin_kernels/marlin_moe_kernel_ku4b8.cu | 31 + + .../marlin_kernels/marlin_moe_kernel_ku4b8.h | 20 + + .../marlin_moe_kernel_ku8b128.cu | 31 + + .../marlin_moe_kernel_ku8b128.h | 18 + + csrc/moe/marlin_moe_ops.cu | 588 +++ + csrc/moe/moe_align_sum_kernels.cu | 324 ++ + csrc/moe/moe_ops.h | 17 +- + csrc/moe/topk_softmax_kernels.cu | 29 +- + csrc/moe/torch_bindings.cpp | 39 + + csrc/ops.h | 419 +- + csrc/permute_cols.cu | 88 + + csrc/pos_encoding_kernels.cu | 235 +- + csrc/prepare_inputs/advance_step.cu | 327 ++ + csrc/prepare_inputs/advance_step.cuh | 19 + + csrc/quantization/aqlm/gemm_kernels.cu | 553 ++- + csrc/quantization/awq/dequantize.cuh | 139 +- + csrc/quantization/awq/gemm_kernels.cu | 620 +-- + .../compressed_tensors/int8_quant_kernels.cu | 286 ++ + csrc/quantization/cutlass_w8a8/Epilogues.md | 147 + + .../cutlass_w8a8/scaled_mm_c2x.cu | 199 + + .../cutlass_w8a8/scaled_mm_c2x.cuh | 220 ++ + .../scaled_mm_c2x_sm75_dispatch.cuh | 123 + + .../scaled_mm_c2x_sm80_dispatch.cuh | 139 + + .../scaled_mm_c2x_sm89_fp8_dispatch.cuh | 368 ++ + .../scaled_mm_c2x_sm89_int8_dispatch.cuh | 353 ++ + .../cutlass_w8a8/scaled_mm_c3x.cu | 87 + + .../cutlass_w8a8/scaled_mm_c3x.cuh | 160 + + .../scaled_mm_c3x_sm90_fp8_dispatch.cuh | 96 + + .../scaled_mm_c3x_sm90_int8_dispatch.cuh | 140 + + .../cutlass_w8a8/scaled_mm_entry.cu | 218 ++ + csrc/quantization/fp8/amd/hip_float8.h | 137 + + csrc/quantization/fp8/amd/hip_float8_impl.h | 316 ++ + csrc/quantization/fp8/amd/quant_utils.cuh | 577 +++ + csrc/quantization/fp8/common.cu | 149 + + csrc/quantization/fp8/common.cuh | 160 + + csrc/quantization/fp8/fp8_marlin.cu | 1311 +++++++ + csrc/quantization/fp8/nvidia/quant_utils.cuh | 573 +++ + ...fused_layernorm_dynamic_per_token_quant.cu | 160 + + .../fused_kernels/layernorm_utils.cuh | 327 ++ + .../fused_kernels/quant_conversions.cuh | 81 + + csrc/quantization/gguf/dequantize.cuh | 568 +++ + csrc/quantization/gguf/ggml-common.h | 1130 ++++++ + csrc/quantization/gguf/gguf_kernel.cu | 249 ++ + csrc/quantization/gguf/mmq.cuh | 600 +++ + csrc/quantization/gguf/mmvq.cuh | 190 + + csrc/quantization/gguf/vecdotq.cuh | 1810 +++++++++ + csrc/quantization/gptq/compat.cuh | 70 +- + csrc/quantization/gptq/matrix_view.cuh | 503 +-- + csrc/quantization/gptq/q_gemm.cu | 3443 ++++++++--------- + csrc/quantization/gptq/qdq_2.cuh | 107 +- + csrc/quantization/gptq/qdq_3.cuh | 246 +- + csrc/quantization/gptq/qdq_4.cuh | 203 +- + csrc/quantization/gptq/qdq_8.cuh | 34 +- + csrc/quantization/gptq/qdq_util.cuh | 58 +- + .../gptq_marlin/awq_marlin_repack.cu | 268 ++ + csrc/quantization/gptq_marlin/gptq_marlin.cu | 1547 ++++++-- + .../gptq_marlin/gptq_marlin_repack.cu | 132 +- + csrc/quantization/gptq_marlin/marlin.cuh | 87 + + .../gptq_marlin/marlin_dtypes.cuh | 79 + + csrc/quantization/machete/Readme.md | 45 + + csrc/quantization/machete/generate.py | 659 ++++ + .../machete/machete_collective_builder.cuh | 31 + + .../machete/machete_interleaving_utils.cuh | 35 + + .../quantization/machete/machete_mainloop.cuh | 1470 +++++++ + .../machete/machete_mm_kernel.cuh | 314 ++ + .../machete/machete_mm_launcher.cuh | 75 + + .../machete/machete_prepack_kernel.cuh | 76 + + .../machete/machete_prepack_launcher.cuh | 74 + + .../machete/machete_prepacked_layout.cuh | 249 ++ + csrc/quantization/machete/machete_pytorch.cu | 73 + + csrc/quantization/marlin/dense/LICENSE | 209 + + csrc/quantization/marlin/dense/common/base.h | 32 + + csrc/quantization/marlin/dense/common/mem.h | 89 + + .../marlin/dense/marlin_cuda_kernel.cu | 1073 +++++ + .../marlin/qqq/marlin_qqq_gemm_kernel.cu | 1248 ++++++ + csrc/quantization/marlin/sparse/LICENSE | 203 + + csrc/quantization/marlin/sparse/common/base.h | 51 + + csrc/quantization/marlin/sparse/common/mem.h | 136 + + csrc/quantization/marlin/sparse/common/mma.h | 191 + + .../marlin/sparse/marlin_24_cuda_kernel.cu | 1145 ++++++ + csrc/quantization/vectorization.cuh | 33 + + csrc/rocm/attention.cu | 1120 ++++++ + csrc/rocm/ops.h | 14 + + csrc/rocm/torch_bindings.cpp | 34 + + csrc/sparse/cutlass/sparse_compressor_c3x.cu | 165 + + .../sparse/cutlass/sparse_compressor_entry.cu | 42 + + csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu | 303 ++ + csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh | 496 +++ + csrc/sparse/cutlass/sparse_scaled_mm_entry.cu | 70 + + csrc/torch_bindings.cpp | 505 +++ + csrc/type_convert.cuh | 165 + + docs/Makefile | 4 + + docs/README.md | 1 + + docs/requirements-docs.txt | 24 +- + docs/source/_static/custom.js | 18 + + docs/source/_templates/sections/header.html | 39 + + docs/source/api/engine/async_llm_engine.md | 7 + + docs/source/api/engine/index.md | 17 + + docs/source/api/engine/llm_engine.md | 7 + + docs/source/api/inference_params.md | 21 + + docs/source/api/model/adapters.md | 9 + + docs/source/api/model/index.md | 11 + + docs/source/api/model/interfaces.md | 9 + + docs/source/api/model/interfaces_base.md | 9 + + docs/source/api/multimodal/index.md | 28 + + docs/source/api/multimodal/inputs.md | 49 + + docs/source/api/multimodal/parse.md | 9 + + docs/source/api/multimodal/processing.md | 9 + + docs/source/api/multimodal/profiling.md | 9 + + docs/source/api/multimodal/registry.md | 9 + + docs/source/api/offline_inference/index.md | 9 + + docs/source/api/offline_inference/llm.md | 7 + + .../api/offline_inference/llm_inputs.md | 19 + + .../dockerfile-stages-dependency.png | Bin 0 -> 118207 bytes + .../architecture_helm_deployment.png | Bin 0 -> 991484 bytes + .../arch_overview/entrypoints.excalidraw.png | Bin 0 -> 123422 bytes + .../arch_overview/llm_engine.excalidraw.png | Bin 0 -> 178116 bytes + docs/source/assets/design/hierarchy.png | Bin 0 -> 174150 bytes + .../features/disagg_prefill/abstraction.jpg | Bin 0 -> 104673 bytes + .../features/disagg_prefill/overview.jpg | Bin 0 -> 177439 bytes + docs/source/community/meetups.md | 15 + + docs/source/community/sponsors.md | 38 + + docs/source/conf.py | 146 +- + .../contributing/dockerfile/dockerfile.md | 50 + + docs/source/contributing/model/basic.md | 115 + + docs/source/contributing/model/index.md | 27 + + docs/source/contributing/model/multimodal.md | 395 ++ + .../source/contributing/model/registration.md | 55 + + docs/source/contributing/model/tests.md | 63 + + docs/source/contributing/overview.md | 149 + + .../contributing/profiling/profiling_index.md | 41 + + .../contributing/vulnerability_management.md | 43 + + docs/source/deployment/docker.md | 81 + + docs/source/deployment/frameworks/bentoml.md | 7 + + .../source/deployment/frameworks/cerebrium.md | 109 + + docs/source/deployment/frameworks/dstack.md | 102 + + docs/source/deployment/frameworks/helm.md | 250 ++ + docs/source/deployment/frameworks/index.md | 14 + + docs/source/deployment/frameworks/lws.md | 11 + + docs/source/deployment/frameworks/modal.md | 7 + + docs/source/deployment/frameworks/skypilot.md | 345 ++ + docs/source/deployment/frameworks/triton.md | 5 + + docs/source/deployment/integrations/index.md | 9 + + docs/source/deployment/integrations/kserve.md | 7 + + docs/source/deployment/integrations/kubeai.md | 15 + + .../deployment/integrations/llamastack.md | 38 + + docs/source/deployment/k8s.md | 249 ++ + docs/source/deployment/nginx.md | 133 + + docs/source/design/arch_overview.md | 252 ++ + .../source/design/automatic_prefix_caching.md | 42 + + docs/source/design/huggingface_integration.md | 36 + + docs/source/design/kernel/paged_attention.md | 529 +++ + docs/source/design/mm_processing.md | 64 + + docs/source/design/multiprocessing.md | 196 + + docs/source/design/plugin_system.md | 56 + + .../features/automatic_prefix_caching.md | 102 + + docs/source/features/compatibility_matrix.md | 468 +++ + docs/source/features/disagg_prefill.md | 68 + + docs/source/features/lora.md | 214 + + docs/source/features/quantization/auto_awq.md | 78 + + docs/source/features/quantization/bnb.md | 47 + + docs/source/features/quantization/fp8.md | 192 + + .../features/quantization/fp8_e4m3_kvcache.md | 44 + + .../features/quantization/fp8_e5m2_kvcache.md | 31 + + docs/source/features/quantization/gguf.md | 72 + + docs/source/features/quantization/index.md | 19 + + docs/source/features/quantization/int8.md | 136 + + .../quantization/supported_hardware.md | 131 + + docs/source/features/spec_decode.md | 265 ++ + docs/source/features/structured_outputs.md | 260 ++ + docs/source/features/tool_calling.md | 300 ++ + docs/source/generate_examples.py | 272 +- + docs/source/getting_started/faq.md | 37 + + .../getting_started/installation/cpu-apple.md | 48 + + .../getting_started/installation/cpu-arm.md | 46 + + .../getting_started/installation/cpu-x86.md | 154 + + .../getting_started/installation/gpu-cuda.md | 236 ++ + .../getting_started/installation/gpu-rocm.md | 163 + + .../getting_started/installation/hpu-gaudi.md | 389 ++ + .../getting_started/installation/index.md | 20 + + .../getting_started/installation/neuron.md | 132 + + .../getting_started/installation/openvino.md | 104 + + .../getting_started/installation/tpu.md | 191 + + .../getting_started/installation/xpu.md | 74 + + docs/source/getting_started/quickstart.md | 186 + + .../source/getting_started/troubleshooting.md | 203 + + docs/source/index.md | 192 + + docs/source/models/extensions/index.md | 8 + + .../models/extensions/runai_model_streamer.md | 53 + + docs/source/models/extensions/tensorizer.md | 16 + + docs/source/models/generative_models.md | 126 + + docs/source/models/pooling_models.md | 136 + + docs/source/models/supported_models.md | 868 +++++ + docs/source/performance/benchmarks.md | 28 + + docs/source/performance/optimization.md | 63 + + docs/source/serving/distributed_serving.md | 105 + + docs/source/serving/engine_args.md | 25 + + docs/source/serving/env_vars.md | 15 + + docs/source/serving/integrations/index.md | 8 + + docs/source/serving/integrations/langchain.md | 30 + + .../source/serving/integrations/llamaindex.md | 26 + + docs/source/serving/metrics.md | 38 + + docs/source/serving/multimodal_inputs.md | 533 +++ + docs/source/serving/offline_inference.md | 79 + + .../serving/openai_compatible_server.md | 439 ++- + docs/source/serving/usage_stats.md | 6 +- + examples/offline_inference/aqlm_example.py | 45 + + examples/offline_inference/arctic.py | 26 + + examples/offline_inference/audio_language.py | 131 + + examples/offline_inference/basic.py | 22 + + .../basic_with_model_default_sampling.py | 30 + + examples/offline_inference/chat.py | 80 + + examples/offline_inference/chat_with_tools.py | 138 + + examples/offline_inference/classification.py | 28 + + examples/offline_inference/cli.py | 80 + + examples/offline_inference/cpu_offload.py | 22 + + examples/offline_inference/distributed.py | 108 + + examples/offline_inference/embedding.py | 28 + + examples/offline_inference/encoder_decoder.py | 99 + + .../offline_inference/florence2_inference.py | 45 + + examples/offline_inference/gguf_inference.py | 32 + + .../offline_inference/llm_engine_example.py | 60 + + .../lora_with_quantization_inference.py | 134 + + examples/offline_inference/mlpspeculator.py | 56 + + .../offline_inference/multilora_inference.py | 106 + + examples/offline_inference/neuron.py | 36 + + .../neuron_int8_quantization.py | 50 + + .../offline_inference/openai/openai_batch.md | 205 + + .../openai/openai_example_batch.jsonl | 2 + + examples/offline_inference/pixtral.py | 165 + + examples/offline_inference/prefix_caching.py | 83 + + examples/offline_inference/profiling.py | 458 +++ + .../offline_inference/save_sharded_state.py | 75 + + examples/offline_inference/scoring.py | 23 + + .../offline_inference/simple_profiling.py | 40 + + .../offline_inference/structured_outputs.py | 78 + + examples/offline_inference/tpu.py | 28 + + examples/offline_inference/vision_language.py | 705 ++++ + .../vision_language_embedding.py | 170 + + .../vision_language_multi_image.py | 493 +++ + examples/offline_inference/whisper.py | 59 + + examples/online_serving/api_client.py | 84 + + .../online_serving/chart-helm/.helmignore | 6 + + examples/online_serving/chart-helm/Chart.yaml | 21 + + examples/online_serving/chart-helm/README.md | 21 + + examples/online_serving/chart-helm/ct.yaml | 3 + + .../online_serving/chart-helm/lintconf.yaml | 42 + + .../chart-helm/templates/_helpers.tpl | 164 + + .../chart-helm/templates/configmap.yaml | 11 + + .../chart-helm/templates/custom-objects.yaml | 6 + + .../chart-helm/templates/deployment.yaml | 122 + + .../chart-helm/templates/hpa.yaml | 31 + + .../chart-helm/templates/job.yaml | 37 + + .../templates/poddisruptionbudget.yaml | 7 + + .../chart-helm/templates/pvc.yaml | 13 + + .../chart-helm/templates/secrets.yaml | 10 + + .../chart-helm/templates/service.yaml | 14 + + .../chart-helm/values.schema.json | 265 ++ + .../online_serving/chart-helm/values.yaml | 119 + + .../online_serving/disaggregated_prefill.sh | 109 + + .../gradio_openai_chatbot_webserver.py | 82 + + examples/online_serving/gradio_webserver.py | 52 + + .../openai_chat_completion_client.py | 36 + + ...i_chat_completion_client_for_multimodal.py | 321 ++ + ...penai_chat_completion_client_with_tools.py | 162 + + ...enai_chat_completion_structured_outputs.py | 94 + + ...ai_chat_embedding_client_for_multimodal.py | 120 + + .../openai_completion_client.py | 31 + + .../openai_cross_encoder_score.py | 59 + + .../online_serving/openai_embedding_client.py | 25 + + .../online_serving/openai_pooling_client.py | 51 + + examples/online_serving/opentelemetry/Otel.md | 82 + + .../opentelemetry/dummy_client.py | 35 + + .../prometheus_grafana/README.md | 54 + + .../prometheus_grafana/docker-compose.yaml | 19 + + .../prometheus_grafana/grafana.json | 1557 ++++++++ + .../prometheus_grafana/prometheus.yaml | 10 + + examples/online_serving/run_cluster.sh | 49 + + .../online_serving/sagemaker-entrypoint.sh | 24 + + examples/other/fp8/README.md | 96 + + examples/other/fp8/extract_scales.py | 367 ++ + examples/other/fp8/quantizer/README.md | 32 + + examples/other/fp8/quantizer/quantize.py | 367 ++ + examples/other/logging_configuration.md | 172 + + examples/other/tensorize_vllm_model.py | 240 ++ + examples/template_blip2.jinja | 11 + + examples/template_dse_qwen2_vl.jinja | 7 + + examples/template_llava.jinja | 23 + + examples/template_pixtral_hf.jinja | 38 + + examples/template_vlm2vec.jinja | 16 + + examples/tool_chat_template_granite.jinja | 36 + + .../tool_chat_template_granite_20b_fc.jinja | 130 + + examples/tool_chat_template_hermes.jinja | 130 + + .../tool_chat_template_internlm2_tool.jinja | 60 + + .../tool_chat_template_llama3.1_json.jinja | 120 + + .../tool_chat_template_llama3.2_json.jinja | 133 + + ...tool_chat_template_llama3.2_pythonic.jinja | 98 + + examples/tool_chat_template_mistral.jinja | 86 + + .../tool_chat_template_mistral_parallel.jinja | 93 + + examples/tool_chat_template_toolace.jinja | 65 + + find_cuda_init.py | 33 + + format.sh | 141 +- + pyproject.toml | 64 +- + python_only_dev.py | 14 + + requirements-build.txt | 8 +- + requirements-common.txt | 45 +- + requirements-cpu.txt | 8 +- + requirements-cuda.txt | 11 +- + requirements-dev.txt | 36 +- + requirements-hpu.txt | 11 + + requirements-lint.txt | 15 + + requirements-neuron.txt | 4 +- + requirements-openvino.txt | 8 + + requirements-rocm.txt | 9 +- + requirements-test.in | 32 + + requirements-test.txt | 582 +++ + requirements-tpu.txt | 25 + + requirements-xpu.txt | 16 + + setup.py | 422 +- + tests/async_engine/__init__.py | 0 + tests/async_engine/api_server_async_engine.py | 13 +- + tests/async_engine/test_api_server.py | 11 +- + tests/async_engine/test_async_llm_engine.py | 300 +- + tests/async_engine/test_request_tracker.py | 27 +- + tests/basic_correctness/__init__.py | 0 + .../test_basic_correctness.py | 209 +- + .../basic_correctness/test_chunked_prefill.py | 310 +- + tests/basic_correctness/test_cpu_offload.py | 6 + + tests/basic_correctness/test_preemption.py | 227 +- + tests/compile/__init__.py | 0 + tests/compile/backend.py | 37 + + tests/compile/piecewise/__init__.py | 0 + tests/compile/piecewise/test_simple.py | 109 + + tests/compile/piecewise/test_toy_llama.py | 447 +++ + tests/compile/test_basic_correctness.py | 141 + + tests/compile/test_full_graph.py | 20 + + tests/compile/test_functionalization.py | 100 + + tests/compile/test_fusion.py | 116 + + tests/compile/test_pass_manager.py | 35 + + tests/compile/test_wrapper.py | 61 + + tests/compile/utils.py | 97 + + tests/conftest.py | 1026 ++++- + tests/core/block/e2e/__init__.py | 0 + tests/core/block/e2e/conftest.py | 30 +- + tests/core/block/e2e/test_correctness.py | 247 +- + .../e2e/test_correctness_sliding_window.py | 170 + + tests/core/block/test_block_manager.py | 491 +++ + tests/core/block/test_block_table.py | 19 +- + .../block/test_cpu_gpu_block_allocator.py | 32 +- + tests/core/block/test_naive_block.py | 49 +- + tests/core/block/test_prefix_caching_block.py | 549 ++- + tests/core/test_chunked_prefill_scheduler.py | 372 +- + tests/core/test_num_computed_tokens_update.py | 80 + + tests/core/test_scheduler.py | 777 ++-- + tests/core/test_scheduler_encoder_decoder.py | 104 + + tests/core/test_serialization.py | 33 + + tests/core/utils.py | 214 +- + tests/data/test_config.yaml | 5 + + tests/distributed/__init__.py | 0 + tests/distributed/test_ca_buffer_sharing.py | 59 + + tests/distributed/test_comm_ops.py | 134 +- + tests/distributed/test_custom_all_reduce.py | 99 +- + tests/distributed/test_distributed_oot.py | 6 + + .../distributed/test_multi_node_assignment.py | 64 + + tests/distributed/test_pipeline_parallel.py | 427 ++ + tests/distributed/test_pipeline_partition.py | 34 + + tests/distributed/test_pp_cudagraph.py | 30 + + tests/distributed/test_pynccl.py | 267 +- + tests/distributed/test_same_node.py | 34 + + tests/distributed/test_shm_broadcast.py | 116 + + tests/distributed/test_utils.py | 141 + + tests/encoder_decoder/__init__.py | 0 + tests/encoder_decoder/test_e2e_correctness.py | 119 + + tests/engine/__init__.py | 0 + tests/engine/output_processor/__init__.py | 0 + .../output_processor/test_multi_step.py | 23 +- + .../output_processor/test_stop_checker.py | 86 + + tests/engine/test_arg_utils.py | 142 + + tests/engine/test_custom_executor.py | 91 + + tests/engine/test_multiproc_workers.py | 6 +- + tests/engine/test_short_mm_context.py | 29 + + tests/engine/test_skip_tokenizer_init.py | 7 +- + tests/engine/test_stop_reason.py | 11 +- + tests/engine/test_stop_strings.py | 158 +- + tests/entrypoints/__init__.py | 0 + tests/entrypoints/conftest.py | 159 + + tests/entrypoints/llm/__init__.py | 0 + tests/entrypoints/llm/test_accuracy.py | 56 + + tests/entrypoints/llm/test_chat.py | 92 + + tests/entrypoints/llm/test_encode.py | 107 + + tests/entrypoints/llm/test_generate.py | 104 + + .../llm/test_generate_multiple_loras.py | 66 + + tests/entrypoints/llm/test_gpu_utilization.py | 25 + + tests/entrypoints/llm/test_guided_generate.py | 265 ++ + tests/entrypoints/llm/test_init.py | 22 + + tests/entrypoints/llm/test_lazy_outlines.py | 76 + + .../entrypoints/llm/test_prompt_validation.py | 24 + + tests/entrypoints/offline_mode/__init__.py | 0 + .../offline_mode/test_offline_mode.py | 82 + + tests/entrypoints/openai/__init__.py | 0 + tests/entrypoints/openai/test_accuracy.py | 85 + + .../openai/test_async_tokenization.py | 137 + + tests/entrypoints/openai/test_audio.py | 380 ++ + tests/entrypoints/openai/test_basic.py | 156 + + tests/entrypoints/openai/test_chat.py | 996 +++++ + tests/entrypoints/openai/test_chat_echo.py | 79 + + .../entrypoints/openai/test_chat_template.py | 117 + + .../entrypoints/openai/test_chunked_prompt.py | 126 + + tests/entrypoints/openai/test_cli_args.py | 131 + + tests/entrypoints/openai/test_completion.py | 779 ++++ + tests/entrypoints/openai/test_embedding.py | 274 ++ + .../openai/test_encoder_decoder.py | 52 + + .../entrypoints/openai/test_lora_adapters.py | 269 ++ + tests/entrypoints/openai/test_metrics.py | 236 ++ + tests/entrypoints/openai/test_models.py | 64 + + .../openai/test_oot_registration.py | 42 + + tests/entrypoints/openai/test_pooling.py | 238 ++ + .../openai/test_prompt_validation.py | 57 + + .../openai/test_return_tokens_as_ids.py | 87 + + tests/entrypoints/openai/test_root_path.py | 103 + + tests/entrypoints/openai/test_run_batch.py | 104 + + tests/entrypoints/openai/test_score.py | 93 + + tests/entrypoints/openai/test_serving_chat.py | 136 +- + .../entrypoints/openai/test_serving_models.py | 121 + + tests/entrypoints/openai/test_shutdown.py | 37 + + tests/entrypoints/openai/test_tokenization.py | 170 + + tests/entrypoints/openai/test_video.py | 348 ++ + tests/entrypoints/openai/test_vision.py | 349 ++ + .../openai/test_vision_embedding.py | 95 + + .../openai/tool_parsers/__init__.py | 0 + .../tool_parsers/test_pythonic_tool_parser.py | 160 + + .../entrypoints/openai/tool_parsers/utils.py | 123 + + tests/entrypoints/test_chat_utils.py | 796 ++++ + tests/kernels/__init__.py | 0 + tests/kernels/quant_utils.py | 88 + + tests/kernels/test_activation.py | 67 +- + tests/kernels/test_aqlm.py | 37 + + tests/kernels/test_attention.py | 181 +- + tests/kernels/test_attention_selector.py | 100 + + tests/kernels/test_awq.py | 43 + + tests/kernels/test_awq_marlin.py | 167 + + tests/kernels/test_awq_triton.py | 170 + + tests/kernels/test_block_fp8.py | 265 ++ + tests/kernels/test_blocksparse_attention.py | 439 +++ + tests/kernels/test_cache.py | 235 +- + tests/kernels/test_cascade_flash_attn.py | 182 + + tests/kernels/test_causal_conv1d.py | 435 +++ + tests/kernels/test_cutlass.py | 455 +++ + tests/kernels/test_encoder_decoder_attn.py | 1101 ++++++ + tests/kernels/test_flash_attn.py | 241 ++ + tests/kernels/test_flashinfer.py | 470 +++ + tests/kernels/test_fp8_quant.py | 114 + + tests/kernels/test_fused_quant_layernorm.py | 171 + + tests/kernels/test_ggml.py | 22 + + tests/kernels/test_gguf.py | 127 + + tests/kernels/test_gptq.py | 29 + + tests/kernels/test_int8_quant.py | 190 + + tests/kernels/test_layernorm.py | 96 +- + tests/kernels/test_machete_mm.py | 406 ++ + tests/kernels/test_mamba_ssm.py | 720 ++++ + tests/kernels/test_marlin_gemm.py | 616 +++ + tests/kernels/test_moe.py | 324 +- + tests/kernels/test_permute_cols.py | 15 + + tests/kernels/test_pos_encoding.py | 126 +- + tests/kernels/test_prefix_prefill.py | 334 +- + tests/kernels/test_rotary_embedding.py | 62 + + tests/kernels/test_semi_structured.py | 134 + + tests/kernels/test_triton_scaled_mm.py | 106 + + tests/kernels/test_utils.py | 24 + + tests/kernels/utils.py | 1100 ++++++ + tests/kv_transfer/disagg_test.py | 119 + + tests/kv_transfer/module_test.py | 64 + + tests/kv_transfer/test_lookup_buffer.py | 160 + + tests/kv_transfer/test_lookup_buffer.sh | 8 + + tests/kv_transfer/test_send_recv.py | 158 + + tests/kv_transfer/test_send_recv.sh | 9 + + tests/lora/conftest.py | 211 +- + tests/lora/data/__init__.py | 0 + tests/lora/data/long_context_test_data.py | 119 + + tests/lora/test_baichuan.py | 36 +- + tests/lora/test_chatglm3_tp.py | 105 + + tests/lora/test_gemma.py | 19 +- + tests/lora/test_jamba.py | 54 + + tests/lora/test_layers.py | 791 +++- + tests/lora/test_llama_tp.py | 159 + + tests/lora/test_long_context.py | 299 ++ + tests/lora/test_lora_bias_e2e.py | 52 + + tests/lora/test_lora_checkpoints.py | 56 +- + tests/lora/test_lora_huggingface.py | 39 + + tests/lora/test_lora_manager.py | 570 ++- + tests/lora/test_minicpmv_tp.py | 122 + + tests/lora/test_mixtral.py | 90 +- + tests/lora/test_phi.py | 70 + + tests/lora/test_punica_ops_sizes.py | 400 ++ + tests/lora/test_punica_ops_variation.py | 316 ++ + tests/lora/test_quant_model.py | 115 +- + tests/lora/test_qwen2vl.py | 81 + + tests/lora/test_tokenizer_group.py | 24 +- + tests/lora/test_utils.py | 83 +- + tests/lora/test_worker.py | 17 +- + tests/lora/utils.py | 280 +- + tests/metrics/__init__.py | 0 + tests/metrics/test_metrics.py | 324 +- + tests/model_executor/__init__.py | 0 + tests/model_executor/conftest.py | 49 + + .../model_executor/test_enabled_custom_ops.py | 89 + + .../model_executor/test_guided_processors.py | 128 + + .../test_model_load_with_params.py | 119 + + tests/models/__init__.py | 0 + tests/models/decoder_only/__init__.py | 0 + .../decoder_only/audio_language/__init__.py | 0 + .../audio_language/test_ultravox.py | 268 ++ + .../models/decoder_only/language/__init__.py | 0 + .../models/decoder_only/language/test_aqlm.py | 69 + + .../models/decoder_only/language/test_fp8.py | 100 + + .../models/decoder_only/language/test_gguf.py | 130 + + .../decoder_only/language/test_gptq_marlin.py | 84 + + .../language/test_gptq_marlin_24.py | 73 + + .../decoder_only/language/test_granite.py | 41 + + .../decoder_only/language/test_jamba.py | 339 ++ + .../decoder_only/language/test_mamba.py | 323 ++ + .../decoder_only/language/test_mistral.py | 335 ++ + .../decoder_only/language/test_modelopt.py | 80 + + .../decoder_only/language/test_models.py | 86 + + .../decoder_only/language/test_phimoe.py | 102 + + .../decoder_only/vision_language/__init__.py | 0 + .../decoder_only/vision_language/test_awq.py | 120 + + .../vision_language/test_h2ovl.py | 129 + + .../vision_language/test_intern_vit.py | 77 + + .../vision_language/test_models.py | 742 ++++ + .../vision_language/test_phi3v.py | 234 ++ + .../vision_language/test_pixtral.py | 270 ++ + .../vision_language/test_qwen2_vl.py | 429 ++ + .../vision_language/vlm_utils/__init__.py | 0 + .../vision_language/vlm_utils/builders.py | 236 ++ + .../vlm_utils/case_filtering.py | 157 + + .../vision_language/vlm_utils/core.py | 156 + + .../vlm_utils/custom_inputs.py | 103 + + .../vision_language/vlm_utils/model_utils.py | 582 +++ + .../vision_language/vlm_utils/runners.py | 139 + + .../vision_language/vlm_utils/types.py | 198 + + tests/models/embedding/__init__.py | 0 + tests/models/embedding/language/__init__.py | 0 + .../embedding/language/test_cls_models.py | 42 + + .../embedding/language/test_embedding.py | 75 + + .../models/embedding/language/test_gritlm.py | 200 + + .../models/embedding/language/test_scoring.py | 89 + + tests/models/embedding/utils.py | 30 + + .../embedding/vision_language/__init__.py | 0 + .../vision_language/test_dse_qwen2_vl.py | 209 + + .../vision_language/test_llava_next.py | 140 + + .../embedding/vision_language/test_phi3v.py | 126 + + tests/models/encoder_decoder/__init__.py | 0 + .../audio_language/__init__.py | 0 + .../audio_language/test_whisper.py | 136 + + .../encoder_decoder/language/__init__.py | 0 + .../encoder_decoder/language/test_bart.py | 222 ++ + .../vision_language/__init__.py | 0 + .../vision_language/test_broadcast.py | 35 + + .../vision_language/test_florence2.py | 102 + + .../vision_language/test_mllama.py | 367 ++ + tests/models/fixtures/pixtral_chat.json | 1 + + .../models/fixtures/pixtral_chat_engine.json | 1 + + tests/models/multimodal/__init__.py | 0 + .../models/multimodal/processing/__init__.py | 0 + .../multimodal/processing/test_common.py | 201 + + .../multimodal/processing/test_idefics3.py | 178 + + .../multimodal/processing/test_internvl.py | 206 + + .../multimodal/processing/test_llava_next.py | 132 + + .../processing/test_llava_onevision.py | 132 + + .../multimodal/processing/test_phi3v.py | 55 + + .../models/multimodal/processing/test_qwen.py | 144 + + .../multimodal/processing/test_qwen2_vl.py | 54 + + tests/models/registry.py | 248 ++ + tests/models/test_initialization.py | 63 + + tests/models/test_oot_registration.py | 80 +- + tests/models/test_registry.py | 94 + + tests/models/utils.py | 282 +- + tests/mq_llm_engine/__init__.py | 0 + tests/mq_llm_engine/test_abort.py | 67 + + tests/mq_llm_engine/test_error_handling.py | 293 ++ + tests/mq_llm_engine/test_load.py | 57 + + tests/mq_llm_engine/utils.py | 78 + + tests/multi_step/__init__.py | 0 + .../multi_step/test_correctness_async_llm.py | 224 ++ + tests/multi_step/test_correctness_llm.py | 352 ++ + tests/multimodal/__init__.py | 0 + tests/multimodal/test_inputs.py | 95 + + tests/multimodal/test_processing.py | 613 +++ + tests/multimodal/test_processor_kwargs.py | 400 ++ + tests/multimodal/test_utils.py | 400 ++ + tests/multimodal/utils.py | 33 + + tests/plugins/vllm_add_dummy_model/setup.py | 9 + + .../vllm_add_dummy_model/__init__.py | 20 + + .../my_gemma_embedding.py | 70 + + .../vllm_add_dummy_model/my_llava.py | 26 + + .../vllm_add_dummy_model/my_opt.py | 19 + + .../plugins/vllm_add_dummy_platform/setup.py | 11 + + .../vllm_add_dummy_platform/__init__.py | 5 + + .../vllm_add_dummy_platform/dummy_platform.py | 5 + + tests/plugins_tests/test_platform_plugins.py | 16 + + tests/prefix_caching/__init__.py | 0 + .../test_disable_sliding_window.py | 44 + + tests/prefix_caching/test_prefix_caching.py | 254 +- + tests/prompt_adapter/test_bloom.py | 45 + + .../test_multi_adapter_inference.py | 53 + + tests/prompt_adapter/test_pa_lora.py | 61 + + tests/quantization/__init__.py | 0 + tests/quantization/test_bitsandbytes.py | 168 + + tests/quantization/test_compressed_tensors.py | 313 ++ + tests/quantization/test_configs.py | 10 +- + tests/quantization/test_cpu_offload.py | 68 + + tests/quantization/test_experts_int8.py | 28 + + tests/quantization/test_fp8.py | 144 +- + tests/quantization/test_ipex_quant.py | 30 + + tests/quantization/test_lm_head.py | 47 + + tests/quantization/utils.py | 15 + + tests/runai_model_streamer/__init__.py | 0 + .../test_runai_model_streamer_loader.py | 31 + + .../runai_model_streamer/test_weight_utils.py | 39 + + tests/samplers/__init__.py | 0 + tests/samplers/test_beam_search.py | 35 +- + tests/samplers/test_ignore_eos.py | 24 +- + tests/samplers/test_logits_processor.py | 85 +- + tests/samplers/test_logprobs.py | 118 +- + tests/samplers/test_no_bad_words.py | 185 + + tests/samplers/test_ranks.py | 38 +- + tests/samplers/test_rejection_sampler.py | 231 +- + tests/samplers/test_sampler.py | 315 +- + tests/samplers/test_seeded_generate.py | 11 +- + .../test_typical_acceptance_sampler.py | 470 +++ + tests/spec_decode/e2e/conftest.py | 519 ++- + tests/spec_decode/e2e/test_compatibility.py | 134 +- + .../spec_decode/e2e/test_eagle_correctness.py | 309 ++ + tests/spec_decode/e2e/test_integration.py | 140 + + .../e2e/test_integration_dist_tp2.py | 174 + + .../e2e/test_integration_dist_tp4.py | 121 + + tests/spec_decode/e2e/test_logprobs.py | 349 +- + .../e2e/test_medusa_correctness.py | 383 ++ + tests/spec_decode/e2e/test_mlp_correctness.py | 478 +++ + .../e2e/test_multistep_correctness.py | 486 ++- + .../spec_decode/e2e/test_ngram_correctness.py | 262 +- + tests/spec_decode/e2e/test_seed.py | 67 + + tests/spec_decode/test_batch_expansion.py | 23 +- + tests/spec_decode/test_dynamic_spec_decode.py | 87 + + tests/spec_decode/test_metrics.py | 137 +- + tests/spec_decode/test_multi_step_worker.py | 441 ++- + tests/spec_decode/test_ngram_worker.py | 50 +- + tests/spec_decode/test_scorer.py | 114 + + tests/spec_decode/test_spec_decode_worker.py | 464 ++- + tests/spec_decode/test_utils.py | 76 +- + tests/spec_decode/utils.py | 143 +- + tests/standalone_tests/lazy_torch_compile.py | 28 + + tests/standalone_tests/python_only_compile.sh | 30 + + tests/system_messages/sonnet3.5_nov2024.txt | 71 + + tests/tensorizer_loader/conftest.py | 47 + + tests/tensorizer_loader/test_tensorizer.py | 448 ++- + tests/test_cache_block_hashing.py | 16 +- + tests/test_config.py | 252 +- + tests/test_embedded_commit.py | 8 + + tests/test_inputs.py | 79 + + tests/test_logger.py | 18 +- + tests/test_logits_processor.py | 33 +- + tests/test_regression.py | 21 + + tests/test_scalartype.py | 36 + + tests/test_sequence.py | 46 +- + tests/test_sharded_state_loader.py | 131 + + tests/test_utils.py | 447 +++ + tests/tokenization/test_detokenize.py | 189 +- + tests/tokenization/test_get_eos.py | 31 + + tests/tokenization/test_tokenizer_group.py | 120 +- + tests/tool_use/__init__.py | 0 + tests/tool_use/conftest.py | 38 + + ...est_chat_completion_request_validations.py | 71 + + tests/tool_use/test_chat_completions.py | 146 + + tests/tool_use/test_jamba_tool_parser.py | 275 ++ + tests/tool_use/test_parallel_tool_calls.py | 205 + + tests/tool_use/test_tool_calls.py | 192 + + tests/tool_use/utils.py | 313 ++ + tests/tpu/__init__.py | 0 + tests/tpu/test_compilation.py | 79 + + tests/tpu/test_custom_dispatcher.py | 22 + + tests/tpu/test_quantization_accuracy.py | 49 + + tests/tracing/__init__.py | 0 + tests/tracing/test_tracing.py | 202 + + tests/utils.py | 825 ++++ + tests/v1/__init__.py | 0 + tests/v1/core/test_kv_cache_utils.py | 245 ++ + tests/v1/core/test_prefix_caching.py | 566 +++ + tests/v1/e2e/__init__.py | 0 + tests/v1/e2e/test_cascade_attention.py | 22 + + tests/v1/engine/__init__.py | 0 + tests/v1/engine/test_async_llm.py | 116 + + tests/v1/engine/test_engine_args.py | 46 + + tests/v1/engine/test_engine_core.py | 177 + + tests/v1/engine/test_engine_core_client.py | 202 + + tests/v1/engine/test_output_processor.py | 295 ++ + tests/v1/sample/__init__.py | 0 + tests/v1/sample/test_sampler.py | 321 ++ + tests/v1/worker/__init__.py | 0 + tests/v1/worker/test_gpu_input_batch.py | 224 ++ + tests/vllm_test_utils/setup.py | 7 + + .../vllm_test_utils/__init__.py | 9 + + .../vllm_test_utils/vllm_test_utils/blame.py | 53 + + .../vllm_test_utils/monitor.py | 68 + + tests/weight_loading/models-large.txt | 5 + + tests/weight_loading/models.txt | 33 + + .../run_model_weight_loading_test.sh | 49 + + tests/weight_loading/test_weight_loading.py | 32 + + .../test_encoder_decoder_model_runner.py | 646 ++++ + tests/worker/test_model_input.py | 241 ++ + tests/worker/test_model_runner.py | 281 +- + tests/worker/test_profile.py | 65 + + tests/worker/test_swap.py | 37 +- + tools/actionlint.sh | 13 + + tools/check_repo.sh | 14 + + tools/doc-lint.sh | 3 + + tools/mypy.sh | 33 + + tools/png-lint.sh | 15 + + tools/profiler/print_layerwise_table.py | 82 + + tools/profiler/visualize_layerwise_profile.py | 590 +++ + tools/report_build_time_ninja.py | 312 ++ + tools/shellcheck.sh | 22 + + use_existing_torch.py | 18 + + vllm/__init__.py | 24 +- + vllm/_custom_ops.py | 967 ++++- + vllm/_ipex_ops.py | 226 ++ + vllm/adapter_commons/__init__.py | 0 + vllm/adapter_commons/layers.py | 14 + + vllm/adapter_commons/models.py | 103 + + vllm/adapter_commons/request.py | 23 + + vllm/adapter_commons/utils.py | 90 + + vllm/adapter_commons/worker_manager.py | 36 + + vllm/assets/__init__.py | 0 + vllm/assets/audio.py | 31 + + vllm/assets/base.py | 38 + + vllm/assets/image.py | 29 + + vllm/assets/video.py | 82 + + vllm/attention/__init__.py | 8 +- + vllm/attention/backends/abstract.py | 205 +- + vllm/attention/backends/blocksparse_attn.py | 454 +++ + vllm/attention/backends/flash_attn.py | 899 ++++- + vllm/attention/backends/flashinfer.py | 836 +++- + vllm/attention/backends/hpu_attn.py | 281 ++ + vllm/attention/backends/ipex_attn.py | 386 ++ + vllm/attention/backends/openvino.py | 140 + + vllm/attention/backends/pallas.py | 345 ++ + vllm/attention/backends/placeholder_attn.py | 403 ++ + vllm/attention/backends/rocm_flash_attn.py | 472 ++- + vllm/attention/backends/torch_sdpa.py | 613 ++- + vllm/attention/backends/utils.py | 574 +++ + vllm/attention/backends/xformers.py | 577 ++- + vllm/attention/layer.py | 275 +- + .../ops/blocksparse_attention/__init__.py | 0 + .../blocksparse_attention_kernel.py | 430 ++ + .../ops/blocksparse_attention/interface.py | 236 ++ + .../ops/blocksparse_attention/utils.py | 242 ++ + vllm/attention/ops/hpu_paged_attn.py | 103 + + vllm/attention/ops/ipex_attn.py | 191 + + vllm/attention/ops/paged_attn.py | 69 +- + vllm/attention/ops/prefix_prefill.py | 180 +- + vllm/attention/ops/triton_flash_attention.py | 10 + + vllm/attention/selector.py | 225 +- + vllm/beam_search.py | 71 + + vllm/compilation/__init__.py | 0 + vllm/compilation/backends.py | 802 ++++ + vllm/compilation/counter.py | 31 + + vllm/compilation/decorators.py | 235 ++ + vllm/compilation/fix_functionalization.py | 180 + + vllm/compilation/fusion.py | 615 +++ + vllm/compilation/fx_utils.py | 42 + + vllm/compilation/inductor_pass.py | 84 + + vllm/compilation/monitor.py | 36 + + vllm/compilation/multi_output_match.py | 106 + + vllm/compilation/pass_manager.py | 77 + + vllm/compilation/reshapes.py | 88 + + vllm/compilation/vllm_inductor_pass.py | 49 + + vllm/compilation/wrapper.py | 105 + + vllm/config.py | 2904 ++++++++++++-- + vllm/connections.py | 167 + + vllm/core/block/block_table.py | 195 +- + vllm/core/block/common.py | 276 +- + vllm/core/block/cpu_gpu_block_allocator.py | 249 +- + vllm/core/block/interfaces.py | 127 +- + vllm/core/block/naive_block.py | 266 +- + vllm/core/block/prefix_caching_block.py | 879 ++++- + vllm/core/block/utils.py | 26 + + vllm/core/block_manager.py | 516 +++ + vllm/core/evictor.py | 154 + + vllm/core/interfaces.py | 40 +- + vllm/core/placeholder_block_space_manager.py | 94 + + vllm/core/scheduler.py | 1297 +++++-- + vllm/distributed/communication_op.py | 233 +- + .../device_communicators/cuda_wrapper.py | 172 + + .../device_communicators/custom_all_reduce.py | 469 +-- + .../custom_all_reduce_utils.py | 255 ++ + .../device_communicators/hpu_communicator.py | 48 + + .../device_communicators/pynccl.py | 406 +- + .../device_communicators/pynccl_wrapper.py | 338 ++ + .../device_communicators/shm_broadcast.py | 528 +++ + .../device_communicators/tpu_communicator.py | 61 + + .../device_communicators/xpu_communicator.py | 47 + + vllm/distributed/kv_transfer/README.md | 30 + + vllm/distributed/kv_transfer/__init__.py | 0 + .../kv_transfer/disagg_prefill_workflow.jpg | Bin 0 -> 142656 bytes + .../kv_transfer/kv_connector/__init__.py | 0 + .../kv_transfer/kv_connector/base.py | 122 + + .../kv_transfer/kv_connector/factory.py | 48 + + .../kv_connector/simple_connector.py | 312 ++ + .../kv_transfer/kv_lookup_buffer/__init__.py | 0 + .../kv_transfer/kv_lookup_buffer/base.py | 108 + + .../kv_lookup_buffer/simple_buffer.py | 242 ++ + .../kv_transfer/kv_pipe/__init__.py | 0 + vllm/distributed/kv_transfer/kv_pipe/base.py | 65 + + .../kv_transfer/kv_pipe/mooncake_pipe.py | 272 ++ + .../kv_transfer/kv_pipe/pynccl_pipe.py | 276 ++ + .../kv_transfer/kv_transfer_agent.py | 75 + + vllm/distributed/parallel_state.py | 1307 ++++++- + vllm/distributed/utils.py | 255 +- + vllm/engine/arg_utils.py | 1028 ++++- + vllm/engine/async_llm_engine.py | 1130 ++++-- + vllm/engine/async_timeout.py | 189 + + vllm/engine/llm_engine.py | 1847 +++++++-- + vllm/engine/metrics.py | 683 +++- + vllm/engine/metrics_types.py | 101 + + vllm/engine/multiprocessing/__init__.py | 152 + + vllm/engine/multiprocessing/client.py | 689 ++++ + vllm/engine/multiprocessing/engine.py | 389 ++ + vllm/engine/output_processor/interfaces.py | 20 +- + vllm/engine/output_processor/multi_step.py | 149 +- + vllm/engine/output_processor/single_step.py | 318 +- + vllm/engine/output_processor/stop_checker.py | 77 +- + vllm/engine/output_processor/util.py | 16 +- + vllm/engine/protocol.py | 277 ++ + vllm/entrypoints/api_server.py | 101 +- + vllm/entrypoints/chat_utils.py | 1001 +++++ + vllm/entrypoints/launcher.py | 103 + + vllm/entrypoints/llm.py | 1198 +++++- + vllm/entrypoints/logger.py | 42 + + vllm/entrypoints/openai/api_server.py | 816 +++- + vllm/entrypoints/openai/cli_args.py | 179 +- + vllm/entrypoints/openai/logits_processors.py | 86 + + vllm/entrypoints/openai/protocol.py | 1118 +++++- + vllm/entrypoints/openai/run_batch.py | 317 ++ + vllm/entrypoints/openai/serving_chat.py | 918 +++-- + vllm/entrypoints/openai/serving_completion.py | 541 ++- + vllm/entrypoints/openai/serving_embedding.py | 240 ++ + vllm/entrypoints/openai/serving_engine.py | 609 ++- + vllm/entrypoints/openai/serving_models.py | 250 ++ + vllm/entrypoints/openai/serving_pooling.py | 233 ++ + vllm/entrypoints/openai/serving_score.py | 226 ++ + .../openai/serving_tokenization.py | 144 + + .../openai/tool_parsers/__init__.py | 16 + + .../tool_parsers/abstract_tool_parser.py | 160 + + .../granite_20b_fc_tool_parser.py | 251 ++ + .../tool_parsers/granite_tool_parser.py | 229 ++ + .../openai/tool_parsers/hermes_tool_parser.py | 367 ++ + .../tool_parsers/internlm2_tool_parser.py | 208 + + .../openai/tool_parsers/jamba_tool_parser.py | 300 ++ + .../openai/tool_parsers/llama_tool_parser.py | 258 ++ + .../tool_parsers/mistral_tool_parser.py | 322 ++ + .../tool_parsers/pythonic_tool_parser.py | 289 ++ + vllm/entrypoints/openai/tool_parsers/utils.py | 121 + + vllm/entrypoints/utils.py | 57 + + vllm/envs.py | 324 +- + vllm/executor/cpu_executor.py | 333 +- + vllm/executor/distributed_gpu_executor.py | 149 +- + vllm/executor/executor_base.py | 74 +- + vllm/executor/gpu_executor.py | 131 +- + vllm/executor/hpu_executor.py | 202 + + vllm/executor/msgspec_utils.py | 27 + + vllm/executor/multiproc_gpu_executor.py | 223 ++ + vllm/executor/multiproc_worker_utils.py | 77 +- + vllm/executor/multiproc_xpu_executor.py | 26 + + vllm/executor/neuron_executor.py | 59 +- + vllm/executor/openvino_executor.py | 125 + + vllm/executor/ray_gpu_executor.py | 541 ++- + vllm/executor/ray_hpu_executor.py | 515 +++ + vllm/executor/ray_tpu_executor.py | 343 ++ + vllm/executor/ray_utils.py | 308 +- + vllm/executor/ray_xpu_executor.py | 40 + + vllm/executor/tpu_executor.py | 142 + + vllm/executor/xpu_executor.py | 39 + + vllm/forward_context.py | 99 + + vllm/inputs/__init__.py | 37 + + vllm/inputs/data.py | 403 ++ + vllm/inputs/parse.py | 112 + + vllm/inputs/preprocess.py | 707 ++++ + vllm/inputs/registry.py | 464 +++ + vllm/logger.py | 80 +- + vllm/logging_utils/__init__.py | 5 + + vllm/logging_utils/formatter.py | 15 + + vllm/logits_process.py | 119 + + vllm/lora/fully_sharded_layers.py | 279 +- + vllm/lora/layers.py | 1260 +++--- + vllm/lora/lora.py | 35 +- + vllm/lora/models.py | 612 +-- + vllm/lora/ops/__init__.py | 0 + vllm/lora/ops/torch_ops/__init__.py | 13 + + vllm/lora/ops/torch_ops/lora_ops.py | 113 + + vllm/lora/ops/triton_ops/__init__.py | 13 + + vllm/lora/ops/triton_ops/bgmv_expand.py | 187 + + vllm/lora/ops/triton_ops/bgmv_expand_slice.py | 206 + + vllm/lora/ops/triton_ops/bgmv_shrink.py | 167 + + vllm/lora/ops/triton_ops/sgmv_expand.py | 278 ++ + vllm/lora/ops/triton_ops/sgmv_shrink.py | 239 ++ + vllm/lora/ops/triton_ops/utils.py | 165 + + vllm/lora/peft_helper.py | 80 + + vllm/lora/punica_wrapper/__init__.py | 7 + + vllm/lora/punica_wrapper/punica_base.py | 482 +++ + vllm/lora/punica_wrapper/punica_cpu.py | 346 ++ + vllm/lora/punica_wrapper/punica_gpu.py | 314 ++ + vllm/lora/punica_wrapper/punica_hpu.py | 87 + + vllm/lora/punica_wrapper/punica_selector.py | 26 + + vllm/lora/punica_wrapper/utils.py | 159 + + vllm/lora/request.py | 85 +- + vllm/lora/utils.py | 145 +- + vllm/lora/worker_manager.py | 247 +- + vllm/model_executor/__init__.py | 8 +- + vllm/model_executor/custom_op.py | 144 + + .../guided_decoding/__init__.py | 150 +- + .../guided_decoding/guided_fields.py | 39 + + .../lm_format_enforcer_decoding.py | 46 +- + .../guided_decoding/outlines_decoding.py | 94 +- + .../outlines_logits_processors.py | 107 +- + vllm/model_executor/guided_decoding/utils.py | 228 ++ + .../guided_decoding/xgrammar_decoding.py | 316 ++ + vllm/model_executor/layers/activation.py | 241 +- + .../layers/fused_moe/__init__.py | 47 +- + ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 146 + + ...336,device_name=NVIDIA_A100-SXM4-80GB.json | 146 + + ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 218 ++ + ...792,device_name=NVIDIA_A100-SXM4-80GB.json | 218 ++ + ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 218 ++ + ...VIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json | 218 ++ + ...072,device_name=NVIDIA_H100_80GB_HBM3.json | 218 ++ + ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 218 ++ + ...584,device_name=NVIDIA_A100-SXM4-80GB.json | 218 ++ + ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 218 ++ + ...168,device_name=NVIDIA_A100-SXM4-80GB.json | 218 ++ + ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 146 + + ...336,device_name=NVIDIA_A100-SXM4-80GB.json | 146 + + ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 218 ++ + ...792,device_name=NVIDIA_A100-SXM4-80GB.json | 218 ++ + ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 146 + + ...VIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json | 146 + + ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 130 + + ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 146 + + ...584,device_name=NVIDIA_A100-SXM4-80GB.json | 218 ++ + ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 130 + + ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 146 + + ...168,device_name=NVIDIA_A100-SXM4-80GB.json | 146 + + ...VIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json | 146 + + ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 130 + + ...280,device_name=NVIDIA_A100-SXM4-80GB.json | 146 + + ...280,device_name=NVIDIA_H100_80GB_HBM3.json | 146 + + ...640,device_name=NVIDIA_A100-SXM4-80GB.json | 146 + + ...640,device_name=NVIDIA_H100_80GB_HBM3.json | 146 + + ...14336,device_name=AMD_Instinct_MI300X.json | 200 + + ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 138 + + ...=1792,device_name=AMD_Instinct_MI300X.json | 200 + + ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 146 + + ...=3584,device_name=AMD_Instinct_MI300X.json | 200 + + ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 146 + + .../E=8,N=3584,device_name=NVIDIA_L40S.json | 173 + + ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 146 + + ...=7168,device_name=AMD_Instinct_MI300X.json | 200 + + ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 146 + + ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 146 + + .../layers/fused_moe/fused_marlin_moe.py | 359 ++ + .../layers/fused_moe/fused_moe.py | 772 +++- + vllm/model_executor/layers/fused_moe/layer.py | 619 +++ + .../layers/fused_moe/moe_pallas.py | 62 + + .../layers/fused_moe/moe_torch_iterative.py | 51 + + vllm/model_executor/layers/layernorm.py | 155 +- + vllm/model_executor/layers/linear.py | 756 +++- + .../model_executor/layers/logits_processor.py | 97 +- + vllm/model_executor/layers/mamba/__init__.py | 0 + .../layers/mamba/mamba_mixer.py | 241 ++ + .../layers/mamba/ops/__init__.py | 0 + .../layers/mamba/ops/causal_conv1d.py | 102 + + .../layers/mamba/ops/mamba_ssm.py | 411 ++ + vllm/model_executor/layers/pooler.py | 320 ++ + .../layers/quantization/__init__.py | 96 +- + .../layers/quantization/aqlm.py | 21 +- + .../model_executor/layers/quantization/awq.py | 100 +- + .../layers/quantization/awq_marlin.py | 475 +++ + .../layers/quantization/awq_triton.py | 317 ++ + .../layers/quantization/base_config.py | 62 +- + .../layers/quantization/bitsandbytes.py | 357 ++ + .../compressed_tensors/__init__.py | 0 + .../compressed_tensors/compressed_tensors.py | 555 +++ + .../compressed_tensors_moe.py | 519 +++ + .../compressed_tensors/schemes/__init__.py | 18 + + .../schemes/compressed_tensors_24.py | 208 + + .../schemes/compressed_tensors_scheme.py | 52 + + .../schemes/compressed_tensors_w4a16_24.py | 157 + + .../schemes/compressed_tensors_w8a16_fp8.py | 117 + + .../schemes/compressed_tensors_w8a8_fp8.py | 146 + + .../schemes/compressed_tensors_w8a8_int8.py | 108 + + .../schemes/compressed_tensors_wNa16.py | 162 + + .../compressed_tensors/triton_scaled_mm.py | 199 + + .../quantization/compressed_tensors/utils.py | 171 + + .../layers/quantization/deepspeedfp.py | 190 + + .../layers/quantization/experts_int8.py | 180 + + .../layers/quantization/fbgemm_fp8.py | 165 + + .../model_executor/layers/quantization/fp8.py | 653 +++- + .../layers/quantization/gguf.py | 226 ++ + .../layers/quantization/gptq.py | 142 +- + .../layers/quantization/gptq_marlin.py | 749 ++-- + .../layers/quantization/gptq_marlin_24.py | 292 ++ + .../layers/quantization/hqq_marlin.py | 325 ++ + .../layers/quantization/ipex_quant.py | 247 ++ + .../layers/quantization/kernels/__init__.py | 0 + .../kernels/mixed_precision/MPLinearKernel.py | 87 + + .../kernels/mixed_precision/__init__.py | 74 + + .../kernels/mixed_precision/exllama.py | 140 + + .../kernels/mixed_precision/machete.py | 120 + + .../kernels/mixed_precision/marlin.py | 133 + + .../kernels/scaled_mm/ScaledMMLinearKernel.py | 64 + + .../kernels/scaled_mm/__init__.py | 84 + + .../quantization/kernels/scaled_mm/cutlass.py | 134 + + .../quantization/kernels/scaled_mm/xla.py | 101 + + .../layers/quantization/kv_cache.py | 78 + + .../layers/quantization/marlin.py | 110 +- + .../layers/quantization/modelopt.py | 163 + + .../layers/quantization/neuron_quant.py | 64 + + .../model_executor/layers/quantization/qqq.py | 270 ++ + .../layers/quantization/tpu_int8.py | 116 + + .../layers/quantization/utils/__init__.py | 3 + + .../layers/quantization/utils/fp8_utils.py | 353 ++ + .../layers/quantization/utils/layer_utils.py | 37 + + .../quantization/utils/machete_utils.py | 30 + + .../layers/quantization/utils/marlin_utils.py | 350 ++ + .../quantization/utils/marlin_utils_fp8.py | 108 + + .../quantization/utils/marlin_utils_test.py | 163 + + .../utils/marlin_utils_test_24.py | 463 +++ + .../utils/marlin_utils_test_qqq.py | 125 + + .../layers/quantization/utils/quant_utils.py | 454 +++ + .../layers/quantization/utils/w8a8_utils.py | 225 ++ + .../layers/rejection_sampler.py | 389 +- + vllm/model_executor/layers/resampler.py | 267 ++ + .../model_executor/layers/rotary_embedding.py | 687 +++- + vllm/model_executor/layers/sampler.py | 893 +++-- + .../layers/spec_decode_base_sampler.py | 254 ++ + .../layers/typical_acceptance_sampler.py | 170 + + vllm/model_executor/layers/utils.py | 57 + + .../layers/vocab_parallel_embedding.py | 415 +- + vllm/model_executor/model_loader/__init__.py | 20 +- + vllm/model_executor/model_loader/loader.py | 1331 ++++++- + vllm/model_executor/model_loader/neuron.py | 155 +- + vllm/model_executor/model_loader/openvino.py | 203 + + .../model_executor/model_loader/tensorizer.py | 242 +- + vllm/model_executor/model_loader/utils.py | 26 +- + .../model_loader/weight_utils.py | 371 +- + vllm/model_executor/models/__init__.py | 133 +- + vllm/model_executor/models/adapters.py | 248 ++ + vllm/model_executor/models/arctic.py | 581 +++ + vllm/model_executor/models/aria.py | 688 ++++ + vllm/model_executor/models/baichuan.py | 200 +- + vllm/model_executor/models/bart.py | 998 +++++ + vllm/model_executor/models/bert.py | 532 +++ + vllm/model_executor/models/blip.py | 333 ++ + vllm/model_executor/models/blip2.py | 739 ++++ + vllm/model_executor/models/bloom.py | 136 +- + vllm/model_executor/models/chameleon.py | 1166 ++++++ + vllm/model_executor/models/chatglm.py | 548 ++- + vllm/model_executor/models/clip.py | 544 +++ + vllm/model_executor/models/commandr.py | 234 +- + vllm/model_executor/models/dbrx.py | 271 +- + vllm/model_executor/models/decilm.py | 32 +- + vllm/model_executor/models/deepseek.py | 143 +- + vllm/model_executor/models/deepseek_v2.py | 652 ++++ + vllm/model_executor/models/deepseek_v3.py | 663 ++++ + vllm/model_executor/models/deepseek_vl2.py | 662 ++++ + vllm/model_executor/models/eagle.py | 212 + + vllm/model_executor/models/exaone.py | 614 +++ + vllm/model_executor/models/falcon.py | 201 +- + vllm/model_executor/models/florence2.py | 264 ++ + vllm/model_executor/models/fuyu.py | 406 ++ + vllm/model_executor/models/gemma.py | 214 +- + vllm/model_executor/models/gemma2.py | 471 +++ + vllm/model_executor/models/glm.py | 21 + + .../models/glm4_vision_encoder.py | 296 ++ + vllm/model_executor/models/gpt2.py | 164 +- + vllm/model_executor/models/gpt_bigcode.py | 175 +- + vllm/model_executor/models/gpt_j.py | 136 +- + vllm/model_executor/models/gpt_neox.py | 131 +- + vllm/model_executor/models/granite.py | 553 +++ + vllm/model_executor/models/granitemoe.py | 458 +++ + vllm/model_executor/models/gritlm.py | 248 ++ + vllm/model_executor/models/h2ovl.py | 400 ++ + .../models/idefics2_vision_model.py | 344 ++ + vllm/model_executor/models/idefics3.py | 777 ++++ + vllm/model_executor/models/interfaces.py | 441 +++ + vllm/model_executor/models/interfaces_base.py | 177 + + vllm/model_executor/models/intern_vit.py | 474 +++ + vllm/model_executor/models/internlm2.py | 316 +- + vllm/model_executor/models/internlm2_ve.py | 154 + + vllm/model_executor/models/internvl.py | 777 ++++ + vllm/model_executor/models/jais.py | 162 +- + vllm/model_executor/models/jamba.py | 631 +++ + vllm/model_executor/models/llama.py | 523 ++- + vllm/model_executor/models/llava.py | 923 ++++- + vllm/model_executor/models/llava_next.py | 587 +++ + .../model_executor/models/llava_next_video.py | 493 +++ + vllm/model_executor/models/llava_onevision.py | 903 +++++ + vllm/model_executor/models/mamba.py | 302 ++ + vllm/model_executor/models/mamba_cache.py | 158 + + vllm/model_executor/models/medusa.py | 208 + + vllm/model_executor/models/minicpm.py | 381 +- + vllm/model_executor/models/minicpm3.py | 251 ++ + vllm/model_executor/models/minicpmv.py | 1023 +++++ + vllm/model_executor/models/mixtral.py | 466 +-- + vllm/model_executor/models/mixtral_quant.py | 166 +- + vllm/model_executor/models/mllama.py | 1527 ++++++++ + vllm/model_executor/models/mlp_speculator.py | 203 + + vllm/model_executor/models/module_mapping.py | 69 + + vllm/model_executor/models/molmo.py | 1412 +++++++ + vllm/model_executor/models/mpt.py | 126 +- + vllm/model_executor/models/nemotron.py | 531 +++ + vllm/model_executor/models/nvlm_d.py | 88 + + vllm/model_executor/models/olmo.py | 138 +- + vllm/model_executor/models/olmo2.py | 432 +++ + vllm/model_executor/models/olmoe.py | 466 +++ + vllm/model_executor/models/opt.py | 189 +- + vllm/model_executor/models/orion.py | 138 +- + vllm/model_executor/models/paligemma.py | 321 ++ + vllm/model_executor/models/persimmon.py | 368 ++ + vllm/model_executor/models/phi.py | 179 +- + vllm/model_executor/models/phi3.py | 20 + + vllm/model_executor/models/phi3_small.py | 482 +++ + vllm/model_executor/models/phi3v.py | 732 ++++ + vllm/model_executor/models/phimoe.py | 676 ++++ + vllm/model_executor/models/pixtral.py | 1123 ++++++ + vllm/model_executor/models/qwen.py | 899 ++++- + vllm/model_executor/models/qwen2.py | 407 +- + vllm/model_executor/models/qwen2_audio.py | 417 ++ + vllm/model_executor/models/qwen2_moe.py | 306 +- + vllm/model_executor/models/qwen2_rm.py | 117 + + vllm/model_executor/models/qwen2_vl.py | 1355 +++++++ + vllm/model_executor/models/registry.py | 514 +++ + vllm/model_executor/models/roberta.py | 256 ++ + vllm/model_executor/models/siglip.py | 655 ++++ + vllm/model_executor/models/solar.py | 573 +++ + vllm/model_executor/models/stablelm.py | 159 +- + vllm/model_executor/models/starcoder2.py | 171 +- + vllm/model_executor/models/telechat2.py | 132 + + vllm/model_executor/models/ultravox.py | 556 +++ + vllm/model_executor/models/utils.py | 642 +++ + vllm/model_executor/models/vision.py | 145 + + vllm/model_executor/models/whisper.py | 735 ++++ + vllm/model_executor/parameter.py | 425 ++ + vllm/model_executor/pooling_metadata.py | 69 + + vllm/model_executor/sampling_metadata.py | 399 +- + vllm/model_executor/utils.py | 31 +- + vllm/multimodal/__init__.py | 31 + + vllm/multimodal/audio.py | 75 + + vllm/multimodal/base.py | 461 +++ + vllm/multimodal/hasher.py | 100 + + vllm/multimodal/image.py | 137 + + vllm/multimodal/inputs.py | 523 +++ + vllm/multimodal/parse.py | 366 ++ + vllm/multimodal/processing.py | 1190 ++++++ + vllm/multimodal/profiling.py | 206 + + vllm/multimodal/registry.py | 423 ++ + vllm/multimodal/utils.py | 479 +++ + vllm/multimodal/video.py | 188 + + vllm/outputs.py | 481 ++- + vllm/platforms/__init__.py | 223 ++ + vllm/platforms/cpu.py | 111 + + vllm/platforms/cuda.py | 365 ++ + vllm/platforms/hpu.py | 64 + + vllm/platforms/interface.py | 272 ++ + vllm/platforms/neuron.py | 46 + + vllm/platforms/openvino.py | 143 + + vllm/platforms/rocm.py | 153 + + vllm/platforms/tpu.py | 81 + + vllm/platforms/xpu.py | 95 + + vllm/plugins/__init__.py | 88 + + vllm/pooling_params.py | 23 + + vllm/profiler/__init__.py | 5 + + vllm/profiler/layerwise_profile.py | 372 ++ + vllm/profiler/utils.py | 145 + + vllm/prompt_adapter/__init__.py | 0 + vllm/prompt_adapter/layers.py | 80 + + vllm/prompt_adapter/models.py | 355 ++ + vllm/prompt_adapter/request.py | 34 + + vllm/prompt_adapter/utils.py | 94 + + vllm/prompt_adapter/worker_manager.py | 176 + + vllm/sampling_params.py | 445 ++- + vllm/scalar_type.py | 330 ++ + vllm/scripts.py | 207 + + vllm/sequence.py | 1310 +++++-- + vllm/spec_decode/batch_expansion.py | 311 +- + vllm/spec_decode/draft_model_runner.py | 323 ++ + vllm/spec_decode/interfaces.py | 21 +- + vllm/spec_decode/medusa_worker.py | 137 + + vllm/spec_decode/metrics.py | 56 +- + vllm/spec_decode/mlp_speculator_worker.py | 91 + + vllm/spec_decode/mqa_scorer.py | 113 + + vllm/spec_decode/multi_step_worker.py | 324 +- + vllm/spec_decode/ngram_worker.py | 159 +- + vllm/spec_decode/proposer_worker_base.py | 56 + + .../spec_decode/smaller_tp_proposer_worker.py | 161 + + vllm/spec_decode/spec_decode_worker.py | 952 ++++- + vllm/spec_decode/target_model_runner.py | 42 + + vllm/spec_decode/top1_proposer.py | 158 +- + vllm/spec_decode/util.py | 178 +- + vllm/tracing.py | 119 + + vllm/transformers_utils/__init__.py | 17 + + vllm/transformers_utils/config.py | 593 ++- + vllm/transformers_utils/configs/__init__.py | 32 +- + vllm/transformers_utils/configs/arctic.py | 204 + + vllm/transformers_utils/configs/aria.py | 47 + + vllm/transformers_utils/configs/chatglm.py | 3 +- + vllm/transformers_utils/configs/cohere2.py | 192 + + .../configs/deepseek_vl2.py | 214 + + vllm/transformers_utils/configs/eagle.py | 49 + + vllm/transformers_utils/configs/exaone.py | 189 + + vllm/transformers_utils/configs/h2ovl.py | 13 + + vllm/transformers_utils/configs/internvl.py | 51 + + vllm/transformers_utils/configs/jais.py | 1 - + vllm/transformers_utils/configs/medusa.py | 60 + + vllm/transformers_utils/configs/mllama.py | 28 + + .../configs/mlp_speculator.py | 65 + + vllm/transformers_utils/configs/mpt.py | 7 +- + vllm/transformers_utils/configs/nemotron.py | 202 + + vllm/transformers_utils/configs/nvlm_d.py | 12 + + vllm/transformers_utils/configs/olmo2.py | 166 + + vllm/transformers_utils/configs/solar.py | 244 ++ + vllm/transformers_utils/configs/telechat2.py | 61 + + vllm/transformers_utils/configs/ultravox.py | 99 + + vllm/transformers_utils/detokenizer.py | 202 +- + vllm/transformers_utils/detokenizer_utils.py | 167 + + vllm/transformers_utils/processor.py | 104 + + vllm/transformers_utils/s3_utils.py | 151 + + vllm/transformers_utils/tokenizer.py | 180 +- + .../tokenizer_group/__init__.py | 46 +- + .../tokenizer_group/base_tokenizer_group.py | 39 +- + .../tokenizer_group/ray_tokenizer_group.py | 169 +- + .../tokenizer_group/tokenizer_group.py | 66 +- + .../transformers_utils/tokenizers/__init__.py | 6 +- + vllm/transformers_utils/tokenizers/mistral.py | 366 ++ + vllm/transformers_utils/utils.py | 20 + + vllm/triton_utils/__init__.py | 10 + + vllm/triton_utils/custom_cache_manager.py | 53 + + vllm/triton_utils/importing.py | 15 + + vllm/usage/usage_lib.py | 36 +- + vllm/utils.py | 1980 ++++++++-- + vllm/v1/__init__.py | 0 + vllm/v1/attention/__init__.py | 0 + vllm/v1/attention/backends/__init__.py | 0 + vllm/v1/attention/backends/flash_attn.py | 430 ++ + vllm/v1/core/__init__.py | 0 + vllm/v1/core/encoder_cache_manager.py | 48 + + vllm/v1/core/kv_cache_manager.py | 479 +++ + vllm/v1/core/kv_cache_utils.py | 307 ++ + vllm/v1/core/scheduler.py | 618 +++ + vllm/v1/engine/__init__.py | 79 + + vllm/v1/engine/async_llm.py | 342 ++ + vllm/v1/engine/core.py | 286 ++ + vllm/v1/engine/core_client.py | 268 ++ + vllm/v1/engine/detokenizer.py | 180 + + vllm/v1/engine/llm_engine.py | 179 + + vllm/v1/engine/mm_input_mapper.py | 142 + + vllm/v1/engine/output_processor.py | 200 + + vllm/v1/engine/processor.py | 223 ++ + vllm/v1/executor/__init__.py | 0 + vllm/v1/executor/abstract.py | 57 + + vllm/v1/executor/multiproc_executor.py | 405 ++ + vllm/v1/executor/ray_executor.py | 342 ++ + vllm/v1/executor/ray_utils.py | 280 ++ + vllm/v1/executor/uniproc_executor.py | 84 + + vllm/v1/metrics/__init__.py | 0 + vllm/v1/metrics/loggers.py | 38 + + vllm/v1/metrics/stats.py | 39 + + vllm/v1/outputs.py | 39 + + vllm/v1/request.py | 171 + + vllm/v1/sample/__init__.py | 0 + vllm/v1/sample/metadata.py | 31 + + vllm/v1/sample/ops/__init__.py | 0 + vllm/v1/sample/ops/penalties.py | 59 + + vllm/v1/sample/ops/topk_topp_sampler.py | 201 + + vllm/v1/sample/sampler.py | 136 + + vllm/v1/serial_utils.py | 10 + + vllm/v1/utils.py | 136 + + vllm/v1/worker/__init__.py | 0 + vllm/v1/worker/block_table.py | 78 + + vllm/v1/worker/gpu_input_batch.py | 435 +++ + vllm/v1/worker/gpu_model_runner.py | 866 +++++ + vllm/v1/worker/gpu_worker.py | 273 ++ + vllm/version.py | 11 + + vllm/worker/cache_engine.py | 58 +- + vllm/worker/cpu_enc_dec_model_runner.py | 325 ++ + vllm/worker/cpu_model_runner.py | 903 +++-- + vllm/worker/cpu_pooling_model_runner.py | 134 + + vllm/worker/cpu_worker.py | 247 +- + vllm/worker/enc_dec_model_runner.py | 526 +++ + vllm/worker/hpu_model_runner.py | 2016 ++++++++++ + vllm/worker/hpu_worker.py | 410 ++ + vllm/worker/model_runner.py | 2698 ++++++++----- + vllm/worker/model_runner_base.py | 305 ++ + vllm/worker/multi_step_model_runner.py | 907 +++++ + vllm/worker/multi_step_tpu_worker.py | 105 + + vllm/worker/multi_step_worker.py | 194 + + vllm/worker/neuron_model_runner.py | 232 +- + vllm/worker/neuron_worker.py | 83 +- + vllm/worker/openvino_model_runner.py | 369 ++ + vllm/worker/openvino_worker.py | 588 +++ + vllm/worker/pooling_model_runner.py | 201 + + vllm/worker/tpu_model_runner.py | 896 +++++ + vllm/worker/tpu_worker.py | 294 ++ + vllm/worker/utils.py | 51 + + vllm/worker/worker.py | 473 ++- + vllm/worker/worker_base.py | 392 +- + vllm/worker/xpu_model_runner.py | 609 +++ + vllm/worker/xpu_worker.py | 184 + + 1537 files changed, 284998 insertions(+), 25880 deletions(-) + create mode 100644 .buildkite/generate_index.py + create mode 100644 .buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Mixtral-8x22B-Instruct-v0.1-FP8-Dynamic.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-FP8.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml + create mode 100644 .buildkite/lm-eval-harness/configs/models-large.txt + create mode 100644 .buildkite/lm-eval-harness/configs/models-small.txt + create mode 100644 .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh + create mode 100644 .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh + create mode 100644 .buildkite/lm-eval-harness/run-tests.sh + create mode 100644 .buildkite/lm-eval-harness/test_lm_eval_correctness.py + create mode 100644 .buildkite/nightly-benchmarks/README.md + create mode 100644 .buildkite/nightly-benchmarks/benchmark-pipeline.yaml + create mode 100644 .buildkite/nightly-benchmarks/nightly-annotation.md + create mode 100644 .buildkite/nightly-benchmarks/nightly-descriptions.md + create mode 100644 .buildkite/nightly-benchmarks/nightly-pipeline.yaml + create mode 100644 .buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md + create mode 100644 .buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py + create mode 100644 .buildkite/nightly-benchmarks/scripts/download-tokenizer.py + create mode 100644 .buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py + create mode 100644 .buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py + create mode 100644 .buildkite/nightly-benchmarks/scripts/launch-server.sh + create mode 100644 .buildkite/nightly-benchmarks/scripts/nightly-annotate.sh + create mode 100644 .buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh + create mode 100644 .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh + create mode 100644 .buildkite/nightly-benchmarks/scripts/summary-nightly-results.py + create mode 100644 .buildkite/nightly-benchmarks/scripts/wait-for-image.sh + create mode 100644 .buildkite/nightly-benchmarks/tests/latency-tests.json + create mode 100644 .buildkite/nightly-benchmarks/tests/nightly-tests.json + create mode 100644 .buildkite/nightly-benchmarks/tests/serving-tests.json + create mode 100644 .buildkite/nightly-benchmarks/tests/throughput-tests.json + create mode 100644 .buildkite/release-pipeline.yaml + create mode 100644 .buildkite/run-cpu-test-ppc64le.sh + create mode 100644 .buildkite/run-gh200-test.sh + create mode 100644 .buildkite/run-hpu-test.sh + create mode 100644 .buildkite/run-multi-node-test.sh + create mode 100644 .buildkite/run-openvino-test.sh + create mode 100644 .buildkite/run-tpu-test.sh + create mode 100644 .buildkite/run-xpu-test.sh + create mode 100644 .buildkite/upload-wheels.sh + create mode 100644 .clang-format + create mode 100644 .github/CODEOWNERS + create mode 100644 .github/FUNDING.yml + create mode 100644 .github/ISSUE_TEMPLATE/400-bug-report.yml + create mode 100644 .github/ISSUE_TEMPLATE/500-feature-request.yml + create mode 100644 .github/ISSUE_TEMPLATE/600-new-model.yml + create mode 100644 .github/ISSUE_TEMPLATE/700-performance-discussion.yml + create mode 100644 .github/ISSUE_TEMPLATE/800-misc-discussion.yml + create mode 100644 .github/dependabot.yml + create mode 100644 .github/mergify.yml + create mode 100644 .github/scripts/cleanup_pr_body.sh + create mode 100644 .github/workflows/actionlint.yml + create mode 100644 .github/workflows/add_label_automerge.yml + create mode 100644 .github/workflows/clang-format.yml + create mode 100644 .github/workflows/cleanup_pr_body.yml + create mode 100644 .github/workflows/codespell.yml + create mode 100644 .github/workflows/doc-lint.yml + create mode 100644 .github/workflows/lint-and-deploy.yaml + create mode 100644 .github/workflows/matchers/actionlint.json + create mode 100644 .github/workflows/matchers/mypy.json + create mode 100644 .github/workflows/matchers/ruff.json + create mode 100644 .github/workflows/png-lint.yml + create mode 100644 .github/workflows/reminder_comment.yml + create mode 100644 .github/workflows/shellcheck.yml + create mode 100644 .github/workflows/stale.yml + create mode 100644 .shellcheckrc + create mode 100644 CODE_OF_CONDUCT.md + create mode 100644 DCO + create mode 100644 Dockerfile.arm + create mode 100644 Dockerfile.hpu + create mode 100644 Dockerfile.openvino + create mode 100644 Dockerfile.ppc64le + create mode 100644 Dockerfile.tpu + create mode 100644 Dockerfile.xpu + create mode 100644 SECURITY.md + create mode 100644 benchmarks/benchmark_guided.py + create mode 100644 benchmarks/benchmark_long_document_qa_throughput.py + create mode 100644 benchmarks/benchmark_prioritization.py + create mode 100644 benchmarks/benchmark_serving_guided.py + create mode 100644 benchmarks/cutlass_benchmarks/sparse_benchmarks.py + create mode 100644 benchmarks/cutlass_benchmarks/utils.py + create mode 100644 benchmarks/cutlass_benchmarks/w8a8_benchmarks.py + create mode 100644 benchmarks/cutlass_benchmarks/weight_shapes.py + create mode 100644 benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh + create mode 100644 benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh + create mode 100644 benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py + create mode 100644 benchmarks/disagg_benchmarks/round_robin_proxy.py + create mode 100644 benchmarks/disagg_benchmarks/visualize_benchmark_results.py + create mode 100644 benchmarks/fused_kernels/layernorm_rms_benchmarks.py + create mode 100644 benchmarks/kernels/benchmark_layernorm.py + create mode 100644 benchmarks/kernels/benchmark_machete.py + create mode 100644 benchmarks/kernels/benchmark_marlin.py + create mode 100644 benchmarks/kernels/benchmark_moe.py + create mode 100644 benchmarks/kernels/benchmark_quant.py + create mode 100644 benchmarks/kernels/benchmark_rmsnorm.py + create mode 100644 benchmarks/kernels/benchmark_shapes.py + create mode 100644 benchmarks/kernels/graph_machete_bench.py + create mode 100644 benchmarks/kernels/requirements.txt + create mode 100644 benchmarks/kernels/weight_shapes.py + create mode 100644 benchmarks/overheads/benchmark_hashing.py + create mode 100644 benchmarks/structured_schemas/structured_schema_1.json + create mode 100644 csrc/attention/attention_kernels.cuh + create mode 100644 csrc/attention/paged_attention_v1.cu + create mode 100644 csrc/attention/paged_attention_v2.cu + create mode 100644 csrc/core/exception.hpp + create mode 100644 csrc/core/math.hpp + create mode 100644 csrc/core/registration.h + create mode 100644 csrc/core/scalar_type.hpp + create mode 100644 csrc/cpu/cpu_types_arm.hpp + create mode 100644 csrc/cpu/cpu_types_vsx.hpp + create mode 100644 csrc/cpu/cpu_types_x86.hpp + create mode 100644 csrc/cpu/dnnl_helper.hpp + create mode 100644 csrc/cpu/quant.cpp + create mode 100644 csrc/cpu/torch_bindings.cpp + create mode 100644 csrc/cpu/utils.cpp + create mode 100644 csrc/cutlass_extensions/common.cpp + create mode 100644 csrc/cutlass_extensions/common.hpp + create mode 100644 csrc/cutlass_extensions/cute_utils.cuh + create mode 100644 csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp + create mode 100644 csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp + create mode 100644 csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp + create mode 100644 csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp + create mode 100644 csrc/cutlass_extensions/torch_utils.hpp + create mode 100644 csrc/cutlass_extensions/vllm_collective_builder.cuh + create mode 100644 csrc/cutlass_extensions/vllm_custom_types.cuh + create mode 100644 csrc/cutlass_extensions/vllm_cutlass_library_extension.py + create mode 100644 csrc/cutlass_extensions/vllm_numeric_conversion.cuh + create mode 100644 csrc/cutlass_extensions/vllm_type_utils.cuh + create mode 100644 csrc/layernorm_quant_kernels.cu + create mode 100644 csrc/mamba/causal_conv1d/causal_conv1d.cu + create mode 100644 csrc/mamba/causal_conv1d/causal_conv1d.h + create mode 100644 csrc/mamba/causal_conv1d/static_switch.h + create mode 100644 csrc/mamba/mamba_ssm/selective_scan.h + create mode 100644 csrc/mamba/mamba_ssm/selective_scan_fwd.cu + create mode 100644 csrc/mamba/mamba_ssm/static_switch.h + create mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel.h + create mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu + create mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h + create mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu + create mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h + create mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu + create mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h + create mode 100644 csrc/moe/marlin_moe_ops.cu + create mode 100644 csrc/moe/moe_align_sum_kernels.cu + create mode 100644 csrc/moe/torch_bindings.cpp + create mode 100644 csrc/permute_cols.cu + create mode 100644 csrc/prepare_inputs/advance_step.cu + create mode 100644 csrc/prepare_inputs/advance_step.cuh + create mode 100644 csrc/quantization/compressed_tensors/int8_quant_kernels.cu + create mode 100644 csrc/quantization/cutlass_w8a8/Epilogues.md + create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu + create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh + create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh + create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh + create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh + create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh + create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu + create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh + create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh + create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh + create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu + create mode 100644 csrc/quantization/fp8/amd/hip_float8.h + create mode 100644 csrc/quantization/fp8/amd/hip_float8_impl.h + create mode 100644 csrc/quantization/fp8/amd/quant_utils.cuh + create mode 100644 csrc/quantization/fp8/common.cu + create mode 100644 csrc/quantization/fp8/common.cuh + create mode 100644 csrc/quantization/fp8/fp8_marlin.cu + create mode 100644 csrc/quantization/fp8/nvidia/quant_utils.cuh + create mode 100644 csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu + create mode 100644 csrc/quantization/fused_kernels/layernorm_utils.cuh + create mode 100644 csrc/quantization/fused_kernels/quant_conversions.cuh + create mode 100644 csrc/quantization/gguf/dequantize.cuh + create mode 100644 csrc/quantization/gguf/ggml-common.h + create mode 100644 csrc/quantization/gguf/gguf_kernel.cu + create mode 100644 csrc/quantization/gguf/mmq.cuh + create mode 100644 csrc/quantization/gguf/mmvq.cuh + create mode 100644 csrc/quantization/gguf/vecdotq.cuh + create mode 100644 csrc/quantization/gptq_marlin/awq_marlin_repack.cu + create mode 100644 csrc/quantization/gptq_marlin/marlin.cuh + create mode 100644 csrc/quantization/gptq_marlin/marlin_dtypes.cuh + create mode 100644 csrc/quantization/machete/Readme.md + create mode 100644 csrc/quantization/machete/generate.py + create mode 100644 csrc/quantization/machete/machete_collective_builder.cuh + create mode 100644 csrc/quantization/machete/machete_interleaving_utils.cuh + create mode 100644 csrc/quantization/machete/machete_mainloop.cuh + create mode 100644 csrc/quantization/machete/machete_mm_kernel.cuh + create mode 100644 csrc/quantization/machete/machete_mm_launcher.cuh + create mode 100644 csrc/quantization/machete/machete_prepack_kernel.cuh + create mode 100644 csrc/quantization/machete/machete_prepack_launcher.cuh + create mode 100644 csrc/quantization/machete/machete_prepacked_layout.cuh + create mode 100644 csrc/quantization/machete/machete_pytorch.cu + create mode 100644 csrc/quantization/marlin/dense/LICENSE + create mode 100644 csrc/quantization/marlin/dense/common/base.h + create mode 100644 csrc/quantization/marlin/dense/common/mem.h + create mode 100644 csrc/quantization/marlin/dense/marlin_cuda_kernel.cu + create mode 100644 csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu + create mode 100644 csrc/quantization/marlin/sparse/LICENSE + create mode 100644 csrc/quantization/marlin/sparse/common/base.h + create mode 100644 csrc/quantization/marlin/sparse/common/mem.h + create mode 100644 csrc/quantization/marlin/sparse/common/mma.h + create mode 100644 csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu + create mode 100644 csrc/quantization/vectorization.cuh + create mode 100644 csrc/rocm/attention.cu + create mode 100644 csrc/rocm/ops.h + create mode 100644 csrc/rocm/torch_bindings.cpp + create mode 100644 csrc/sparse/cutlass/sparse_compressor_c3x.cu + create mode 100644 csrc/sparse/cutlass/sparse_compressor_entry.cu + create mode 100644 csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu + create mode 100644 csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh + create mode 100644 csrc/sparse/cutlass/sparse_scaled_mm_entry.cu + create mode 100644 csrc/torch_bindings.cpp + create mode 100644 csrc/type_convert.cuh + create mode 100644 docs/source/_static/custom.js + create mode 100644 docs/source/_templates/sections/header.html + create mode 100644 docs/source/api/engine/async_llm_engine.md + create mode 100644 docs/source/api/engine/index.md + create mode 100644 docs/source/api/engine/llm_engine.md + create mode 100644 docs/source/api/inference_params.md + create mode 100644 docs/source/api/model/adapters.md + create mode 100644 docs/source/api/model/index.md + create mode 100644 docs/source/api/model/interfaces.md + create mode 100644 docs/source/api/model/interfaces_base.md + create mode 100644 docs/source/api/multimodal/index.md + create mode 100644 docs/source/api/multimodal/inputs.md + create mode 100644 docs/source/api/multimodal/parse.md + create mode 100644 docs/source/api/multimodal/processing.md + create mode 100644 docs/source/api/multimodal/profiling.md + create mode 100644 docs/source/api/multimodal/registry.md + create mode 100644 docs/source/api/offline_inference/index.md + create mode 100644 docs/source/api/offline_inference/llm.md + create mode 100644 docs/source/api/offline_inference/llm_inputs.md + create mode 100644 docs/source/assets/contributing/dockerfile-stages-dependency.png + create mode 100644 docs/source/assets/deployment/architecture_helm_deployment.png + create mode 100644 docs/source/assets/design/arch_overview/entrypoints.excalidraw.png + create mode 100644 docs/source/assets/design/arch_overview/llm_engine.excalidraw.png + create mode 100644 docs/source/assets/design/hierarchy.png + create mode 100644 docs/source/assets/features/disagg_prefill/abstraction.jpg + create mode 100644 docs/source/assets/features/disagg_prefill/overview.jpg + create mode 100644 docs/source/community/meetups.md + create mode 100644 docs/source/community/sponsors.md + create mode 100644 docs/source/contributing/dockerfile/dockerfile.md + create mode 100644 docs/source/contributing/model/basic.md + create mode 100644 docs/source/contributing/model/index.md + create mode 100644 docs/source/contributing/model/multimodal.md + create mode 100644 docs/source/contributing/model/registration.md + create mode 100644 docs/source/contributing/model/tests.md + create mode 100644 docs/source/contributing/overview.md + create mode 100644 docs/source/contributing/profiling/profiling_index.md + create mode 100644 docs/source/contributing/vulnerability_management.md + create mode 100644 docs/source/deployment/docker.md + create mode 100644 docs/source/deployment/frameworks/bentoml.md + create mode 100644 docs/source/deployment/frameworks/cerebrium.md + create mode 100644 docs/source/deployment/frameworks/dstack.md + create mode 100644 docs/source/deployment/frameworks/helm.md + create mode 100644 docs/source/deployment/frameworks/index.md + create mode 100644 docs/source/deployment/frameworks/lws.md + create mode 100644 docs/source/deployment/frameworks/modal.md + create mode 100644 docs/source/deployment/frameworks/skypilot.md + create mode 100644 docs/source/deployment/frameworks/triton.md + create mode 100644 docs/source/deployment/integrations/index.md + create mode 100644 docs/source/deployment/integrations/kserve.md + create mode 100644 docs/source/deployment/integrations/kubeai.md + create mode 100644 docs/source/deployment/integrations/llamastack.md + create mode 100644 docs/source/deployment/k8s.md + create mode 100644 docs/source/deployment/nginx.md + create mode 100644 docs/source/design/arch_overview.md + create mode 100644 docs/source/design/automatic_prefix_caching.md + create mode 100644 docs/source/design/huggingface_integration.md + create mode 100644 docs/source/design/kernel/paged_attention.md + create mode 100644 docs/source/design/mm_processing.md + create mode 100644 docs/source/design/multiprocessing.md + create mode 100644 docs/source/design/plugin_system.md + create mode 100644 docs/source/features/automatic_prefix_caching.md + create mode 100644 docs/source/features/compatibility_matrix.md + create mode 100644 docs/source/features/disagg_prefill.md + create mode 100644 docs/source/features/lora.md + create mode 100644 docs/source/features/quantization/auto_awq.md + create mode 100644 docs/source/features/quantization/bnb.md + create mode 100644 docs/source/features/quantization/fp8.md + create mode 100644 docs/source/features/quantization/fp8_e4m3_kvcache.md + create mode 100644 docs/source/features/quantization/fp8_e5m2_kvcache.md + create mode 100644 docs/source/features/quantization/gguf.md + create mode 100644 docs/source/features/quantization/index.md + create mode 100644 docs/source/features/quantization/int8.md + create mode 100644 docs/source/features/quantization/supported_hardware.md + create mode 100644 docs/source/features/spec_decode.md + create mode 100644 docs/source/features/structured_outputs.md + create mode 100644 docs/source/features/tool_calling.md + create mode 100644 docs/source/getting_started/faq.md + create mode 100644 docs/source/getting_started/installation/cpu-apple.md + create mode 100644 docs/source/getting_started/installation/cpu-arm.md + create mode 100644 docs/source/getting_started/installation/cpu-x86.md + create mode 100644 docs/source/getting_started/installation/gpu-cuda.md + create mode 100644 docs/source/getting_started/installation/gpu-rocm.md + create mode 100644 docs/source/getting_started/installation/hpu-gaudi.md + create mode 100644 docs/source/getting_started/installation/index.md + create mode 100644 docs/source/getting_started/installation/neuron.md + create mode 100644 docs/source/getting_started/installation/openvino.md + create mode 100644 docs/source/getting_started/installation/tpu.md + create mode 100644 docs/source/getting_started/installation/xpu.md + create mode 100644 docs/source/getting_started/quickstart.md + create mode 100644 docs/source/getting_started/troubleshooting.md + create mode 100644 docs/source/index.md + create mode 100644 docs/source/models/extensions/index.md + create mode 100644 docs/source/models/extensions/runai_model_streamer.md + create mode 100644 docs/source/models/extensions/tensorizer.md + create mode 100644 docs/source/models/generative_models.md + create mode 100644 docs/source/models/pooling_models.md + create mode 100644 docs/source/models/supported_models.md + create mode 100644 docs/source/performance/benchmarks.md + create mode 100644 docs/source/performance/optimization.md + create mode 100644 docs/source/serving/distributed_serving.md + create mode 100644 docs/source/serving/engine_args.md + create mode 100644 docs/source/serving/env_vars.md + create mode 100644 docs/source/serving/integrations/index.md + create mode 100644 docs/source/serving/integrations/langchain.md + create mode 100644 docs/source/serving/integrations/llamaindex.md + create mode 100644 docs/source/serving/metrics.md + create mode 100644 docs/source/serving/multimodal_inputs.md + create mode 100644 docs/source/serving/offline_inference.md + create mode 100644 examples/offline_inference/aqlm_example.py + create mode 100644 examples/offline_inference/arctic.py + create mode 100644 examples/offline_inference/audio_language.py + create mode 100644 examples/offline_inference/basic.py + create mode 100644 examples/offline_inference/basic_with_model_default_sampling.py + create mode 100644 examples/offline_inference/chat.py + create mode 100644 examples/offline_inference/chat_with_tools.py + create mode 100644 examples/offline_inference/classification.py + create mode 100644 examples/offline_inference/cli.py + create mode 100644 examples/offline_inference/cpu_offload.py + create mode 100644 examples/offline_inference/distributed.py + create mode 100644 examples/offline_inference/embedding.py + create mode 100644 examples/offline_inference/encoder_decoder.py + create mode 100644 examples/offline_inference/florence2_inference.py + create mode 100644 examples/offline_inference/gguf_inference.py + create mode 100644 examples/offline_inference/llm_engine_example.py + create mode 100644 examples/offline_inference/lora_with_quantization_inference.py + create mode 100644 examples/offline_inference/mlpspeculator.py + create mode 100644 examples/offline_inference/multilora_inference.py + create mode 100644 examples/offline_inference/neuron.py + create mode 100644 examples/offline_inference/neuron_int8_quantization.py + create mode 100644 examples/offline_inference/openai/openai_batch.md + create mode 100644 examples/offline_inference/openai/openai_example_batch.jsonl + create mode 100644 examples/offline_inference/pixtral.py + create mode 100644 examples/offline_inference/prefix_caching.py + create mode 100644 examples/offline_inference/profiling.py + create mode 100644 examples/offline_inference/save_sharded_state.py + create mode 100644 examples/offline_inference/scoring.py + create mode 100644 examples/offline_inference/simple_profiling.py + create mode 100644 examples/offline_inference/structured_outputs.py + create mode 100644 examples/offline_inference/tpu.py + create mode 100644 examples/offline_inference/vision_language.py + create mode 100644 examples/offline_inference/vision_language_embedding.py + create mode 100644 examples/offline_inference/vision_language_multi_image.py + create mode 100644 examples/offline_inference/whisper.py + create mode 100644 examples/online_serving/api_client.py + create mode 100644 examples/online_serving/chart-helm/.helmignore + create mode 100644 examples/online_serving/chart-helm/Chart.yaml + create mode 100644 examples/online_serving/chart-helm/README.md + create mode 100644 examples/online_serving/chart-helm/ct.yaml + create mode 100644 examples/online_serving/chart-helm/lintconf.yaml + create mode 100644 examples/online_serving/chart-helm/templates/_helpers.tpl + create mode 100644 examples/online_serving/chart-helm/templates/configmap.yaml + create mode 100644 examples/online_serving/chart-helm/templates/custom-objects.yaml + create mode 100644 examples/online_serving/chart-helm/templates/deployment.yaml + create mode 100644 examples/online_serving/chart-helm/templates/hpa.yaml + create mode 100644 examples/online_serving/chart-helm/templates/job.yaml + create mode 100644 examples/online_serving/chart-helm/templates/poddisruptionbudget.yaml + create mode 100644 examples/online_serving/chart-helm/templates/pvc.yaml + create mode 100644 examples/online_serving/chart-helm/templates/secrets.yaml + create mode 100644 examples/online_serving/chart-helm/templates/service.yaml + create mode 100644 examples/online_serving/chart-helm/values.schema.json + create mode 100644 examples/online_serving/chart-helm/values.yaml + create mode 100644 examples/online_serving/disaggregated_prefill.sh + create mode 100644 examples/online_serving/gradio_openai_chatbot_webserver.py + create mode 100644 examples/online_serving/gradio_webserver.py + create mode 100644 examples/online_serving/openai_chat_completion_client.py + create mode 100644 examples/online_serving/openai_chat_completion_client_for_multimodal.py + create mode 100644 examples/online_serving/openai_chat_completion_client_with_tools.py + create mode 100644 examples/online_serving/openai_chat_completion_structured_outputs.py + create mode 100644 examples/online_serving/openai_chat_embedding_client_for_multimodal.py + create mode 100644 examples/online_serving/openai_completion_client.py + create mode 100644 examples/online_serving/openai_cross_encoder_score.py + create mode 100644 examples/online_serving/openai_embedding_client.py + create mode 100644 examples/online_serving/openai_pooling_client.py + create mode 100644 examples/online_serving/opentelemetry/Otel.md + create mode 100644 examples/online_serving/opentelemetry/dummy_client.py + create mode 100644 examples/online_serving/prometheus_grafana/README.md + create mode 100644 examples/online_serving/prometheus_grafana/docker-compose.yaml + create mode 100644 examples/online_serving/prometheus_grafana/grafana.json + create mode 100644 examples/online_serving/prometheus_grafana/prometheus.yaml + create mode 100644 examples/online_serving/run_cluster.sh + create mode 100644 examples/online_serving/sagemaker-entrypoint.sh + create mode 100644 examples/other/fp8/README.md + create mode 100644 examples/other/fp8/extract_scales.py + create mode 100644 examples/other/fp8/quantizer/README.md + create mode 100644 examples/other/fp8/quantizer/quantize.py + create mode 100644 examples/other/logging_configuration.md + create mode 100644 examples/other/tensorize_vllm_model.py + create mode 100644 examples/template_blip2.jinja + create mode 100644 examples/template_dse_qwen2_vl.jinja + create mode 100644 examples/template_llava.jinja + create mode 100644 examples/template_pixtral_hf.jinja + create mode 100644 examples/template_vlm2vec.jinja + create mode 100644 examples/tool_chat_template_granite.jinja + create mode 100644 examples/tool_chat_template_granite_20b_fc.jinja + create mode 100644 examples/tool_chat_template_hermes.jinja + create mode 100644 examples/tool_chat_template_internlm2_tool.jinja + create mode 100644 examples/tool_chat_template_llama3.1_json.jinja + create mode 100644 examples/tool_chat_template_llama3.2_json.jinja + create mode 100644 examples/tool_chat_template_llama3.2_pythonic.jinja + create mode 100644 examples/tool_chat_template_mistral.jinja + create mode 100644 examples/tool_chat_template_mistral_parallel.jinja + create mode 100644 examples/tool_chat_template_toolace.jinja + create mode 100644 find_cuda_init.py + create mode 100644 python_only_dev.py + create mode 100644 requirements-hpu.txt + create mode 100644 requirements-lint.txt + create mode 100644 requirements-openvino.txt + create mode 100644 requirements-test.in + create mode 100644 requirements-test.txt + create mode 100644 requirements-tpu.txt + create mode 100644 requirements-xpu.txt + create mode 100644 tests/async_engine/__init__.py + create mode 100644 tests/basic_correctness/__init__.py + create mode 100644 tests/basic_correctness/test_cpu_offload.py + create mode 100644 tests/compile/__init__.py + create mode 100644 tests/compile/backend.py + create mode 100644 tests/compile/piecewise/__init__.py + create mode 100644 tests/compile/piecewise/test_simple.py + create mode 100644 tests/compile/piecewise/test_toy_llama.py + create mode 100644 tests/compile/test_basic_correctness.py + create mode 100644 tests/compile/test_full_graph.py + create mode 100644 tests/compile/test_functionalization.py + create mode 100644 tests/compile/test_fusion.py + create mode 100644 tests/compile/test_pass_manager.py + create mode 100644 tests/compile/test_wrapper.py + create mode 100644 tests/compile/utils.py + create mode 100644 tests/core/block/e2e/__init__.py + create mode 100644 tests/core/block/e2e/test_correctness_sliding_window.py + create mode 100644 tests/core/block/test_block_manager.py + create mode 100644 tests/core/test_num_computed_tokens_update.py + create mode 100644 tests/core/test_scheduler_encoder_decoder.py + create mode 100644 tests/core/test_serialization.py + create mode 100644 tests/data/test_config.yaml + create mode 100644 tests/distributed/__init__.py + create mode 100644 tests/distributed/test_ca_buffer_sharing.py + create mode 100644 tests/distributed/test_distributed_oot.py + create mode 100644 tests/distributed/test_multi_node_assignment.py + create mode 100644 tests/distributed/test_pipeline_parallel.py + create mode 100644 tests/distributed/test_pipeline_partition.py + create mode 100644 tests/distributed/test_pp_cudagraph.py + create mode 100644 tests/distributed/test_same_node.py + create mode 100644 tests/distributed/test_shm_broadcast.py + create mode 100644 tests/distributed/test_utils.py + create mode 100644 tests/encoder_decoder/__init__.py + create mode 100644 tests/encoder_decoder/test_e2e_correctness.py + create mode 100644 tests/engine/__init__.py + create mode 100644 tests/engine/output_processor/__init__.py + create mode 100644 tests/engine/output_processor/test_stop_checker.py + create mode 100644 tests/engine/test_arg_utils.py + create mode 100644 tests/engine/test_custom_executor.py + create mode 100644 tests/engine/test_short_mm_context.py + create mode 100644 tests/entrypoints/__init__.py + create mode 100644 tests/entrypoints/conftest.py + create mode 100644 tests/entrypoints/llm/__init__.py + create mode 100644 tests/entrypoints/llm/test_accuracy.py + create mode 100644 tests/entrypoints/llm/test_chat.py + create mode 100644 tests/entrypoints/llm/test_encode.py + create mode 100644 tests/entrypoints/llm/test_generate.py + create mode 100644 tests/entrypoints/llm/test_generate_multiple_loras.py + create mode 100644 tests/entrypoints/llm/test_gpu_utilization.py + create mode 100644 tests/entrypoints/llm/test_guided_generate.py + create mode 100644 tests/entrypoints/llm/test_init.py + create mode 100644 tests/entrypoints/llm/test_lazy_outlines.py + create mode 100644 tests/entrypoints/llm/test_prompt_validation.py + create mode 100644 tests/entrypoints/offline_mode/__init__.py + create mode 100644 tests/entrypoints/offline_mode/test_offline_mode.py + create mode 100644 tests/entrypoints/openai/__init__.py + create mode 100644 tests/entrypoints/openai/test_accuracy.py + create mode 100644 tests/entrypoints/openai/test_async_tokenization.py + create mode 100644 tests/entrypoints/openai/test_audio.py + create mode 100644 tests/entrypoints/openai/test_basic.py + create mode 100644 tests/entrypoints/openai/test_chat.py + create mode 100644 tests/entrypoints/openai/test_chat_echo.py + create mode 100644 tests/entrypoints/openai/test_chat_template.py + create mode 100644 tests/entrypoints/openai/test_chunked_prompt.py + create mode 100644 tests/entrypoints/openai/test_cli_args.py + create mode 100644 tests/entrypoints/openai/test_completion.py + create mode 100644 tests/entrypoints/openai/test_embedding.py + create mode 100644 tests/entrypoints/openai/test_encoder_decoder.py + create mode 100644 tests/entrypoints/openai/test_lora_adapters.py + create mode 100644 tests/entrypoints/openai/test_metrics.py + create mode 100644 tests/entrypoints/openai/test_models.py + create mode 100644 tests/entrypoints/openai/test_oot_registration.py + create mode 100644 tests/entrypoints/openai/test_pooling.py + create mode 100644 tests/entrypoints/openai/test_prompt_validation.py + create mode 100644 tests/entrypoints/openai/test_return_tokens_as_ids.py + create mode 100644 tests/entrypoints/openai/test_root_path.py + create mode 100644 tests/entrypoints/openai/test_run_batch.py + create mode 100644 tests/entrypoints/openai/test_score.py + create mode 100644 tests/entrypoints/openai/test_serving_models.py + create mode 100644 tests/entrypoints/openai/test_shutdown.py + create mode 100644 tests/entrypoints/openai/test_tokenization.py + create mode 100644 tests/entrypoints/openai/test_video.py + create mode 100644 tests/entrypoints/openai/test_vision.py + create mode 100644 tests/entrypoints/openai/test_vision_embedding.py + create mode 100644 tests/entrypoints/openai/tool_parsers/__init__.py + create mode 100644 tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py + create mode 100644 tests/entrypoints/openai/tool_parsers/utils.py + create mode 100644 tests/entrypoints/test_chat_utils.py + create mode 100644 tests/kernels/__init__.py + create mode 100644 tests/kernels/quant_utils.py + create mode 100644 tests/kernels/test_aqlm.py + create mode 100644 tests/kernels/test_attention_selector.py + create mode 100644 tests/kernels/test_awq.py + create mode 100644 tests/kernels/test_awq_marlin.py + create mode 100644 tests/kernels/test_awq_triton.py + create mode 100644 tests/kernels/test_block_fp8.py + create mode 100644 tests/kernels/test_blocksparse_attention.py + create mode 100644 tests/kernels/test_cascade_flash_attn.py + create mode 100644 tests/kernels/test_causal_conv1d.py + create mode 100644 tests/kernels/test_cutlass.py + create mode 100644 tests/kernels/test_encoder_decoder_attn.py + create mode 100644 tests/kernels/test_flash_attn.py + create mode 100644 tests/kernels/test_flashinfer.py + create mode 100644 tests/kernels/test_fp8_quant.py + create mode 100644 tests/kernels/test_fused_quant_layernorm.py + create mode 100644 tests/kernels/test_ggml.py + create mode 100644 tests/kernels/test_gguf.py + create mode 100644 tests/kernels/test_gptq.py + create mode 100644 tests/kernels/test_int8_quant.py + create mode 100644 tests/kernels/test_machete_mm.py + create mode 100644 tests/kernels/test_mamba_ssm.py + create mode 100644 tests/kernels/test_marlin_gemm.py + create mode 100644 tests/kernels/test_permute_cols.py + create mode 100644 tests/kernels/test_rotary_embedding.py + create mode 100644 tests/kernels/test_semi_structured.py + create mode 100644 tests/kernels/test_triton_scaled_mm.py + create mode 100644 tests/kernels/test_utils.py + create mode 100644 tests/kernels/utils.py + create mode 100644 tests/kv_transfer/disagg_test.py + create mode 100644 tests/kv_transfer/module_test.py + create mode 100644 tests/kv_transfer/test_lookup_buffer.py + create mode 100644 tests/kv_transfer/test_lookup_buffer.sh + create mode 100644 tests/kv_transfer/test_send_recv.py + create mode 100644 tests/kv_transfer/test_send_recv.sh + create mode 100644 tests/lora/data/__init__.py + create mode 100644 tests/lora/data/long_context_test_data.py + create mode 100644 tests/lora/test_chatglm3_tp.py + create mode 100644 tests/lora/test_jamba.py + create mode 100644 tests/lora/test_llama_tp.py + create mode 100644 tests/lora/test_long_context.py + create mode 100644 tests/lora/test_lora_bias_e2e.py + create mode 100644 tests/lora/test_lora_huggingface.py + create mode 100644 tests/lora/test_minicpmv_tp.py + create mode 100644 tests/lora/test_phi.py + create mode 100644 tests/lora/test_punica_ops_sizes.py + create mode 100644 tests/lora/test_punica_ops_variation.py + create mode 100644 tests/lora/test_qwen2vl.py + create mode 100644 tests/metrics/__init__.py + create mode 100644 tests/model_executor/__init__.py + create mode 100644 tests/model_executor/conftest.py + create mode 100644 tests/model_executor/test_enabled_custom_ops.py + create mode 100644 tests/model_executor/test_guided_processors.py + create mode 100644 tests/model_executor/test_model_load_with_params.py + create mode 100644 tests/models/__init__.py + create mode 100644 tests/models/decoder_only/__init__.py + create mode 100644 tests/models/decoder_only/audio_language/__init__.py + create mode 100644 tests/models/decoder_only/audio_language/test_ultravox.py + create mode 100644 tests/models/decoder_only/language/__init__.py + create mode 100644 tests/models/decoder_only/language/test_aqlm.py + create mode 100644 tests/models/decoder_only/language/test_fp8.py + create mode 100644 tests/models/decoder_only/language/test_gguf.py + create mode 100644 tests/models/decoder_only/language/test_gptq_marlin.py + create mode 100644 tests/models/decoder_only/language/test_gptq_marlin_24.py + create mode 100644 tests/models/decoder_only/language/test_granite.py + create mode 100644 tests/models/decoder_only/language/test_jamba.py + create mode 100644 tests/models/decoder_only/language/test_mamba.py + create mode 100644 tests/models/decoder_only/language/test_mistral.py + create mode 100644 tests/models/decoder_only/language/test_modelopt.py + create mode 100644 tests/models/decoder_only/language/test_models.py + create mode 100644 tests/models/decoder_only/language/test_phimoe.py + create mode 100644 tests/models/decoder_only/vision_language/__init__.py + create mode 100644 tests/models/decoder_only/vision_language/test_awq.py + create mode 100644 tests/models/decoder_only/vision_language/test_h2ovl.py + create mode 100644 tests/models/decoder_only/vision_language/test_intern_vit.py + create mode 100644 tests/models/decoder_only/vision_language/test_models.py + create mode 100644 tests/models/decoder_only/vision_language/test_phi3v.py + create mode 100644 tests/models/decoder_only/vision_language/test_pixtral.py + create mode 100644 tests/models/decoder_only/vision_language/test_qwen2_vl.py + create mode 100644 tests/models/decoder_only/vision_language/vlm_utils/__init__.py + create mode 100644 tests/models/decoder_only/vision_language/vlm_utils/builders.py + create mode 100644 tests/models/decoder_only/vision_language/vlm_utils/case_filtering.py + create mode 100644 tests/models/decoder_only/vision_language/vlm_utils/core.py + create mode 100644 tests/models/decoder_only/vision_language/vlm_utils/custom_inputs.py + create mode 100644 tests/models/decoder_only/vision_language/vlm_utils/model_utils.py + create mode 100644 tests/models/decoder_only/vision_language/vlm_utils/runners.py + create mode 100644 tests/models/decoder_only/vision_language/vlm_utils/types.py + create mode 100644 tests/models/embedding/__init__.py + create mode 100644 tests/models/embedding/language/__init__.py + create mode 100644 tests/models/embedding/language/test_cls_models.py + create mode 100644 tests/models/embedding/language/test_embedding.py + create mode 100644 tests/models/embedding/language/test_gritlm.py + create mode 100644 tests/models/embedding/language/test_scoring.py + create mode 100644 tests/models/embedding/utils.py + create mode 100644 tests/models/embedding/vision_language/__init__.py + create mode 100644 tests/models/embedding/vision_language/test_dse_qwen2_vl.py + create mode 100644 tests/models/embedding/vision_language/test_llava_next.py + create mode 100644 tests/models/embedding/vision_language/test_phi3v.py + create mode 100644 tests/models/encoder_decoder/__init__.py + create mode 100644 tests/models/encoder_decoder/audio_language/__init__.py + create mode 100644 tests/models/encoder_decoder/audio_language/test_whisper.py + create mode 100644 tests/models/encoder_decoder/language/__init__.py + create mode 100644 tests/models/encoder_decoder/language/test_bart.py + create mode 100644 tests/models/encoder_decoder/vision_language/__init__.py + create mode 100644 tests/models/encoder_decoder/vision_language/test_broadcast.py + create mode 100644 tests/models/encoder_decoder/vision_language/test_florence2.py + create mode 100644 tests/models/encoder_decoder/vision_language/test_mllama.py + create mode 100644 tests/models/fixtures/pixtral_chat.json + create mode 100644 tests/models/fixtures/pixtral_chat_engine.json + create mode 100644 tests/models/multimodal/__init__.py + create mode 100644 tests/models/multimodal/processing/__init__.py + create mode 100644 tests/models/multimodal/processing/test_common.py + create mode 100644 tests/models/multimodal/processing/test_idefics3.py + create mode 100644 tests/models/multimodal/processing/test_internvl.py + create mode 100644 tests/models/multimodal/processing/test_llava_next.py + create mode 100644 tests/models/multimodal/processing/test_llava_onevision.py + create mode 100644 tests/models/multimodal/processing/test_phi3v.py + create mode 100644 tests/models/multimodal/processing/test_qwen.py + create mode 100644 tests/models/multimodal/processing/test_qwen2_vl.py + create mode 100644 tests/models/registry.py + create mode 100644 tests/models/test_initialization.py + create mode 100644 tests/models/test_registry.py + create mode 100644 tests/mq_llm_engine/__init__.py + create mode 100644 tests/mq_llm_engine/test_abort.py + create mode 100644 tests/mq_llm_engine/test_error_handling.py + create mode 100644 tests/mq_llm_engine/test_load.py + create mode 100644 tests/mq_llm_engine/utils.py + create mode 100644 tests/multi_step/__init__.py + create mode 100644 tests/multi_step/test_correctness_async_llm.py + create mode 100644 tests/multi_step/test_correctness_llm.py + create mode 100644 tests/multimodal/__init__.py + create mode 100644 tests/multimodal/test_inputs.py + create mode 100644 tests/multimodal/test_processing.py + create mode 100644 tests/multimodal/test_processor_kwargs.py + create mode 100644 tests/multimodal/test_utils.py + create mode 100644 tests/multimodal/utils.py + create mode 100644 tests/plugins/vllm_add_dummy_model/setup.py + create mode 100644 tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py + create mode 100644 tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py + create mode 100644 tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py + create mode 100644 tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py + create mode 100644 tests/plugins/vllm_add_dummy_platform/setup.py + create mode 100644 tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py + create mode 100644 tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py + create mode 100644 tests/plugins_tests/test_platform_plugins.py + create mode 100644 tests/prefix_caching/__init__.py + create mode 100644 tests/prefix_caching/test_disable_sliding_window.py + create mode 100644 tests/prompt_adapter/test_bloom.py + create mode 100644 tests/prompt_adapter/test_multi_adapter_inference.py + create mode 100644 tests/prompt_adapter/test_pa_lora.py + create mode 100644 tests/quantization/__init__.py + create mode 100644 tests/quantization/test_bitsandbytes.py + create mode 100644 tests/quantization/test_compressed_tensors.py + create mode 100644 tests/quantization/test_cpu_offload.py + create mode 100644 tests/quantization/test_experts_int8.py + create mode 100644 tests/quantization/test_ipex_quant.py + create mode 100644 tests/quantization/test_lm_head.py + create mode 100644 tests/quantization/utils.py + create mode 100644 tests/runai_model_streamer/__init__.py + create mode 100644 tests/runai_model_streamer/test_runai_model_streamer_loader.py + create mode 100644 tests/runai_model_streamer/test_weight_utils.py + create mode 100644 tests/samplers/__init__.py + create mode 100644 tests/samplers/test_no_bad_words.py + create mode 100644 tests/samplers/test_typical_acceptance_sampler.py + create mode 100644 tests/spec_decode/e2e/test_eagle_correctness.py + create mode 100644 tests/spec_decode/e2e/test_integration.py + create mode 100644 tests/spec_decode/e2e/test_integration_dist_tp2.py + create mode 100644 tests/spec_decode/e2e/test_integration_dist_tp4.py + create mode 100644 tests/spec_decode/e2e/test_medusa_correctness.py + create mode 100644 tests/spec_decode/e2e/test_mlp_correctness.py + create mode 100644 tests/spec_decode/e2e/test_seed.py + create mode 100644 tests/spec_decode/test_dynamic_spec_decode.py + create mode 100644 tests/spec_decode/test_scorer.py + create mode 100644 tests/standalone_tests/lazy_torch_compile.py + create mode 100644 tests/standalone_tests/python_only_compile.sh + create mode 100644 tests/system_messages/sonnet3.5_nov2024.txt + create mode 100644 tests/tensorizer_loader/conftest.py + create mode 100644 tests/test_embedded_commit.py + create mode 100644 tests/test_inputs.py + create mode 100644 tests/test_scalartype.py + create mode 100644 tests/test_sharded_state_loader.py + create mode 100644 tests/test_utils.py + create mode 100644 tests/tokenization/test_get_eos.py + create mode 100644 tests/tool_use/__init__.py + create mode 100644 tests/tool_use/conftest.py + create mode 100644 tests/tool_use/test_chat_completion_request_validations.py + create mode 100644 tests/tool_use/test_chat_completions.py + create mode 100644 tests/tool_use/test_jamba_tool_parser.py + create mode 100644 tests/tool_use/test_parallel_tool_calls.py + create mode 100644 tests/tool_use/test_tool_calls.py + create mode 100644 tests/tool_use/utils.py + create mode 100644 tests/tpu/__init__.py + create mode 100644 tests/tpu/test_compilation.py + create mode 100644 tests/tpu/test_custom_dispatcher.py + create mode 100644 tests/tpu/test_quantization_accuracy.py + create mode 100644 tests/tracing/__init__.py + create mode 100644 tests/tracing/test_tracing.py + create mode 100644 tests/utils.py + create mode 100644 tests/v1/__init__.py + create mode 100644 tests/v1/core/test_kv_cache_utils.py + create mode 100644 tests/v1/core/test_prefix_caching.py + create mode 100644 tests/v1/e2e/__init__.py + create mode 100644 tests/v1/e2e/test_cascade_attention.py + create mode 100644 tests/v1/engine/__init__.py + create mode 100644 tests/v1/engine/test_async_llm.py + create mode 100644 tests/v1/engine/test_engine_args.py + create mode 100644 tests/v1/engine/test_engine_core.py + create mode 100644 tests/v1/engine/test_engine_core_client.py + create mode 100644 tests/v1/engine/test_output_processor.py + create mode 100644 tests/v1/sample/__init__.py + create mode 100644 tests/v1/sample/test_sampler.py + create mode 100644 tests/v1/worker/__init__.py + create mode 100644 tests/v1/worker/test_gpu_input_batch.py + create mode 100644 tests/vllm_test_utils/setup.py + create mode 100644 tests/vllm_test_utils/vllm_test_utils/__init__.py + create mode 100644 tests/vllm_test_utils/vllm_test_utils/blame.py + create mode 100644 tests/vllm_test_utils/vllm_test_utils/monitor.py + create mode 100644 tests/weight_loading/models-large.txt + create mode 100644 tests/weight_loading/models.txt + create mode 100644 tests/weight_loading/run_model_weight_loading_test.sh + create mode 100644 tests/weight_loading/test_weight_loading.py + create mode 100644 tests/worker/test_encoder_decoder_model_runner.py + create mode 100644 tests/worker/test_model_input.py + create mode 100644 tests/worker/test_profile.py + create mode 100644 tools/actionlint.sh + create mode 100644 tools/check_repo.sh + create mode 100644 tools/doc-lint.sh + create mode 100644 tools/mypy.sh + create mode 100644 tools/png-lint.sh + create mode 100644 tools/profiler/print_layerwise_table.py + create mode 100644 tools/profiler/visualize_layerwise_profile.py + create mode 100644 tools/report_build_time_ninja.py + create mode 100644 tools/shellcheck.sh + create mode 100644 use_existing_torch.py + create mode 100644 vllm/_ipex_ops.py + create mode 100644 vllm/adapter_commons/__init__.py + create mode 100644 vllm/adapter_commons/layers.py + create mode 100644 vllm/adapter_commons/models.py + create mode 100644 vllm/adapter_commons/request.py + create mode 100644 vllm/adapter_commons/utils.py + create mode 100644 vllm/adapter_commons/worker_manager.py + create mode 100644 vllm/assets/__init__.py + create mode 100644 vllm/assets/audio.py + create mode 100644 vllm/assets/base.py + create mode 100644 vllm/assets/image.py + create mode 100644 vllm/assets/video.py + create mode 100644 vllm/attention/backends/blocksparse_attn.py + create mode 100644 vllm/attention/backends/hpu_attn.py + create mode 100644 vllm/attention/backends/ipex_attn.py + create mode 100644 vllm/attention/backends/openvino.py + create mode 100644 vllm/attention/backends/pallas.py + create mode 100644 vllm/attention/backends/placeholder_attn.py + create mode 100644 vllm/attention/backends/utils.py + create mode 100644 vllm/attention/ops/blocksparse_attention/__init__.py + create mode 100644 vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py + create mode 100644 vllm/attention/ops/blocksparse_attention/interface.py + create mode 100644 vllm/attention/ops/blocksparse_attention/utils.py + create mode 100644 vllm/attention/ops/hpu_paged_attn.py + create mode 100644 vllm/attention/ops/ipex_attn.py + create mode 100644 vllm/beam_search.py + create mode 100644 vllm/compilation/__init__.py + create mode 100644 vllm/compilation/backends.py + create mode 100644 vllm/compilation/counter.py + create mode 100644 vllm/compilation/decorators.py + create mode 100644 vllm/compilation/fix_functionalization.py + create mode 100644 vllm/compilation/fusion.py + create mode 100644 vllm/compilation/fx_utils.py + create mode 100644 vllm/compilation/inductor_pass.py + create mode 100644 vllm/compilation/monitor.py + create mode 100644 vllm/compilation/multi_output_match.py + create mode 100644 vllm/compilation/pass_manager.py + create mode 100644 vllm/compilation/reshapes.py + create mode 100644 vllm/compilation/vllm_inductor_pass.py + create mode 100644 vllm/compilation/wrapper.py + create mode 100644 vllm/connections.py + create mode 100644 vllm/core/block/utils.py + create mode 100644 vllm/core/block_manager.py + create mode 100644 vllm/core/evictor.py + create mode 100644 vllm/core/placeholder_block_space_manager.py + create mode 100644 vllm/distributed/device_communicators/cuda_wrapper.py + create mode 100644 vllm/distributed/device_communicators/custom_all_reduce_utils.py + create mode 100644 vllm/distributed/device_communicators/hpu_communicator.py + create mode 100644 vllm/distributed/device_communicators/pynccl_wrapper.py + create mode 100644 vllm/distributed/device_communicators/shm_broadcast.py + create mode 100644 vllm/distributed/device_communicators/tpu_communicator.py + create mode 100644 vllm/distributed/device_communicators/xpu_communicator.py + create mode 100644 vllm/distributed/kv_transfer/README.md + create mode 100644 vllm/distributed/kv_transfer/__init__.py + create mode 100644 vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg + create mode 100644 vllm/distributed/kv_transfer/kv_connector/__init__.py + create mode 100644 vllm/distributed/kv_transfer/kv_connector/base.py + create mode 100644 vllm/distributed/kv_transfer/kv_connector/factory.py + create mode 100644 vllm/distributed/kv_transfer/kv_connector/simple_connector.py + create mode 100644 vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py + create mode 100644 vllm/distributed/kv_transfer/kv_lookup_buffer/base.py + create mode 100644 vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py + create mode 100644 vllm/distributed/kv_transfer/kv_pipe/__init__.py + create mode 100644 vllm/distributed/kv_transfer/kv_pipe/base.py + create mode 100644 vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py + create mode 100644 vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py + create mode 100644 vllm/distributed/kv_transfer/kv_transfer_agent.py + create mode 100644 vllm/engine/async_timeout.py + create mode 100644 vllm/engine/metrics_types.py + create mode 100644 vllm/engine/multiprocessing/__init__.py + create mode 100644 vllm/engine/multiprocessing/client.py + create mode 100644 vllm/engine/multiprocessing/engine.py + create mode 100644 vllm/engine/protocol.py + create mode 100644 vllm/entrypoints/chat_utils.py + create mode 100644 vllm/entrypoints/launcher.py + create mode 100644 vllm/entrypoints/logger.py + create mode 100644 vllm/entrypoints/openai/logits_processors.py + create mode 100644 vllm/entrypoints/openai/run_batch.py + create mode 100644 vllm/entrypoints/openai/serving_embedding.py + create mode 100644 vllm/entrypoints/openai/serving_models.py + create mode 100644 vllm/entrypoints/openai/serving_pooling.py + create mode 100644 vllm/entrypoints/openai/serving_score.py + create mode 100644 vllm/entrypoints/openai/serving_tokenization.py + create mode 100644 vllm/entrypoints/openai/tool_parsers/__init__.py + create mode 100644 vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py + create mode 100644 vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py + create mode 100644 vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py + create mode 100644 vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py + create mode 100644 vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py + create mode 100644 vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py + create mode 100644 vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py + create mode 100644 vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py + create mode 100644 vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py + create mode 100644 vllm/entrypoints/openai/tool_parsers/utils.py + create mode 100644 vllm/entrypoints/utils.py + create mode 100644 vllm/executor/hpu_executor.py + create mode 100644 vllm/executor/msgspec_utils.py + create mode 100644 vllm/executor/multiproc_gpu_executor.py + create mode 100644 vllm/executor/multiproc_xpu_executor.py + create mode 100644 vllm/executor/openvino_executor.py + create mode 100644 vllm/executor/ray_hpu_executor.py + create mode 100644 vllm/executor/ray_tpu_executor.py + create mode 100644 vllm/executor/ray_xpu_executor.py + create mode 100644 vllm/executor/tpu_executor.py + create mode 100644 vllm/executor/xpu_executor.py + create mode 100644 vllm/forward_context.py + create mode 100644 vllm/inputs/__init__.py + create mode 100644 vllm/inputs/data.py + create mode 100644 vllm/inputs/parse.py + create mode 100644 vllm/inputs/preprocess.py + create mode 100644 vllm/inputs/registry.py + create mode 100644 vllm/logging_utils/__init__.py + create mode 100644 vllm/logging_utils/formatter.py + create mode 100644 vllm/logits_process.py + create mode 100644 vllm/lora/ops/__init__.py + create mode 100644 vllm/lora/ops/torch_ops/__init__.py + create mode 100644 vllm/lora/ops/torch_ops/lora_ops.py + create mode 100644 vllm/lora/ops/triton_ops/__init__.py + create mode 100644 vllm/lora/ops/triton_ops/bgmv_expand.py + create mode 100644 vllm/lora/ops/triton_ops/bgmv_expand_slice.py + create mode 100644 vllm/lora/ops/triton_ops/bgmv_shrink.py + create mode 100644 vllm/lora/ops/triton_ops/sgmv_expand.py + create mode 100644 vllm/lora/ops/triton_ops/sgmv_shrink.py + create mode 100644 vllm/lora/ops/triton_ops/utils.py + create mode 100644 vllm/lora/peft_helper.py + create mode 100644 vllm/lora/punica_wrapper/__init__.py + create mode 100644 vllm/lora/punica_wrapper/punica_base.py + create mode 100644 vllm/lora/punica_wrapper/punica_cpu.py + create mode 100644 vllm/lora/punica_wrapper/punica_gpu.py + create mode 100644 vllm/lora/punica_wrapper/punica_hpu.py + create mode 100644 vllm/lora/punica_wrapper/punica_selector.py + create mode 100644 vllm/lora/punica_wrapper/utils.py + create mode 100644 vllm/model_executor/custom_op.py + create mode 100644 vllm/model_executor/guided_decoding/guided_fields.py + create mode 100644 vllm/model_executor/guided_decoding/utils.py + create mode 100644 vllm/model_executor/guided_decoding/xgrammar_decoding.py + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_L40S.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json + create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json + create mode 100644 vllm/model_executor/layers/fused_moe/fused_marlin_moe.py + create mode 100644 vllm/model_executor/layers/fused_moe/layer.py + create mode 100644 vllm/model_executor/layers/fused_moe/moe_pallas.py + create mode 100644 vllm/model_executor/layers/fused_moe/moe_torch_iterative.py + create mode 100644 vllm/model_executor/layers/mamba/__init__.py + create mode 100644 vllm/model_executor/layers/mamba/mamba_mixer.py + create mode 100644 vllm/model_executor/layers/mamba/ops/__init__.py + create mode 100644 vllm/model_executor/layers/mamba/ops/causal_conv1d.py + create mode 100644 vllm/model_executor/layers/mamba/ops/mamba_ssm.py + create mode 100644 vllm/model_executor/layers/pooler.py + create mode 100644 vllm/model_executor/layers/quantization/awq_marlin.py + create mode 100644 vllm/model_executor/layers/quantization/awq_triton.py + create mode 100644 vllm/model_executor/layers/quantization/bitsandbytes.py + create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/__init__.py + create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py + create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py + create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py + create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py + create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py + create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py + create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py + create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py + create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py + create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py + create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py + create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/utils.py + create mode 100644 vllm/model_executor/layers/quantization/deepspeedfp.py + create mode 100644 vllm/model_executor/layers/quantization/experts_int8.py + create mode 100644 vllm/model_executor/layers/quantization/fbgemm_fp8.py + create mode 100644 vllm/model_executor/layers/quantization/gguf.py + create mode 100644 vllm/model_executor/layers/quantization/gptq_marlin_24.py + create mode 100644 vllm/model_executor/layers/quantization/hqq_marlin.py + create mode 100644 vllm/model_executor/layers/quantization/ipex_quant.py + create mode 100644 vllm/model_executor/layers/quantization/kernels/__init__.py + create mode 100644 vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py + create mode 100644 vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py + create mode 100644 vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py + create mode 100644 vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py + create mode 100644 vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py + create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py + create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py + create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py + create mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py + create mode 100644 vllm/model_executor/layers/quantization/kv_cache.py + create mode 100644 vllm/model_executor/layers/quantization/modelopt.py + create mode 100644 vllm/model_executor/layers/quantization/neuron_quant.py + create mode 100644 vllm/model_executor/layers/quantization/qqq.py + create mode 100644 vllm/model_executor/layers/quantization/tpu_int8.py + create mode 100644 vllm/model_executor/layers/quantization/utils/__init__.py + create mode 100644 vllm/model_executor/layers/quantization/utils/fp8_utils.py + create mode 100644 vllm/model_executor/layers/quantization/utils/layer_utils.py + create mode 100644 vllm/model_executor/layers/quantization/utils/machete_utils.py + create mode 100644 vllm/model_executor/layers/quantization/utils/marlin_utils.py + create mode 100644 vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py + create mode 100644 vllm/model_executor/layers/quantization/utils/marlin_utils_test.py + create mode 100644 vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py + create mode 100644 vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py + create mode 100644 vllm/model_executor/layers/quantization/utils/quant_utils.py + create mode 100644 vllm/model_executor/layers/quantization/utils/w8a8_utils.py + create mode 100644 vllm/model_executor/layers/resampler.py + create mode 100644 vllm/model_executor/layers/spec_decode_base_sampler.py + create mode 100644 vllm/model_executor/layers/typical_acceptance_sampler.py + create mode 100644 vllm/model_executor/layers/utils.py + create mode 100644 vllm/model_executor/model_loader/openvino.py + create mode 100644 vllm/model_executor/models/adapters.py + create mode 100644 vllm/model_executor/models/arctic.py + create mode 100644 vllm/model_executor/models/aria.py + create mode 100644 vllm/model_executor/models/bart.py + create mode 100644 vllm/model_executor/models/bert.py + create mode 100644 vllm/model_executor/models/blip.py + create mode 100644 vllm/model_executor/models/blip2.py + create mode 100644 vllm/model_executor/models/chameleon.py + create mode 100644 vllm/model_executor/models/clip.py + create mode 100644 vllm/model_executor/models/deepseek_v2.py + create mode 100644 vllm/model_executor/models/deepseek_v3.py + create mode 100644 vllm/model_executor/models/deepseek_vl2.py + create mode 100644 vllm/model_executor/models/eagle.py + create mode 100644 vllm/model_executor/models/exaone.py + create mode 100644 vllm/model_executor/models/florence2.py + create mode 100644 vllm/model_executor/models/fuyu.py + create mode 100644 vllm/model_executor/models/gemma2.py + create mode 100644 vllm/model_executor/models/glm.py + create mode 100644 vllm/model_executor/models/glm4_vision_encoder.py + create mode 100644 vllm/model_executor/models/granite.py + create mode 100644 vllm/model_executor/models/granitemoe.py + create mode 100644 vllm/model_executor/models/gritlm.py + create mode 100644 vllm/model_executor/models/h2ovl.py + create mode 100644 vllm/model_executor/models/idefics2_vision_model.py + create mode 100644 vllm/model_executor/models/idefics3.py + create mode 100644 vllm/model_executor/models/interfaces.py + create mode 100644 vllm/model_executor/models/interfaces_base.py + create mode 100644 vllm/model_executor/models/intern_vit.py + create mode 100644 vllm/model_executor/models/internlm2_ve.py + create mode 100644 vllm/model_executor/models/internvl.py + create mode 100644 vllm/model_executor/models/jamba.py + create mode 100644 vllm/model_executor/models/llava_next.py + create mode 100644 vllm/model_executor/models/llava_next_video.py + create mode 100644 vllm/model_executor/models/llava_onevision.py + create mode 100644 vllm/model_executor/models/mamba.py + create mode 100644 vllm/model_executor/models/mamba_cache.py + create mode 100644 vllm/model_executor/models/medusa.py + create mode 100644 vllm/model_executor/models/minicpm3.py + create mode 100644 vllm/model_executor/models/minicpmv.py + create mode 100644 vllm/model_executor/models/mllama.py + create mode 100644 vllm/model_executor/models/mlp_speculator.py + create mode 100644 vllm/model_executor/models/module_mapping.py + create mode 100644 vllm/model_executor/models/molmo.py + create mode 100644 vllm/model_executor/models/nemotron.py + create mode 100644 vllm/model_executor/models/nvlm_d.py + create mode 100644 vllm/model_executor/models/olmo2.py + create mode 100644 vllm/model_executor/models/olmoe.py + create mode 100644 vllm/model_executor/models/paligemma.py + create mode 100644 vllm/model_executor/models/persimmon.py + create mode 100644 vllm/model_executor/models/phi3.py + create mode 100644 vllm/model_executor/models/phi3_small.py + create mode 100644 vllm/model_executor/models/phi3v.py + create mode 100644 vllm/model_executor/models/phimoe.py + create mode 100644 vllm/model_executor/models/pixtral.py + create mode 100644 vllm/model_executor/models/qwen2_audio.py + create mode 100644 vllm/model_executor/models/qwen2_rm.py + create mode 100644 vllm/model_executor/models/qwen2_vl.py + create mode 100644 vllm/model_executor/models/registry.py + create mode 100644 vllm/model_executor/models/roberta.py + create mode 100644 vllm/model_executor/models/siglip.py + create mode 100644 vllm/model_executor/models/solar.py + create mode 100644 vllm/model_executor/models/telechat2.py + create mode 100644 vllm/model_executor/models/ultravox.py + create mode 100644 vllm/model_executor/models/utils.py + create mode 100644 vllm/model_executor/models/vision.py + create mode 100644 vllm/model_executor/models/whisper.py + create mode 100644 vllm/model_executor/parameter.py + create mode 100644 vllm/model_executor/pooling_metadata.py + create mode 100644 vllm/multimodal/__init__.py + create mode 100644 vllm/multimodal/audio.py + create mode 100644 vllm/multimodal/base.py + create mode 100644 vllm/multimodal/hasher.py + create mode 100644 vllm/multimodal/image.py + create mode 100644 vllm/multimodal/inputs.py + create mode 100644 vllm/multimodal/parse.py + create mode 100644 vllm/multimodal/processing.py + create mode 100644 vllm/multimodal/profiling.py + create mode 100644 vllm/multimodal/registry.py + create mode 100644 vllm/multimodal/utils.py + create mode 100644 vllm/multimodal/video.py + create mode 100644 vllm/platforms/__init__.py + create mode 100644 vllm/platforms/cpu.py + create mode 100644 vllm/platforms/cuda.py + create mode 100644 vllm/platforms/hpu.py + create mode 100644 vllm/platforms/interface.py + create mode 100644 vllm/platforms/neuron.py + create mode 100644 vllm/platforms/openvino.py + create mode 100644 vllm/platforms/rocm.py + create mode 100644 vllm/platforms/tpu.py + create mode 100644 vllm/platforms/xpu.py + create mode 100644 vllm/plugins/__init__.py + create mode 100644 vllm/pooling_params.py + create mode 100644 vllm/profiler/__init__.py + create mode 100644 vllm/profiler/layerwise_profile.py + create mode 100644 vllm/profiler/utils.py + create mode 100644 vllm/prompt_adapter/__init__.py + create mode 100644 vllm/prompt_adapter/layers.py + create mode 100644 vllm/prompt_adapter/models.py + create mode 100644 vllm/prompt_adapter/request.py + create mode 100644 vllm/prompt_adapter/utils.py + create mode 100644 vllm/prompt_adapter/worker_manager.py + create mode 100644 vllm/scalar_type.py + create mode 100644 vllm/scripts.py + create mode 100644 vllm/spec_decode/draft_model_runner.py + create mode 100644 vllm/spec_decode/medusa_worker.py + create mode 100644 vllm/spec_decode/mlp_speculator_worker.py + create mode 100644 vllm/spec_decode/mqa_scorer.py + create mode 100644 vllm/spec_decode/proposer_worker_base.py + create mode 100644 vllm/spec_decode/smaller_tp_proposer_worker.py + create mode 100644 vllm/spec_decode/target_model_runner.py + create mode 100644 vllm/tracing.py + create mode 100644 vllm/transformers_utils/configs/arctic.py + create mode 100644 vllm/transformers_utils/configs/aria.py + create mode 100644 vllm/transformers_utils/configs/cohere2.py + create mode 100644 vllm/transformers_utils/configs/deepseek_vl2.py + create mode 100644 vllm/transformers_utils/configs/eagle.py + create mode 100644 vllm/transformers_utils/configs/exaone.py + create mode 100644 vllm/transformers_utils/configs/h2ovl.py + create mode 100644 vllm/transformers_utils/configs/internvl.py + create mode 100644 vllm/transformers_utils/configs/medusa.py + create mode 100644 vllm/transformers_utils/configs/mllama.py + create mode 100644 vllm/transformers_utils/configs/mlp_speculator.py + create mode 100644 vllm/transformers_utils/configs/nemotron.py + create mode 100644 vllm/transformers_utils/configs/nvlm_d.py + create mode 100644 vllm/transformers_utils/configs/olmo2.py + create mode 100644 vllm/transformers_utils/configs/solar.py + create mode 100644 vllm/transformers_utils/configs/telechat2.py + create mode 100644 vllm/transformers_utils/configs/ultravox.py + create mode 100644 vllm/transformers_utils/detokenizer_utils.py + create mode 100644 vllm/transformers_utils/processor.py + create mode 100644 vllm/transformers_utils/s3_utils.py + create mode 100644 vllm/transformers_utils/tokenizers/mistral.py + create mode 100644 vllm/transformers_utils/utils.py + create mode 100644 vllm/triton_utils/__init__.py + create mode 100644 vllm/triton_utils/custom_cache_manager.py + create mode 100644 vllm/triton_utils/importing.py + create mode 100644 vllm/v1/__init__.py + create mode 100644 vllm/v1/attention/__init__.py + create mode 100644 vllm/v1/attention/backends/__init__.py + create mode 100644 vllm/v1/attention/backends/flash_attn.py + create mode 100644 vllm/v1/core/__init__.py + create mode 100644 vllm/v1/core/encoder_cache_manager.py + create mode 100644 vllm/v1/core/kv_cache_manager.py + create mode 100644 vllm/v1/core/kv_cache_utils.py + create mode 100644 vllm/v1/core/scheduler.py + create mode 100644 vllm/v1/engine/__init__.py + create mode 100644 vllm/v1/engine/async_llm.py + create mode 100644 vllm/v1/engine/core.py + create mode 100644 vllm/v1/engine/core_client.py + create mode 100644 vllm/v1/engine/detokenizer.py + create mode 100644 vllm/v1/engine/llm_engine.py + create mode 100644 vllm/v1/engine/mm_input_mapper.py + create mode 100644 vllm/v1/engine/output_processor.py + create mode 100644 vllm/v1/engine/processor.py + create mode 100644 vllm/v1/executor/__init__.py + create mode 100644 vllm/v1/executor/abstract.py + create mode 100644 vllm/v1/executor/multiproc_executor.py + create mode 100644 vllm/v1/executor/ray_executor.py + create mode 100644 vllm/v1/executor/ray_utils.py + create mode 100644 vllm/v1/executor/uniproc_executor.py + create mode 100644 vllm/v1/metrics/__init__.py + create mode 100644 vllm/v1/metrics/loggers.py + create mode 100644 vllm/v1/metrics/stats.py + create mode 100644 vllm/v1/outputs.py + create mode 100644 vllm/v1/request.py + create mode 100644 vllm/v1/sample/__init__.py + create mode 100644 vllm/v1/sample/metadata.py + create mode 100644 vllm/v1/sample/ops/__init__.py + create mode 100644 vllm/v1/sample/ops/penalties.py + create mode 100644 vllm/v1/sample/ops/topk_topp_sampler.py + create mode 100644 vllm/v1/sample/sampler.py + create mode 100644 vllm/v1/serial_utils.py + create mode 100644 vllm/v1/utils.py + create mode 100644 vllm/v1/worker/__init__.py + create mode 100644 vllm/v1/worker/block_table.py + create mode 100644 vllm/v1/worker/gpu_input_batch.py + create mode 100644 vllm/v1/worker/gpu_model_runner.py + create mode 100644 vllm/v1/worker/gpu_worker.py + create mode 100644 vllm/version.py + create mode 100644 vllm/worker/cpu_enc_dec_model_runner.py + create mode 100644 vllm/worker/cpu_pooling_model_runner.py + create mode 100644 vllm/worker/enc_dec_model_runner.py + create mode 100644 vllm/worker/hpu_model_runner.py + create mode 100644 vllm/worker/hpu_worker.py + create mode 100644 vllm/worker/model_runner_base.py + create mode 100644 vllm/worker/multi_step_model_runner.py + create mode 100644 vllm/worker/multi_step_tpu_worker.py + create mode 100644 vllm/worker/multi_step_worker.py + create mode 100644 vllm/worker/openvino_model_runner.py + create mode 100644 vllm/worker/openvino_worker.py + create mode 100644 vllm/worker/pooling_model_runner.py + create mode 100644 vllm/worker/tpu_model_runner.py + create mode 100644 vllm/worker/tpu_worker.py + create mode 100644 vllm/worker/utils.py + create mode 100644 vllm/worker/xpu_model_runner.py + create mode 100644 vllm/worker/xpu_worker.py + +diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py +index 90a5e54..0412c5f 100644 +--- a/.buildkite/check-wheel-size.py ++++ b/.buildkite/check-wheel-size.py +@@ -1,36 +1,43 @@ + import os ++import sys + import zipfile + +-MAX_SIZE_MB = 100 ++# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 250 MB ++VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 250)) + + + def print_top_10_largest_files(zip_file): ++ """Print the top 10 largest files in the given zip file.""" + with zipfile.ZipFile(zip_file, 'r') as z: + file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()] + file_sizes.sort(key=lambda x: x[1], reverse=True) + for f, size in file_sizes[:10]: +- print(f"{f}: {size/(1024*1024)} MBs uncompressed.") ++ print(f"{f}: {size / (1024 * 1024):.2f} MBs uncompressed.") + + + def check_wheel_size(directory): ++ """Check the size of .whl files in the given directory.""" + for root, _, files in os.walk(directory): +- for f in files: +- if f.endswith(".whl"): +- wheel_path = os.path.join(root, f) +- wheel_size = os.path.getsize(wheel_path) +- wheel_size_mb = wheel_size / (1024 * 1024) +- if wheel_size_mb > MAX_SIZE_MB: +- print( +- f"Wheel {wheel_path} is too large ({wheel_size_mb} MB) " +- f"compare to the allowed size ({MAX_SIZE_MB} MB).") ++ for file_name in files: ++ if file_name.endswith(".whl"): ++ wheel_path = os.path.join(root, file_name) ++ wheel_size_mb = os.path.getsize(wheel_path) / (1024 * 1024) ++ if wheel_size_mb > VLLM_MAX_SIZE_MB: ++ print(f"Not allowed: Wheel {wheel_path} is larger " ++ f"({wheel_size_mb:.2f} MB) than the limit " ++ f"({VLLM_MAX_SIZE_MB} MB).") + print_top_10_largest_files(wheel_path) + return 1 + else: + print(f"Wheel {wheel_path} is within the allowed size " +- f"({wheel_size_mb} MB).") ++ f"({wheel_size_mb:.2f} MB).") + return 0 + + + if __name__ == "__main__": +- import sys +- sys.exit(check_wheel_size(sys.argv[1])) ++ if len(sys.argv) < 2: ++ print("Usage: python check-wheel-size.py ") ++ sys.exit(1) ++ ++ directory = sys.argv[1] ++ sys.exit(check_wheel_size(directory)) +\ No newline at end of file +diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py +new file mode 100644 +index 0000000..8350e27 +--- /dev/null ++++ b/.buildkite/generate_index.py +@@ -0,0 +1,24 @@ ++import argparse ++import os ++ ++template = """ ++ ++ ++

Links for vLLM

++ {wheel}
++ ++ ++""" ++ ++parser = argparse.ArgumentParser() ++parser.add_argument("--wheel", help="The wheel path.", required=True) ++args = parser.parse_args() ++ ++filename = os.path.basename(args.wheel) ++ ++with open("index.html", "w") as f: ++ print(f"Generated index.html for {args.wheel}") ++ # cloudfront requires escaping the '+' character ++ f.write( ++ template.format(wheel=filename, ++ wheel_html_escaped=filename.replace("+", "%2B"))) +diff --git a/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml b/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml +new file mode 100644 +index 0000000..d70ecb2 +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml +@@ -0,0 +1,12 @@ ++# bash ./run-lm-eval-gsm-vllm-baseline.sh -m deepseek-ai/DeepSeek-V2-Lite-Chat -b "auto" -l 1000 -f 5 -t 2 ++model_name: "deepseek-ai/DeepSeek-V2-Lite-Chat" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.671 ++ - name: "exact_match,flexible-extract" ++ value: 0.664 ++limit: 1000 ++num_fewshot: 5 ++trust_remote_code: True +\ No newline at end of file +diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml +new file mode 100644 +index 0000000..4397eff +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml +@@ -0,0 +1,11 @@ ++# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5 ++model_name: "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.905 ++ - name: "exact_match,flexible-extract" ++ value: 0.905 ++limit: 1000 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct.yaml +new file mode 100644 +index 0000000..fa6ea23 +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct.yaml +@@ -0,0 +1,11 @@ ++# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-70B-Instruct -b 32 -l 250 -f 5 ++model_name: "meta-llama/Meta-Llama-3-70B-Instruct" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.892 ++ - name: "exact_match,flexible-extract" ++ value: 0.892 ++limit: 250 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml +new file mode 100644 +index 0000000..c513159 +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml +@@ -0,0 +1,11 @@ ++# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors -b auto -l 1000 -f 5 -t 1 ++model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.752 ++ - name: "exact_match,flexible-extract" ++ value: 0.754 ++limit: 1000 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml +new file mode 100644 +index 0000000..5e57fcb +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform.yaml +@@ -0,0 +1,11 @@ ++# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5 -t 1 ++model_name: "nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.753 ++ - name: "exact_match,flexible-extract" ++ value: 0.753 ++limit: 1000 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml +new file mode 100644 +index 0000000..374171f +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml +@@ -0,0 +1,11 @@ ++# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test -b 32 -l 1000 -f 5 -t 1 ++model_name: "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.755 ++ - name: "exact_match,flexible-extract" ++ value: 0.755 ++limit: 1000 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml +new file mode 100644 +index 0000000..dc36b70 +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-FP8.yaml +@@ -0,0 +1,11 @@ ++# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1 ++model_name: "neuralmagic/Meta-Llama-3-8B-Instruct-FP8" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.753 ++ - name: "exact_match,flexible-extract" ++ value: 0.753 ++limit: 1000 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml +new file mode 100644 +index 0000000..0ecfc01 +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml +@@ -0,0 +1,11 @@ ++# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test -b "auto" -l 250 -f 5 -t 1 ++model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.764 ++ - name: "exact_match,flexible-extract" ++ value: 0.764 ++limit: 250 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml +new file mode 100644 +index 0000000..bc29002 +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml +@@ -0,0 +1,11 @@ ++# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test -b "auto" -l 250 -f 5 -t 1 ++model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.728 ++ - name: "exact_match,flexible-extract" ++ value: 0.728 ++limit: 250 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml +new file mode 100644 +index 0000000..3964f3b +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml +@@ -0,0 +1,11 @@ ++# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test -b auto -l 1000 -f 5 -t 1 ++model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.758 ++ - name: "exact_match,flexible-extract" ++ value: 0.759 ++limit: 1000 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml +new file mode 100644 +index 0000000..fb4b491 +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml +@@ -0,0 +1,11 @@ ++# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m meta-llama/Meta-Llama-3-8B-Instruct -b 32 -l 250 -f 5 -t 1 ++model_name: "meta-llama/Meta-Llama-3-8B-Instruct" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.756 ++ - name: "exact_match,flexible-extract" ++ value: 0.752 ++limit: 250 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml +new file mode 100644 +index 0000000..0424586 +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml +@@ -0,0 +1,11 @@ ++# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1 ++model_name: "HandH1998/QQQ-Llama-3-8b-g128" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.419 ++ - name: "exact_match,flexible-extract" ++ value: 0.416 ++limit: 1000 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml +new file mode 100644 +index 0000000..78347f6 +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml +@@ -0,0 +1,11 @@ ++# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8 -b "auto" -l 1000 -f 5 -t 1 ++model_name: "neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.356 ++ - name: "exact_match,flexible-extract" ++ value: 0.358 ++limit: 1000 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml b/.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml +new file mode 100644 +index 0000000..3ea0b7b +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml +@@ -0,0 +1,11 @@ ++# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m mgoin/Minitron-4B-Base-FP8 -b auto -l 1000 -f 5 -t 1 ++model_name: "mgoin/Minitron-4B-Base-FP8" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.233 ++ - name: "exact_match,flexible-extract" ++ value: 0.236 ++limit: 1000 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/Mixtral-8x22B-Instruct-v0.1-FP8-Dynamic.yaml b/.buildkite/lm-eval-harness/configs/Mixtral-8x22B-Instruct-v0.1-FP8-Dynamic.yaml +new file mode 100644 +index 0000000..75a24e4 +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Mixtral-8x22B-Instruct-v0.1-FP8-Dynamic.yaml +@@ -0,0 +1,11 @@ ++# bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic -b "auto" -l 250 -f 5 -t 8 ++model_name: "neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.86 ++ - name: "exact_match,flexible-extract" ++ value: 0.86 ++limit: 250 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-FP8.yaml b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-FP8.yaml +new file mode 100644 +index 0000000..436ec21 +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-FP8.yaml +@@ -0,0 +1,11 @@ ++# bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8 -b "auto" -l 250 -f 5 -t 4 ++model_name: "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.624 ++ - name: "exact_match,flexible-extract" ++ value: 0.624 ++limit: 250 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1.yaml b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1.yaml +new file mode 100644 +index 0000000..dec9164 +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1.yaml +@@ -0,0 +1,11 @@ ++# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1 -b 32 -l 250 -f 5 -t 4 ++model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.616 ++ - name: "exact_match,flexible-extract" ++ value: 0.632 ++limit: 250 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml +new file mode 100644 +index 0000000..42936fb +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml +@@ -0,0 +1,11 @@ ++# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-FP8W8 -b auto -l 1000 -f 5 -t 1 ++model_name: "nm-testing/Qwen2-1.5B-Instruct-FP8W8" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.578 ++ - name: "exact_match,flexible-extract" ++ value: 0.585 ++limit: 1000 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml +new file mode 100644 +index 0000000..43ff2bc +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml +@@ -0,0 +1,11 @@ ++# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8 -b "auto" -l 1000 -f 5 -t 1 ++model_name: "neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.593 ++ - name: "exact_match,flexible-extract" ++ value: 0.588 ++limit: 1000 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml +new file mode 100644 +index 0000000..259799b +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml +@@ -0,0 +1,11 @@ ++# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise -b "auto" -l 1000 -f 5 -t 1 ++model_name: "nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.595 ++ - name: "exact_match,flexible-extract" ++ value: 0.582 ++limit: 1000 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml +new file mode 100644 +index 0000000..45d5efc +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml +@@ -0,0 +1,11 @@ ++# bash ./run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2-57B-A14B-Instruct -b "auto" -l 250 -f 5 -t 4 ++model_name: "Qwen/Qwen2-57B-A14B-Instruct" ++tasks: ++- name: "gsm8k" ++ metrics: ++ - name: "exact_match,strict-match" ++ value: 0.792 ++ - name: "exact_match,flexible-extract" ++ value: 0.824 ++limit: 250 ++num_fewshot: 5 +diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt +new file mode 100644 +index 0000000..37eeac8 +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/models-large.txt +@@ -0,0 +1,5 @@ ++Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml ++Meta-Llama-3-70B-Instruct.yaml ++Mixtral-8x7B-Instruct-v0.1.yaml ++Qwen2-57B-A14-Instruct.yaml ++DeepSeek-V2-Lite-Chat.yaml +diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt +new file mode 100644 +index 0000000..6057229 +--- /dev/null ++++ b/.buildkite/lm-eval-harness/configs/models-small.txt +@@ -0,0 +1,10 @@ ++Meta-Llama-3-8B-Instruct.yaml ++Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml ++Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml ++Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml ++Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml ++Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml ++Minitron-4B-Base-FP8.yaml ++Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml ++Qwen2-1.5B-Instruct-FP8W8.yaml ++Meta-Llama-3-8B-QQQ.yaml +diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh +new file mode 100644 +index 0000000..a67fc89 +--- /dev/null ++++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh +@@ -0,0 +1,46 @@ ++#!/bin/bash ++# We can use this script to compute baseline accuracy on GSM for transformers. ++# ++# Make sure you have lm-eval-harness installed: ++# pip install lm-eval==0.4.4 ++ ++usage() { ++ echo`` ++ echo "Runs lm eval harness on GSM8k using huggingface transformers." ++ echo "This pathway is intended to be used to create baselines for " ++ echo "our automated nm-test-accuracy workflow" ++ echo ++ echo "usage: ${0} " ++ echo ++ echo " -m - huggingface stub or local directory of the model" ++ echo " -b - batch size to run the evaluation at" ++ echo " -l - limit number of samples to run" ++ echo " -f - number of fewshot samples to use" ++ echo ++} ++ ++while getopts "m:b:l:f:" OPT; do ++ case ${OPT} in ++ m ) ++ MODEL="$OPTARG" ++ ;; ++ b ) ++ BATCH_SIZE="$OPTARG" ++ ;; ++ l ) ++ LIMIT="$OPTARG" ++ ;; ++ f ) ++ FEWSHOT="$OPTARG" ++ ;; ++ \? ) ++ usage ++ exit 1 ++ ;; ++ esac ++done ++ ++lm_eval --model hf \ ++ --model_args "pretrained=$MODEL,parallelize=True" \ ++ --tasks gsm8k --num_fewshot "$FEWSHOT" --limit "$LIMIT" \ ++ --batch_size "$BATCH_SIZE" +diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +new file mode 100644 +index 0000000..65be3c5 +--- /dev/null ++++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +@@ -0,0 +1,51 @@ ++#!/bin/bash ++# We can use this script to compute baseline accuracy on GSM for vllm. ++# We use this for fp8, which HF does not support. ++# ++# Make sure you have lm-eval-harness installed: ++# pip install lm-eval==0.4.4 ++ ++usage() { ++ echo`` ++ echo "Runs lm eval harness on GSM8k using huggingface transformers." ++ echo "This pathway is intended to be used to create baselines for " ++ echo "our automated nm-test-accuracy workflow" ++ echo ++ echo "usage: ${0} " ++ echo ++ echo " -m - huggingface stub or local directory of the model" ++ echo " -b - batch size to run the evaluation at" ++ echo " -l - limit number of samples to run" ++ echo " -f - number of fewshot samples to use" ++ echo " -t - tensor parallel size to run at" ++ echo ++} ++ ++while getopts "m:b:l:f:t:" OPT; do ++ case ${OPT} in ++ m ) ++ MODEL="$OPTARG" ++ ;; ++ b ) ++ BATCH_SIZE="$OPTARG" ++ ;; ++ l ) ++ LIMIT="$OPTARG" ++ ;; ++ f ) ++ FEWSHOT="$OPTARG" ++ ;; ++ t ) ++ TP_SIZE="$OPTARG" ++ ;; ++ \? ) ++ usage ++ exit 1 ++ ;; ++ esac ++done ++ ++lm_eval --model vllm \ ++ --model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend=ray,trust_remote_code=true,max_model_len=4096" \ ++ --tasks gsm8k --num_fewshot "$FEWSHOT" --limit "$LIMIT" \ ++ --batch_size "$BATCH_SIZE" +diff --git a/.buildkite/lm-eval-harness/run-tests.sh b/.buildkite/lm-eval-harness/run-tests.sh +new file mode 100644 +index 0000000..26f33b7 +--- /dev/null ++++ b/.buildkite/lm-eval-harness/run-tests.sh +@@ -0,0 +1,59 @@ ++#!/bin/bash ++ ++usage() { ++ echo`` ++ echo "Runs lm eval harness on GSM8k using vllm and compares to " ++ echo "precomputed baseline (measured by HF transformers.)" ++ echo ++ echo "usage: ${0} " ++ echo ++ echo " -c - path to the test data config (e.g. configs/small-models.txt)" ++ echo " -t - tensor parallel size" ++ echo ++} ++ ++SUCCESS=0 ++ ++while getopts "c:t:" OPT; do ++ case ${OPT} in ++ c ) ++ CONFIG="$OPTARG" ++ ;; ++ t ) ++ TP_SIZE="$OPTARG" ++ ;; ++ \? ) ++ usage ++ exit 1 ++ ;; ++ esac ++done ++ ++# Parse list of configs. ++IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "$CONFIG" ++ ++for MODEL_CONFIG in "${MODEL_CONFIGS[@]}" ++do ++ LOCAL_SUCCESS=0 ++ ++ echo "=== RUNNING MODEL: $MODEL_CONFIG WITH TP SIZE: $TP_SIZE===" ++ ++ export LM_EVAL_TEST_DATA_FILE=$PWD/configs/${MODEL_CONFIG} ++ export LM_EVAL_TP_SIZE=$TP_SIZE ++ pytest -s test_lm_eval_correctness.py || LOCAL_SUCCESS=$? ++ ++ if [[ $LOCAL_SUCCESS == 0 ]]; then ++ echo "=== PASSED MODEL: ${MODEL_CONFIG} ===" ++ else ++ echo "=== FAILED MODEL: ${MODEL_CONFIG} ===" ++ fi ++ ++ SUCCESS=$((SUCCESS + LOCAL_SUCCESS)) ++ ++done ++ ++if [ "${SUCCESS}" -eq "0" ]; then ++ exit 0 ++else ++ exit 1 ++fi +diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py +new file mode 100644 +index 0000000..afc935c +--- /dev/null ++++ b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py +@@ -0,0 +1,63 @@ ++""" ++LM eval harness on model to compare vs HF baseline computed offline. ++Configs are found in configs/$MODEL.yaml ++ ++* export LM_EVAL_TEST_DATA_FILE=configs/Meta-Llama-3-70B-Instruct.yaml ++* export LM_EVAL_TP_SIZE=4 ++* pytest -s test_lm_eval_correctness.py ++""" ++ ++import os ++from pathlib import Path ++ ++import lm_eval ++import numpy ++import yaml ++ ++RTOL = 0.05 ++TEST_DATA_FILE = os.environ.get( ++ "LM_EVAL_TEST_DATA_FILE", ++ ".buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml") ++ ++TP_SIZE = os.environ.get("LM_EVAL_TP_SIZE", 1) ++ ++ ++def launch_lm_eval(eval_config): ++ trust_remote_code = eval_config.get('trust_remote_code', False) ++ ++ model_args = f"pretrained={eval_config['model_name']}," \ ++ f"tensor_parallel_size={TP_SIZE}," \ ++ f"add_bos_token=true," \ ++ f"trust_remote_code={trust_remote_code}" ++ ++ results = lm_eval.simple_evaluate( ++ model="vllm", ++ model_args=model_args, ++ tasks=[task["name"] for task in eval_config["tasks"]], ++ num_fewshot=eval_config["num_fewshot"], ++ limit=eval_config["limit"], ++ batch_size="auto") ++ ++ return results ++ ++ ++def test_lm_eval_correctness(): ++ eval_config = yaml.safe_load( ++ Path(TEST_DATA_FILE).read_text(encoding="utf-8")) ++ ++ # Launch eval requests. ++ results = launch_lm_eval(eval_config) ++ ++ # Confirm scores match ground truth. ++ success = True ++ for task in eval_config["tasks"]: ++ for metric in task["metrics"]: ++ ground_truth = metric["value"] ++ measured_value = results["results"][task["name"]][metric["name"]] ++ print(f'{task["name"]} | {metric["name"]}: ' ++ f'ground_truth={ground_truth} | measured={measured_value}') ++ success = success and numpy.isclose( ++ ground_truth, measured_value, rtol=RTOL) ++ ++ # Assert at the end, print all scores even on failure for debugging. ++ assert success +diff --git a/.buildkite/nightly-benchmarks/README.md b/.buildkite/nightly-benchmarks/README.md +new file mode 100644 +index 0000000..fbf41eb +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/README.md +@@ -0,0 +1,153 @@ ++# vLLM benchmark suite ++ ++ ++## Introduction ++ ++This directory contains two sets of benchmark for vllm. ++- Performance benchmark: benchmark vllm's performance under various workload, for **developers** to gain clarity on whether their PR improves/degrades vllm's performance ++- Nightly benchmark: compare vllm's performance against alternatives (tgi, trt-llm and lmdeploy), for **the public** to know when to choose vllm. ++ ++ ++See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performance benchmark results and [vLLM GitHub README](https://github.com/vllm-project/vllm/blob/main/README.md) for latest nightly benchmark results. ++ ++ ++## Performance benchmark quick overview ++ ++**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!), with different models. ++ ++**Benchmarking Duration**: about 1hr. ++ ++**For benchmarking developers**: please try your best to constraint the duration of benchmarking to about 1 hr so that it won't take forever to run. ++ ++ ++## Nightly benchmark quick overview ++ ++**Benchmarking Coverage**: Fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!) on Llama-3 8B, 70B and Mixtral 8x7B. ++ ++**Benchmarking engines**: vllm, TGI, trt-llm and lmdeploy. ++ ++**Benchmarking Duration**: about 3.5hrs. ++ ++ ++ ++## Trigger the benchmark ++ ++Performance benchmark will be triggered when: ++- A PR being merged into vllm. ++- Every commit for those PRs with `perf-benchmarks` label AND `ready` label. ++ ++Nightly benchmark will be triggered when: ++- Every commit for those PRs with `perf-benchmarks` label and `nightly-benchmarks` label. ++ ++ ++ ++ ++## Performance benchmark details ++ ++ ++See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases. ++ ++ ++#### Latency test ++ ++Here is an example of one test inside `latency-tests.json`: ++ ++```json ++[ ++ { ++ "test_name": "latency_llama8B_tp1", ++ "parameters": { ++ "model": "meta-llama/Meta-Llama-3-8B", ++ "tensor_parallel_size": 1, ++ "load_format": "dummy", ++ "num_iters_warmup": 5, ++ "num_iters": 15 ++ } ++ }, ++] ++``` ++ ++In this example: ++- The `test_name` attributes is a unique identifier for the test. In `latency-tests.json`, it must start with `latency_`. ++- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-performance-benchmarks.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15` ++ ++Note that the performance numbers are highly sensitive to the value of the parameters. Please make sure the parameters are set correctly. ++ ++WARNING: The benchmarking script will save json results by itself, so please do not configure `--output-json` parameter in the json file. ++ ++ ++#### Throughput test ++The tests are specified in `throughput-tests.json`. The syntax is similar to `latency-tests.json`, except for that the parameters will be fed forward to `benchmark_throughput.py`. ++ ++The number of this test is also stable -- a slight change on the value of this number might vary the performance numbers by a lot. ++ ++#### Serving test ++We test the throughput by using `benchmark_serving.py` with request rate = inf to cover the online serving overhead. The corresponding parameters are in `serving-tests.json`, and here is an example: ++ ++``` ++[ ++ { ++ "test_name": "serving_llama8B_tp1_sharegpt", ++ "qps_list": [1, 4, 16, "inf"], ++ "server_parameters": { ++ "model": "meta-llama/Meta-Llama-3-8B", ++ "tensor_parallel_size": 1, ++ "swap_space": 16, ++ "disable_log_stats": "", ++ "disable_log_requests": "", ++ "load_format": "dummy" ++ }, ++ "client_parameters": { ++ "model": "meta-llama/Meta-Llama-3-8B", ++ "backend": "vllm", ++ "dataset_name": "sharegpt", ++ "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", ++ "num_prompts": 200 ++ } ++ }, ++] ++``` ++ ++Inside this example: ++- The `test_name` attribute is also a unique identifier for the test. It must start with `serving_`. ++- The `server-parameters` includes the command line arguments for vLLM server. ++- The `client-parameters` includes the command line arguments for `benchmark_serving.py`. ++- The `qps_list` controls the list of qps for test. It will be used to configure the `--request-rate` parameter in `benchmark_serving.py` ++ ++The number of this test is less stable compared to the delay and latency benchmarks (due to randomized sharegpt dataset sampling inside `benchmark_serving.py`), but a large change on this number (e.g. 5% change) still vary the output greatly. ++ ++WARNING: The benchmarking script will save json results by itself, so please do not configure `--save-results` or other results-saving-related parameters in `serving-tests.json`. ++ ++#### Visualizing the results ++The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](tests/descriptions.md) with real benchmarking results. ++You can find the result presented as a table inside the `buildkite/performance-benchmark` job page. ++If you do not see the table, please wait till the benchmark finish running. ++The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file. ++The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking. ++ ++ ++ ++## Nightly test details ++ ++See [nightly-descriptions.md](nightly-descriptions.md) for the detailed description on test workload, models and docker containers of benchmarking other llm engines. ++ ++ ++#### Workflow ++ ++- The [nightly-pipeline.yaml](nightly-pipeline.yaml) specifies the docker containers for different LLM serving engines. ++- Inside each container, we run [run-nightly-suite.sh](run-nightly-suite.sh), which will probe the serving engine of the current container. ++- The `run-nightly-suite.sh` will redirect the request to `tests/run-[llm serving engine name]-nightly.sh`, which parses the workload described in [nightly-tests.json](tests/nightly-tests.json) and performs the benchmark. ++- At last, we run [scripts/plot-nightly-results.py](scripts/plot-nightly-results.py) to collect and plot the final benchmarking results, and update the results to buildkite. ++ ++#### Nightly tests ++ ++In [nightly-tests.json](tests/nightly-tests.json), we include the command line arguments for benchmarking commands, together with the benchmarking test cases. The format is highly similar to performance benchmark. ++ ++#### Docker containers ++ ++The docker containers for benchmarking are specified in `nightly-pipeline.yaml`. ++ ++WARNING: the docker versions are HARD-CODED and SHOULD BE ALIGNED WITH `nightly-descriptions.md`. The docker versions need to be hard-coded as there are several version-specific bug fixes inside `tests/run-[llm serving engine name]-nightly.sh`. ++ ++WARNING: populating `trt-llm` to latest version is not easy, as it requires updating several protobuf files in [tensorrt-demo](https://github.com/neuralmagic/tensorrt-demo.git). ++ +diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +new file mode 100644 +index 0000000..679abf1 +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +@@ -0,0 +1,92 @@ ++steps: ++ - label: "Wait for container to be ready" ++ key: wait-for-container-image ++ agents: ++ queue: A100 ++ plugins: ++ - kubernetes: ++ podSpec: ++ containers: ++ - image: badouralix/curl-jq ++ command: ++ - sh .buildkite/nightly-benchmarks/scripts/wait-for-image.sh ++ ++ - label: "A100" ++ # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" ++ agents: ++ queue: A100 ++ depends_on: wait-for-container-image ++ plugins: ++ - kubernetes: ++ podSpec: ++ priorityClassName: perf-benchmark ++ containers: ++ - image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT ++ command: ++ - bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh ++ resources: ++ limits: ++ nvidia.com/gpu: 8 ++ volumeMounts: ++ - name: devshm ++ mountPath: /dev/shm ++ env: ++ - name: VLLM_USAGE_SOURCE ++ value: ci-test ++ - name: HF_TOKEN ++ valueFrom: ++ secretKeyRef: ++ name: hf-token-secret ++ key: token ++ nodeSelector: ++ nvidia.com/gpu.product: NVIDIA-A100-SXM4-80GB ++ volumes: ++ - name: devshm ++ emptyDir: ++ medium: Memory ++ ++ - label: "H200" ++ # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" ++ agents: ++ queue: H200 ++ depends_on: wait-for-container-image ++ plugins: ++ - docker#v5.12.0: ++ image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT ++ command: ++ - bash ++ - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh ++ mount-buildkite-agent: true ++ propagate-environment: true ++ ipc: host ++ gpus: 4,5,6,7 ++ volumes: ++ - /data/benchmark-hf-cache:/root/.cache/huggingface ++ environment: ++ - VLLM_USAGE_SOURCE ++ - HF_TOKEN ++ ++ #- block: "Run H100 Benchmark" ++ #key: block-h100 ++ #depends_on: ~ ++ ++ - label: "H100" ++ # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" ++ agents: ++ queue: H100 ++ depends_on: wait-for-container-image ++ plugins: ++ - docker#v5.12.0: ++ image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT ++ command: ++ - bash ++ - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh ++ mount-buildkite-agent: true ++ propagate-environment: true ++ ipc: host ++ gpus: all # see CUDA_VISIBLE_DEVICES for actual GPUs used ++ volumes: ++ - /data/benchmark-hf-cache:/root/.cache/huggingface ++ environment: ++ - VLLM_USAGE_SOURCE ++ - HF_TOKEN +diff --git a/.buildkite/nightly-benchmarks/nightly-annotation.md b/.buildkite/nightly-benchmarks/nightly-annotation.md +new file mode 100644 +index 0000000..1e33793 +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/nightly-annotation.md +@@ -0,0 +1,28 @@ ++ ++## Description ++ ++This file contains the downloading link for benchmarking results. ++ ++- [benchmarking pipeline](artifact://nightly-pipeline.yaml) ++- [benchmarking results](artifact://results.zip) ++- [benchmarking code](artifact://nightly-benchmarks.zip) ++ ++Please download the visualization scripts in the post ++ ++ ++## Results reproduction ++ ++- Find the docker we use in `benchmarking pipeline` ++- Deploy the docker, and inside the docker: ++ - Download `nightly-benchmarks.zip`. ++ - In the same folder, run the following code ++``` ++export HF_TOKEN= ++apt update ++apt install -y git ++unzip nightly-benchmarks.zip ++VLLM_SOURCE_CODE_LOC=./ bash .buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh ++``` ++ ++And the results will be inside `./benchmarks/results`. ++ +diff --git a/.buildkite/nightly-benchmarks/nightly-descriptions.md b/.buildkite/nightly-benchmarks/nightly-descriptions.md +new file mode 100644 +index 0000000..7dec7a0 +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/nightly-descriptions.md +@@ -0,0 +1,39 @@ ++ ++# Nightly benchmark ++ ++This benchmark aims to: ++- Provide performance clarity: Provide clarity on which one (vllm, tensorrt-llm, lmdeploy and SGLang) leads in performance in what workload. ++- Be reproducible: one can run the exact same set of benchmarking commands inside the exact same docker by following reproducing instructions. ++ ++Latest results: [results link](https://blog.vllm.ai/2024/09/05/perf-update.html), scroll to the end. ++ ++Latest reproduction guilde: [github issue link](https://github.com/vllm-project/vllm/issues/8176) ++ ++ ++## Setup ++ ++- Docker images: ++ - vLLM: `vllm/vllm-openai:v0.6.2` ++ - SGLang: `lmsysorg/sglang:v0.3.2-cu121` ++ - LMDeploy: `openmmlab/lmdeploy:v0.6.1-cu12` ++ - TensorRT-LLM: `nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3` ++ - *NOTE: we uses r24.07 as the current implementation only works for this version. We are going to bump this up.* ++ - Check [nightly-pipeline.yaml](nightly-pipeline.yaml) for the concrete docker images, specs and commands we use for the benchmark. ++- Hardware ++ - 8x Nvidia A100 GPUs ++- Workload: ++ - Dataset ++ - ShareGPT dataset ++ - Prefill-heavy dataset (in average 462 input tokens, 16 tokens as output) ++ - Decode-heavy dataset (in average 462 input tokens, 256 output tokens) ++ - Check [nightly-tests.json](tests/nightly-tests.json) for the concrete configuration of datasets we use. ++ - Models: llama-3 8B, llama-3 70B. ++ - We do not use llama 3.1 as it is incompatible with trt-llm r24.07. ([issue](https://github.com/NVIDIA/TensorRT-LLM/issues/2105)). ++ - Average QPS (query per second): 2, 4, 8, 16, 32 and inf. ++ - Queries are randomly sampled, and arrival patterns are determined via Poisson process, but all with fixed random seed. ++ - Evaluation metrics: Throughput (higher the better), TTFT (time to the first token, lower the better), ITL (inter-token latency, lower the better). ++ ++# Known issues ++ ++- TRT-LLM crashes with Llama 3.1 8B [issue](https://github.com/NVIDIA/TensorRT-LLM/issues/2105). ++- TGI does not support `ignore-eos` flag. +\ No newline at end of file +diff --git a/.buildkite/nightly-benchmarks/nightly-pipeline.yaml b/.buildkite/nightly-benchmarks/nightly-pipeline.yaml +new file mode 100644 +index 0000000..199517e +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/nightly-pipeline.yaml +@@ -0,0 +1,196 @@ ++common_pod_spec: &common_pod_spec ++ priorityClassName: perf-benchmark ++ nodeSelector: ++ nvidia.com/gpu.product: NVIDIA-A100-SXM4-80GB ++ volumes: ++ - name: devshm ++ emptyDir: ++ medium: Memory ++ - name: hf-cache ++ hostPath: ++ path: /root/.cache/huggingface ++ type: Directory ++ ++common_container_settings: &common_container_settings ++ command: ++ - bash .buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh ++ resources: ++ limits: ++ nvidia.com/gpu: 8 ++ volumeMounts: ++ - name: devshm ++ mountPath: /dev/shm ++ - name: hf-cache ++ mountPath: /root/.cache/huggingface ++ env: ++ - name: VLLM_USAGE_SOURCE ++ value: ci-test ++ - name: HF_HOME ++ value: /root/.cache/huggingface ++ - name: VLLM_SOURCE_CODE_LOC ++ value: /workspace/build/buildkite/vllm/performance-benchmark ++ - name: HF_TOKEN ++ valueFrom: ++ secretKeyRef: ++ name: hf-token-secret ++ key: token ++ ++steps: ++ - block: ":rocket: Ready for comparing vllm against alternatives? This will take 4 hours." ++ ++ ++ ++ - label: "A100 vllm step 10" ++ priority: 100 ++ agents: ++ queue: A100 ++ plugins: ++ - kubernetes: ++ podSpec: ++ <<: *common_pod_spec ++ containers: ++ - image: vllm/vllm-openai:v0.6.2 ++ <<: *common_container_settings ++ ++ ++ ++ - label: "A100 sglang benchmark" ++ priority: 100 ++ agents: ++ queue: A100 ++ plugins: ++ - kubernetes: ++ podSpec: ++ <<: *common_pod_spec ++ containers: ++ - image: lmsysorg/sglang:v0.3.2-cu121 ++ <<: *common_container_settings ++ ++ - label: "A100 lmdeploy benchmark" ++ priority: 100 ++ agents: ++ queue: A100 ++ plugins: ++ - kubernetes: ++ podSpec: ++ <<: *common_pod_spec ++ containers: ++ - image: openmmlab/lmdeploy:v0.6.1-cu12 ++ <<: *common_container_settings ++ ++ ++ ++ ++ - label: "A100 trt llama-8B" ++ priority: 100 ++ agents: ++ queue: A100 ++ plugins: ++ - kubernetes: ++ podSpec: ++ <<: *common_pod_spec ++ containers: ++ - image: nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3 ++ <<: *common_container_settings ++ env: ++ - name: VLLM_USAGE_SOURCE ++ value: ci-test ++ - name: HF_HOME ++ value: /root/.cache/huggingface ++ - name: VLLM_SOURCE_CODE_LOC ++ value: /workspace/build/buildkite/vllm/performance-benchmark ++ - name: HF_TOKEN ++ valueFrom: ++ secretKeyRef: ++ name: hf-token-secret ++ key: token ++ - name: TEST_SELECTOR ++ value: "llama8B" ++ ++ ++ - label: "A100 trt llama-70B" ++ priority: 100 ++ agents: ++ queue: A100 ++ plugins: ++ - kubernetes: ++ podSpec: ++ <<: *common_pod_spec ++ containers: ++ - image: nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3 ++ <<: *common_container_settings ++ env: ++ - name: VLLM_USAGE_SOURCE ++ value: ci-test ++ - name: HF_HOME ++ value: /root/.cache/huggingface ++ - name: VLLM_SOURCE_CODE_LOC ++ value: /workspace/build/buildkite/vllm/performance-benchmark ++ - name: HF_TOKEN ++ valueFrom: ++ secretKeyRef: ++ name: hf-token-secret ++ key: token ++ - name: TEST_SELECTOR ++ value: "llama70B" ++ ++ ++ # FIXME(Kuntai): uncomment this after NVIDIA gives us their test docker image ++ # - label: "A100 trt benchmark" ++ # priority: 100 ++ # agents: ++ # queue: A100 ++ # plugins: ++ # - kubernetes: ++ # podSpec: ++ # <<: *common_pod_spec ++ # containers: ++ # - image: nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3 ++ # <<: *common_container_settings ++ ++ ++ # FIXME(Kuntai): uncomment this after TGI supports `--ignore-eos`. ++ # - label: "A100 tgi benchmark" ++ # priority: 100 ++ # agents: ++ # queue: A100 ++ # plugins: ++ # - kubernetes: ++ # podSpec: ++ # <<: *common_pod_spec ++ # containers: ++ # - image: ghcr.io/huggingface/text-generation-inference:2.2.0 ++ # <<: *common_container_settings ++ ++ - wait ++ ++ - label: "Collect the results" ++ priority: 100 ++ agents: ++ queue: A100 ++ plugins: ++ - kubernetes: ++ podSpec: ++ <<: *common_pod_spec ++ containers: ++ - image: vllm/vllm-openai:v0.5.0.post1 ++ command: ++ - bash .buildkite/nightly-benchmarks/scripts/nightly-annotate.sh ++ resources: ++ limits: ++ nvidia.com/gpu: 8 ++ volumeMounts: ++ - name: devshm ++ mountPath: /dev/shm ++ env: ++ - name: VLLM_USAGE_SOURCE ++ value: ci-test ++ - name: VLLM_SOURCE_CODE_LOC ++ value: /workspace/build/buildkite/vllm/performance-benchmark ++ - name: HF_TOKEN ++ valueFrom: ++ secretKeyRef: ++ name: hf-token-secret ++ key: token ++ ++ - block: ":rocket: check the results!" +\ No newline at end of file +diff --git a/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md b/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md +new file mode 100644 +index 0000000..da32d1f +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md +@@ -0,0 +1,62 @@ ++ ++## Latency tests ++ ++- Input length: 32 tokens. ++- Output length: 128 tokens. ++- Batch size: fixed (8). ++- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. ++- Evaluation metrics: end-to-end latency (mean, median, p99). ++ ++ ++{latency_tests_markdown_table} ++ ++ ++## Throughput tests ++ ++- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). ++- Output length: the corresponding output length of these 200 prompts. ++- Batch size: dynamically determined by vllm to achieve maximum throughput. ++- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. ++- Evaluation metrics: throughput. ++ ++ ++{throughput_tests_markdown_table} ++ ++ ++## Serving tests ++ ++- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). ++- Output length: the corresponding output length of these 200 prompts. ++- Batch size: dynamically determined by vllm and the arrival pattern of the requests. ++- **Average QPS (query per second)**: 1, 4, 16 and inf. QPS = inf means all requests come at once. For other QPS values, the arrival time of each query is determined using a random Poisson process (with fixed random seed). ++- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. ++- We also added a speculative decoding test for llama-3 70B, under QPS 2 ++- Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99). ++ ++ ++{serving_tests_markdown_table} ++ ++ ++## json version of the benchmarking tables ++ ++This section contains the data of the markdown tables above in JSON format. ++You can load the benchmarking tables into pandas dataframes as follows: ++ ++```python ++import json ++import pandas as pd ++ ++benchmarking_results_json = """The json string""" ++benchmarking_results = json.loads(benchmarking_results_json) ++latency_results = pd.DataFrame.from_dict(benchmarking_results["latency"]) ++throughput_results = pd.DataFrame.from_dict(benchmarking_results["throughput"]) ++serving_results = pd.DataFrame.from_dict(benchmarking_results["serving"]) ++``` ++ ++The json string for all benchmarking tables: ++```json ++{benchmarking_results_in_json_string} ++``` ++ ++You can also check the raw experiment data in the Artifact tab of the Buildkite page. ++ +diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +new file mode 100644 +index 0000000..9d3646e +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +@@ -0,0 +1,204 @@ ++import json ++import os ++from pathlib import Path ++ ++import pandas as pd ++from tabulate import tabulate ++ ++results_folder = Path("results/") ++ ++# latency results and the keys that will be printed into markdown ++latency_results = [] ++latency_column_mapping = { ++ "test_name": "Test name", ++ "gpu_type": "GPU", ++ "avg_latency": "Mean latency (ms)", ++ # "P10": "P10 (s)", ++ # "P25": "P25 (s)", ++ "P50": "Median latency (ms)", ++ # "P75": "P75 (s)", ++ # "P90": "P90 (s)", ++ "P99": "P99 latency (ms)", ++} ++ ++# throughput tests and the keys that will be printed into markdown ++throughput_results = [] ++throughput_results_column_mapping = { ++ "test_name": "Test name", ++ "gpu_type": "GPU", ++ # "num_requests": "# of req.", ++ # "total_num_tokens": "Total # of tokens", ++ # "elapsed_time": "Elapsed time (s)", ++ "requests_per_second": "Tput (req/s)", ++ # "tokens_per_second": "Tput (tok/s)", ++} ++ ++# serving results and the keys that will be printed into markdown ++serving_results = [] ++serving_column_mapping = { ++ "test_name": "Test name", ++ "gpu_type": "GPU", ++ # "completed": "# of req.", ++ "request_throughput": "Tput (req/s)", ++ # "input_throughput": "Input Tput (tok/s)", ++ # "output_throughput": "Output Tput (tok/s)", ++ "mean_ttft_ms": "Mean TTFT (ms)", ++ "median_ttft_ms": "Median TTFT (ms)", ++ "p99_ttft_ms": "P99 TTFT (ms)", ++ # "mean_tpot_ms": "Mean TPOT (ms)", ++ # "median_tpot_ms": "Median", ++ # "p99_tpot_ms": "P99", ++ "mean_itl_ms": "Mean ITL (ms)", ++ "median_itl_ms": "Median ITL (ms)", ++ "p99_itl_ms": "P99 ITL (ms)", ++} ++ ++ ++def read_markdown(file): ++ if os.path.exists(file): ++ with open(file) as f: ++ return f.read() + "\n" ++ else: ++ return f"{file} not found.\n" ++ ++ ++def results_to_json(latency, throughput, serving): ++ return json.dumps({ ++ 'latency': latency.to_dict(), ++ 'throughput': throughput.to_dict(), ++ 'serving': serving.to_dict() ++ }) ++ ++ ++if __name__ == "__main__": ++ ++ # collect results ++ for test_file in results_folder.glob("*.json"): ++ ++ with open(test_file) as f: ++ raw_result = json.loads(f.read()) ++ ++ if "serving" in str(test_file): ++ # this result is generated via `benchmark_serving.py` ++ ++ # attach the benchmarking command to raw_result ++ with open(test_file.with_suffix(".commands")) as f: ++ command = json.loads(f.read()) ++ raw_result.update(command) ++ ++ # update the test name of this result ++ raw_result.update({"test_name": test_file.stem}) ++ ++ # add the result to raw_result ++ serving_results.append(raw_result) ++ continue ++ ++ elif "latency" in f.name: ++ # this result is generated via `benchmark_latency.py` ++ ++ # attach the benchmarking command to raw_result ++ with open(test_file.with_suffix(".commands")) as f: ++ command = json.loads(f.read()) ++ raw_result.update(command) ++ ++ # update the test name of this result ++ raw_result.update({"test_name": test_file.stem}) ++ ++ # get different percentiles ++ for perc in [10, 25, 50, 75, 90, 99]: ++ # Multiply 1000 to convert the time unit from s to ms ++ raw_result.update( ++ {f"P{perc}": 1000 * raw_result["percentiles"][str(perc)]}) ++ raw_result["avg_latency"] = raw_result["avg_latency"] * 1000 ++ ++ # add the result to raw_result ++ latency_results.append(raw_result) ++ continue ++ ++ elif "throughput" in f.name: ++ # this result is generated via `benchmark_throughput.py` ++ ++ # attach the benchmarking command to raw_result ++ with open(test_file.with_suffix(".commands")) as f: ++ command = json.loads(f.read()) ++ raw_result.update(command) ++ ++ # update the test name of this result ++ raw_result.update({"test_name": test_file.stem}) ++ ++ # add the result to raw_result ++ throughput_results.append(raw_result) ++ continue ++ ++ print(f"Skipping {test_file}") ++ ++ latency_results = pd.DataFrame.from_dict(latency_results) ++ serving_results = pd.DataFrame.from_dict(serving_results) ++ throughput_results = pd.DataFrame.from_dict(throughput_results) ++ ++ raw_results_json = results_to_json(latency_results, throughput_results, ++ serving_results) ++ ++ # remapping the key, for visualization purpose ++ if not latency_results.empty: ++ latency_results = latency_results[list( ++ latency_column_mapping.keys())].rename( ++ columns=latency_column_mapping) ++ if not serving_results.empty: ++ serving_results = serving_results[list( ++ serving_column_mapping.keys())].rename( ++ columns=serving_column_mapping) ++ if not throughput_results.empty: ++ throughput_results = throughput_results[list( ++ throughput_results_column_mapping.keys())].rename( ++ columns=throughput_results_column_mapping) ++ ++ processed_results_json = results_to_json(latency_results, ++ throughput_results, ++ serving_results) ++ ++ for df in [latency_results, serving_results, throughput_results]: ++ if df.empty: ++ continue ++ ++ # Sort all dataframes by their respective "Test name" columns ++ df.sort_values(by="Test name", inplace=True) ++ ++ # The GPUs sometimes come in format of "GPUTYPE\nGPUTYPE\n...", ++ # we want to turn it into "8xGPUTYPE" ++ df["GPU"] = df["GPU"].apply( ++ lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}") ++ ++ # get markdown tables ++ latency_md_table = tabulate(latency_results, ++ headers='keys', ++ tablefmt='pipe', ++ showindex=False) ++ serving_md_table = tabulate(serving_results, ++ headers='keys', ++ tablefmt='pipe', ++ showindex=False) ++ throughput_md_table = tabulate(throughput_results, ++ headers='keys', ++ tablefmt='pipe', ++ showindex=False) ++ ++ # document the result ++ with open(results_folder / "benchmark_results.md", "w") as f: ++ ++ results = read_markdown("../.buildkite/nightly-benchmarks/" + ++ "performance-benchmarks-descriptions.md") ++ results = results.format( ++ latency_tests_markdown_table=latency_md_table, ++ throughput_tests_markdown_table=throughput_md_table, ++ serving_tests_markdown_table=serving_md_table, ++ benchmarking_results_in_json_string=processed_results_json) ++ f.write(results) ++ ++ # document benchmarking results in json ++ with open(results_folder / "benchmark_results.json", "w") as f: ++ ++ results = latency_results.to_dict( ++ orient='records') + throughput_results.to_dict( ++ orient='records') + serving_results.to_dict(orient='records') ++ f.write(json.dumps(results)) +diff --git a/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py b/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py +new file mode 100644 +index 0000000..68ac590 +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py +@@ -0,0 +1,26 @@ ++import argparse ++ ++from transformers import AutoTokenizer ++ ++ ++def main(model, cachedir): ++ # Load the tokenizer and save it to the specified directory ++ tokenizer = AutoTokenizer.from_pretrained(model) ++ tokenizer.save_pretrained(cachedir) ++ print(f"Tokenizer saved to {cachedir}") ++ ++ ++if __name__ == "__main__": ++ parser = argparse.ArgumentParser( ++ description="Download and save Hugging Face tokenizer") ++ parser.add_argument("--model", ++ type=str, ++ required=True, ++ help="Name of the model") ++ parser.add_argument("--cachedir", ++ type=str, ++ required=True, ++ help="Directory to save the tokenizer") ++ ++ args = parser.parse_args() ++ main(args.model, args.cachedir) +diff --git a/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py b/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py +new file mode 100644 +index 0000000..052060c +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py +@@ -0,0 +1,95 @@ ++import argparse ++import json ++from pathlib import Path ++ ++import numpy as np ++import pandas as pd ++from tabulate import tabulate ++ ++ ++def parse_arguments(): ++ parser = argparse.ArgumentParser( ++ description= ++ 'Parse command line arguments for summary-nightly-results script.') ++ parser.add_argument('--results-folder', ++ type=str, ++ required=True, ++ help='The folder where the results are stored.') ++ parser.add_argument('--description', ++ type=str, ++ required=True, ++ help='Description of the results.') ++ ++ args = parser.parse_args() ++ return args ++ ++ ++def get_perf(df, method, model, metric): ++ ++ means = [] ++ ++ for qps in [2, 4, 8, 16, "inf"]: ++ target = df['Test name'].str.contains(model) ++ target = target & df['Engine'].str.contains(method) ++ target = target & df['Test name'].str.contains("qps_" + str(qps)) ++ filtered_df = df[target] ++ ++ if filtered_df.empty: ++ means.append(0.) ++ else: ++ means.append(filtered_df[metric].values[0]) ++ ++ return np.array(means) ++ ++ ++def get_perf_w_std(df, method, model, metric): ++ ++ if metric in ["TTFT", "ITL"]: ++ mean = get_perf(df, method, model, "Mean " + metric + " (ms)") ++ mean = mean.tolist() ++ std = get_perf(df, method, model, "Std " + metric + " (ms)") ++ if std.mean() == 0: ++ std = None ++ success = get_perf(df, method, model, "Successful req.") ++ if std is not None: ++ std = std / np.sqrt(success) ++ std = std.tolist() ++ ++ else: ++ assert metric == "Tput" ++ mean = get_perf(df, method, model, "Input Tput (tok/s)") + get_perf( ++ df, method, model, "Output Tput (tok/s)") ++ mean = mean.tolist() ++ std = None ++ ++ return mean, std ++ ++ ++def main(args): ++ results_folder = Path(args.results_folder) ++ ++ results = [] ++ ++ # collect results ++ for test_file in results_folder.glob("*_nightly_results.json"): ++ with open(test_file) as f: ++ results = results + json.loads(f.read()) ++ ++ # generate markdown table ++ df = pd.DataFrame.from_dict(results) ++ ++ md_table = tabulate(df, headers='keys', tablefmt='pipe', showindex=False) ++ ++ with open(args.description) as f: ++ description = f.read() ++ ++ description = description.format( ++ nightly_results_benchmarking_table=md_table) ++ ++ with open("nightly_results.md", "w") as f: ++ f.write(description) ++ ++ ++if __name__ == '__main__': ++ args = parse_arguments() ++ main(args) +diff --git a/.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py b/.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py +new file mode 100644 +index 0000000..18bcc3a +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py +@@ -0,0 +1,6 @@ ++from lmdeploy.serve.openai.api_client import APIClient ++ ++api_client = APIClient("http://localhost:8000") ++model_name = api_client.available_models[0] ++ ++print(model_name) +diff --git a/.buildkite/nightly-benchmarks/scripts/launch-server.sh b/.buildkite/nightly-benchmarks/scripts/launch-server.sh +new file mode 100644 +index 0000000..fb5063d +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/scripts/launch-server.sh +@@ -0,0 +1,228 @@ ++#!/bin/bash ++ ++# Currently FP8 benchmark is NOT enabled. ++ ++set -x ++server_params=$1 ++common_params=$2 ++ ++json2args() { ++ # transforms the JSON string to command line args, and '_' is replaced to '-' ++ # example: ++ # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } ++ # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 ++ local json_string=$1 ++ local args=$( ++ echo "$json_string" | jq -r ' ++ to_entries | ++ map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | ++ join(" ") ++ ' ++ ) ++ echo "$args" ++} ++ ++launch_trt_server() { ++ ++ model_path=$(echo "$common_params" | jq -r '.model') ++ model_name="${model_path#*/}" ++ model_type=$(echo "$server_params" | jq -r '.model_type') ++ model_dtype=$(echo "$server_params" | jq -r '.model_dtype') ++ model_tp_size=$(echo "$common_params" | jq -r '.tp') ++ max_batch_size=$(echo "$server_params" | jq -r '.max_batch_size') ++ max_input_len=$(echo "$server_params" | jq -r '.max_input_len') ++ max_seq_len=$(echo "$server_params" | jq -r '.max_seq_len') ++ max_num_tokens=$(echo "$server_params" | jq -r '.max_num_tokens') ++ trt_llm_version=$(echo "$server_params" | jq -r '.trt_llm_version') ++ ++ # create model caching directory ++ cd ~ ++ rm -rf models ++ mkdir -p models ++ cd models ++ models_dir=$(pwd) ++ trt_model_path=${models_dir}/${model_name}-trt-ckpt ++ trt_engine_path=${models_dir}/${model_name}-trt-engine ++ ++ # clone tensorrt backend ++ cd / ++ rm -rf tensorrtllm_backend ++ git clone https://github.com/triton-inference-server/tensorrtllm_backend.git ++ git lfs install ++ cd tensorrtllm_backend ++ git checkout "$trt_llm_version" ++ git submodule update --init --recursive ++ ++ # build trtllm engine ++ cd /tensorrtllm_backend ++ cd "./tensorrt_llm/examples/${model_type}" ++ python3 convert_checkpoint.py \ ++ --model_dir "${model_path}" \ ++ --dtype "${model_dtype}" \ ++ --tp_size "${model_tp_size}" \ ++ --output_dir "${trt_model_path}" ++ trtllm-build \ ++ --checkpoint_dir "${trt_model_path}" \ ++ --use_fused_mlp \ ++ --reduce_fusion disable \ ++ --workers 8 \ ++ --gpt_attention_plugin "${model_dtype}" \ ++ --gemm_plugin "${model_dtype}" \ ++ --tp_size "${model_tp_size}" \ ++ --max_batch_size "${max_batch_size}" \ ++ --max_input_len "${max_input_len}" \ ++ --max_seq_len "${max_seq_len}" \ ++ --max_num_tokens "${max_num_tokens}" \ ++ --output_dir "${trt_engine_path}" ++ ++ # handle triton protobuf files and launch triton server ++ cd /tensorrtllm_backend ++ mkdir triton_model_repo ++ cp -r all_models/inflight_batcher_llm/* triton_model_repo/ ++ cd triton_model_repo ++ rm -rf ./tensorrt_llm/1/* ++ cp -r "${trt_engine_path}"/* ./tensorrt_llm/1 ++ python3 ../tools/fill_template.py -i tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,engine_dir:/tensorrtllm_backend/triton_model_repo/tensorrt_llm/1,decoupled_mode:true,batching_strategy:inflight_fused_batching,batch_scheduler_policy:guaranteed_no_evict,exclude_input_in_output:true,triton_max_batch_size:2048,max_queue_delay_microseconds:0,max_beam_width:1,max_queue_size:2048,enable_kv_cache_reuse:false ++ python3 ../tools/fill_template.py -i preprocessing/config.pbtxt "triton_max_batch_size:2048,tokenizer_dir:$model_path,preprocessing_instance_count:5" ++ python3 ../tools/fill_template.py -i postprocessing/config.pbtxt "triton_max_batch_size:2048,tokenizer_dir:$model_path,postprocessing_instance_count:5,skip_special_tokens:false" ++ python3 ../tools/fill_template.py -i ensemble/config.pbtxt triton_max_batch_size:"$max_batch_size" ++ python3 ../tools/fill_template.py -i tensorrt_llm_bls/config.pbtxt "triton_max_batch_size:$max_batch_size,decoupled_mode:true,accumulate_tokens:False,bls_instance_count:1" ++ cd /tensorrtllm_backend ++ python3 scripts/launch_triton_server.py \ ++ --world_size="${model_tp_size}" \ ++ --model_repo=/tensorrtllm_backend/triton_model_repo & ++ ++} ++ ++launch_tgi_server() { ++ model=$(echo "$common_params" | jq -r '.model') ++ tp=$(echo "$common_params" | jq -r '.tp') ++ port=$(echo "$common_params" | jq -r '.port') ++ server_args=$(json2args "$server_params") ++ ++ if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then ++ echo "Key 'fp8' exists in common params." ++ server_command="/tgi-entrypoint.sh \ ++ --model-id $model \ ++ --num-shard $tp \ ++ --port $port \ ++ --quantize fp8 \ ++ $server_args" ++ else ++ echo "Key 'fp8' does not exist in common params." ++ server_command="/tgi-entrypoint.sh \ ++ --model-id $model \ ++ --num-shard $tp \ ++ --port $port \ ++ $server_args" ++ fi ++ ++ echo "Server command: $server_command" ++ eval "$server_command" & ++ ++} ++ ++launch_lmdeploy_server() { ++ model=$(echo "$common_params" | jq -r '.model') ++ tp=$(echo "$common_params" | jq -r '.tp') ++ port=$(echo "$common_params" | jq -r '.port') ++ server_args=$(json2args "$server_params") ++ ++ server_command="lmdeploy serve api_server $model \ ++ --tp $tp \ ++ --server-port $port \ ++ $server_args" ++ ++ # run the server ++ echo "Server command: $server_command" ++ bash -c "$server_command" & ++} ++ ++launch_sglang_server() { ++ ++ model=$(echo "$common_params" | jq -r '.model') ++ tp=$(echo "$common_params" | jq -r '.tp') ++ port=$(echo "$common_params" | jq -r '.port') ++ server_args=$(json2args "$server_params") ++ ++ if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then ++ echo "Key 'fp8' exists in common params. Use neuralmagic fp8 model for convenience." ++ model=$(echo "$common_params" | jq -r '.neuralmagic_quantized_model') ++ server_command="python3 \ ++ -m sglang.launch_server \ ++ --tp $tp \ ++ --model-path $model \ ++ --port $port \ ++ $server_args" ++ else ++ echo "Key 'fp8' does not exist in common params." ++ server_command="python3 \ ++ -m sglang.launch_server \ ++ --tp $tp \ ++ --model-path $model \ ++ --port $port \ ++ $server_args" ++ fi ++ ++ # run the server ++ echo "Server command: $server_command" ++ eval "$server_command" & ++} ++ ++launch_vllm_server() { ++ ++ export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') ++ ++ model=$(echo "$common_params" | jq -r '.model') ++ tp=$(echo "$common_params" | jq -r '.tp') ++ port=$(echo "$common_params" | jq -r '.port') ++ server_args=$(json2args "$server_params") ++ ++ if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then ++ echo "Key 'fp8' exists in common params. Use neuralmagic fp8 model for convenience." ++ model=$(echo "$common_params" | jq -r '.neuralmagic_quantized_model') ++ server_command="python3 \ ++ -m vllm.entrypoints.openai.api_server \ ++ -tp $tp \ ++ --model $model \ ++ --port $port \ ++ $server_args" ++ else ++ echo "Key 'fp8' does not exist in common params." ++ server_command="python3 \ ++ -m vllm.entrypoints.openai.api_server \ ++ -tp $tp \ ++ --model $model \ ++ --port $port \ ++ $server_args" ++ fi ++ ++ # run the server ++ echo "Server command: $server_command" ++ eval "$server_command" & ++} ++ ++main() { ++ ++ if [[ "$CURRENT_LLM_SERVING_ENGINE" == "trt" ]]; then ++ launch_trt_server ++ fi ++ ++ if [[ "$CURRENT_LLM_SERVING_ENGINE" == "tgi" ]]; then ++ launch_tgi_server ++ fi ++ ++ if [[ "$CURRENT_LLM_SERVING_ENGINE" == "lmdeploy" ]]; then ++ launch_lmdeploy_server ++ fi ++ ++ if [[ "$CURRENT_LLM_SERVING_ENGINE" == "sglang" ]]; then ++ launch_sglang_server ++ fi ++ ++ if [[ "$CURRENT_LLM_SERVING_ENGINE" == *"vllm"* ]]; then ++ launch_vllm_server ++ fi ++} ++ ++main +diff --git a/.buildkite/nightly-benchmarks/scripts/nightly-annotate.sh b/.buildkite/nightly-benchmarks/scripts/nightly-annotate.sh +new file mode 100644 +index 0000000..686f70d +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/scripts/nightly-annotate.sh +@@ -0,0 +1,78 @@ ++#!/bin/bash ++ ++set -ex ++set -o pipefail ++ ++ ++main() { ++ ++ (which wget && which curl) || (apt-get update && apt-get install -y wget curl) ++ (which jq) || (apt-get update && apt-get -y install jq) ++ (which zip) || (apt-get install -y zip) ++ ++ if [ ! -f /workspace/buildkite-agent ]; then ++ echo "buildkite-agent binary not found. Skip plotting the results." ++ exit 0 ++ fi ++ ++ # initial annotation ++ #description="$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/nightly-descriptions.md" ++ ++ # download results ++ cd "$VLLM_SOURCE_CODE_LOC/benchmarks" ++ mkdir -p results/ ++ /workspace/buildkite-agent artifact download 'results/*nightly_results.json' results/ ++ ls ++ ls results/ ++ ++ # upload benchmark results ++ zip -r results.zip results/ ++ /workspace/buildkite-agent artifact upload "results.zip" ++ ++ # upload benchmarking scripts ++ cd "$VLLM_SOURCE_CODE_LOC/" ++ zip -r nightly-benchmarks.zip .buildkite/ benchmarks/ ++ /workspace/buildkite-agent artifact upload "nightly-benchmarks.zip" ++ ++ cd "$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/" ++ # upload benchmarking pipeline ++ /workspace/buildkite-agent artifact upload "nightly-pipeline.yaml" ++ ++ cd "$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/" ++ /workspace/buildkite-agent annotate --style "success" --context "nightly-benchmarks-results" --append < nightly-annotation.md ++ ++ ++ ++ # The figures should be genereated by a separate process outside the CI/CD pipeline ++ ++ # # generate figures ++ # python3 -m pip install tabulate pandas matplotlib ++ ++ # python3 $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py \ ++ # --description $description \ ++ # --results-folder results/ ++ ++ ++ # python3 $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/plot-nightly-results.py \ ++ # --description $description \ ++ # --results-folder results/ \ ++ # --dataset sharegpt ++ ++ # python3 $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/plot-nightly-results.py \ ++ # --description $description \ ++ # --results-folder results/ \ ++ # --dataset sonnet_2048_128 ++ ++ # python3 $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/plot-nightly-results.py \ ++ # --description $description \ ++ # --results-folder results/ \ ++ # --dataset sonnet_128_2048 ++ ++ # # upload results and figures ++ # /workspace/buildkite-agent artifact upload "nightly_results*.png" ++ # /workspace/buildkite-agent artifact upload $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/nightly-pipeline.yaml ++ # /workspace/buildkite-agent artifact upload $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/tests/nightly-tests.json ++ # /workspace/buildkite-agent annotate --style "success" --context "nightly-benchmarks-results" --append < nightly_results.md ++} ++ ++main "$@" +diff --git a/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh b/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh +new file mode 100644 +index 0000000..3f38cf5 +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh +@@ -0,0 +1,355 @@ ++#!/bin/bash ++ ++set -o pipefail ++set -x ++ ++check_gpus() { ++ # check the number of GPUs and GPU type. ++ declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l) ++ if [[ $gpu_count -gt 0 ]]; then ++ echo "GPU found." ++ else ++ echo "Need at least 1 GPU to run benchmarking." ++ exit 1 ++ fi ++ declare -g gpu_type="$(nvidia-smi --query-gpu=name --format=csv,noheader | awk '{print $2}')" ++ echo "GPU type is $gpu_type" ++} ++ ++check_hf_token() { ++ # check if HF_TOKEN is available and valid ++ if [[ -z "$HF_TOKEN" ]]; then ++ echo "Error: HF_TOKEN is not set." ++ exit 1 ++ elif [[ ! "$HF_TOKEN" =~ ^hf_ ]]; then ++ echo "Error: HF_TOKEN does not start with 'hf_'." ++ exit 1 ++ else ++ echo "HF_TOKEN is set and valid." ++ fi ++} ++ ++ ++upload_to_buildkite() { ++ # upload the benchmarking results to buildkite ++ ++ # if the agent binary is not found, skip uploading the results, exit 0 ++ if [ ! -f /workspace/buildkite-agent ]; then ++ echo "buildkite-agent binary not found. Skip uploading the results." ++ return 0 ++ fi ++ # /workspace/buildkite-agent annotate --style "success" --context "benchmark-results" --append < $RESULTS_FOLDER/${CURRENT_LLM_SERVING_ENGINE}_nightly_results.md ++ /workspace/buildkite-agent artifact upload "$RESULTS_FOLDER/*" ++} ++ ++ ++get_current_llm_serving_engine() { ++ ++ if which lmdeploy >/dev/null; then ++ echo "Container: lmdeploy" ++ export CURRENT_LLM_SERVING_ENGINE=lmdeploy ++ return ++ fi ++ ++ if [ -e /tgi-entrypoint.sh ]; then ++ echo "Container: tgi" ++ export CURRENT_LLM_SERVING_ENGINE=tgi ++ return ++ fi ++ ++ if which trtllm-build >/dev/null; then ++ echo "Container: tensorrt-llm" ++ export CURRENT_LLM_SERVING_ENGINE=trt ++ return ++ fi ++ ++ if [ -e /sgl-workspace ]; then ++ echo "Container: sglang" ++ export CURRENT_LLM_SERVING_ENGINE=sglang ++ return ++ fi ++ ++ if [ -e /vllm-workspace ]; then ++ echo "Container: vllm" ++ # move to a completely irrelevant directory, to avoid import vllm from current folder ++ export CURRENT_LLM_SERVING_ENGINE=vllm ++ ++ return ++ fi ++} ++ ++json2args() { ++ # transforms the JSON string to command line args, and '_' is replaced to '-' ++ # example: ++ # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } ++ # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 ++ local json_string=$1 ++ local args=$( ++ echo "$json_string" | jq -r ' ++ to_entries | ++ map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | ++ join(" ") ++ ' ++ ) ++ echo "$args" ++} ++ ++kill_gpu_processes() { ++ pkill -f python ++ pkill -f python3 ++ pkill -f tritonserver ++ pkill -f pt_main_thread ++ pkill -f text-generation ++ pkill -f lmdeploy ++ ++ while [ "$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1)" -ge 1000 ]; do ++ sleep 1 ++ done ++} ++ ++wait_for_server() { ++ # wait for vllm server to start ++ # return 1 if vllm server crashes ++ timeout 1200 bash -c ' ++ until curl -s localhost:8000/v1/completions > /dev/null; do ++ sleep 1 ++ done' && return 0 || return 1 ++} ++ ++ensure_installed() { ++ # Ensure that the given command is installed by apt-get ++ local cmd=$1 ++ if ! which "$cmd" >/dev/null; then ++ apt-get update && apt-get install -y "$cmd" ++ fi ++} ++ ++run_serving_tests() { ++ # run serving tests using `benchmark_serving.py` ++ # $1: a json file specifying serving test cases ++ ++ local serving_test_file ++ serving_test_file=$1 ++ ++ # Iterate over serving tests ++ jq -c '.[]' "$serving_test_file" | while read -r params; do ++ # get the test name, and append the GPU type back to it. ++ test_name=$(echo "$params" | jq -r '.test_name') ++ ++ # if TEST_SELECTOR is set, only run the test cases that match the selector ++ if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then ++ echo "Skip test case $test_name." ++ continue ++ fi ++ ++ # prepend the current serving engine to the test name ++ test_name=${CURRENT_LLM_SERVING_ENGINE}_${test_name} ++ ++ # get common parameters ++ common_params=$(echo "$params" | jq -r '.common_parameters') ++ model=$(echo "$common_params" | jq -r '.model') ++ tp=$(echo "$common_params" | jq -r '.tp') ++ dataset_name=$(echo "$common_params" | jq -r '.dataset_name') ++ dataset_path=$(echo "$common_params" | jq -r '.dataset_path') ++ port=$(echo "$common_params" | jq -r '.port') ++ num_prompts=$(echo "$common_params" | jq -r '.num_prompts') ++ reuse_server=$(echo "$common_params" | jq -r '.reuse_server') ++ ++ # get client and server arguments ++ server_params=$(echo "$params" | jq -r ".${CURRENT_LLM_SERVING_ENGINE}_server_parameters") ++ client_params=$(echo "$params" | jq -r ".${CURRENT_LLM_SERVING_ENGINE}_client_parameters") ++ client_args=$(json2args "$client_params") ++ qps_list=$(echo "$params" | jq -r '.qps_list') ++ qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') ++ echo "Running over qps list $qps_list" ++ ++ # check if there is enough GPU to run the test ++ if [[ $gpu_count -lt $tp ]]; then ++ echo "Required num-shard $tp but only $gpu_count GPU found. Skip testcase $test_name." ++ continue ++ fi ++ ++ if [[ $reuse_server == "true" ]]; then ++ echo "Reuse previous server for test case $test_name" ++ else ++ kill_gpu_processes ++ bash "$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/launch-server.sh" \ ++ "$server_params" "$common_params" ++ fi ++ ++ if wait_for_server; then ++ echo "" ++ echo "$CURRENT_LLM_SERVING_ENGINE server is up and running." ++ else ++ echo "" ++ echo "$CURRENT_LLM_SERVING_ENGINE failed to start within the timeout period." ++ break ++ fi ++ ++ # prepare tokenizer ++ # this is required for lmdeploy. ++ cd "$VLLM_SOURCE_CODE_LOC/benchmarks" ++ rm -rf /tokenizer_cache ++ mkdir /tokenizer_cache ++ python3 ../.buildkite/nightly-benchmarks/scripts/download-tokenizer.py \ ++ --model "$model" \ ++ --cachedir /tokenizer_cache ++ cd "$VLLM_SOURCE_CODE_LOC/benchmarks" ++ ++ ++ # change model name for lmdeploy (it will not follow standard hf name) ++ if [[ "$CURRENT_LLM_SERVING_ENGINE" == "lmdeploy" ]]; then ++ model=$(python ../.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py) ++ fi ++ ++ # iterate over different QPS ++ for qps in $qps_list; do ++ # remove the surrounding single quote from qps ++ if [[ "$qps" == *"inf"* ]]; then ++ echo "qps was $qps" ++ qps="inf" ++ echo "now qps is $qps" ++ fi ++ ++ new_test_name=$test_name"_qps_"$qps ++ ++ backend=$CURRENT_LLM_SERVING_ENGINE ++ ++ if [[ $backend = "trt" ]]; then ++ backend="tensorrt-llm" ++ fi ++ ++ if [[ "$backend" == *"vllm"* ]]; then ++ backend="vllm" ++ fi ++ ++ if [[ "$dataset_name" = "sharegpt" ]]; then ++ ++ client_command="python3 benchmark_serving.py \ ++ --backend $backend \ ++ --tokenizer /tokenizer_cache \ ++ --model $model \ ++ --dataset-name $dataset_name \ ++ --dataset-path $dataset_path \ ++ --num-prompts $num_prompts \ ++ --port $port \ ++ --save-result \ ++ --result-dir $RESULTS_FOLDER \ ++ --result-filename ${new_test_name}.json \ ++ --request-rate $qps \ ++ --ignore-eos \ ++ $client_args" ++ ++ elif [[ "$dataset_name" = "sonnet" ]]; then ++ ++ sonnet_input_len=$(echo "$common_params" | jq -r '.sonnet_input_len') ++ sonnet_output_len=$(echo "$common_params" | jq -r '.sonnet_output_len') ++ sonnet_prefix_len=$(echo "$common_params" | jq -r '.sonnet_prefix_len') ++ ++ client_command="python3 benchmark_serving.py \ ++ --backend $backend \ ++ --tokenizer /tokenizer_cache \ ++ --model $model \ ++ --dataset-name $dataset_name \ ++ --dataset-path $dataset_path \ ++ --num-prompts $num_prompts \ ++ --sonnet-input-len $sonnet_input_len \ ++ --sonnet-output-len $sonnet_output_len \ ++ --sonnet-prefix-len $sonnet_prefix_len \ ++ --port $port \ ++ --save-result \ ++ --result-dir $RESULTS_FOLDER \ ++ --result-filename ${new_test_name}.json \ ++ --request-rate $qps \ ++ --ignore-eos \ ++ $client_args" ++ ++ else ++ ++ echo "The dataset name must be either 'sharegpt' or 'sonnet'. Got $dataset_name." ++ exit 1 ++ ++ fi ++ ++ ++ ++ echo "Running test case $test_name with qps $qps" ++ echo "Client command: $client_command" ++ ++ eval "$client_command" ++ ++ server_command="None" ++ ++ # record the benchmarking commands ++ jq_output=$(jq -n \ ++ --arg server "$server_command" \ ++ --arg client "$client_command" \ ++ --arg gpu "$gpu_type" \ ++ --arg engine "$CURRENT_LLM_SERVING_ENGINE" \ ++ '{ ++ server_command: $server, ++ client_command: $client, ++ gpu_type: $gpu, ++ engine: $engine ++ }') ++ echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands" ++ ++ done ++ ++ done ++ ++ kill_gpu_processes ++} ++ ++ ++prepare_dataset() { ++ ++ # download sharegpt dataset ++ cd "$VLLM_SOURCE_CODE_LOC/benchmarks" ++ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json ++ ++ # duplicate sonnet by 4x, to allow benchmarking with input length 2048 ++ cd "$VLLM_SOURCE_CODE_LOC/benchmarks" ++ echo "" > sonnet_4x.txt ++ for _ in {1..4} ++ do ++ cat sonnet.txt >> sonnet_4x.txt ++ done ++ ++} ++ ++main() { ++ ++ # check if the environment variable is successfully injected from yaml ++ ++ check_gpus ++ check_hf_token ++ get_current_llm_serving_engine ++ ++ pip install -U transformers ++ ++ # check storage ++ df -h ++ ++ ensure_installed wget ++ ensure_installed curl ++ ensure_installed jq ++ ++ prepare_dataset ++ ++ cd "$VLLM_SOURCE_CODE_LOC/benchmarks" ++ declare -g RESULTS_FOLDER=results/ ++ mkdir -p $RESULTS_FOLDER ++ BENCHMARK_ROOT="$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/" ++ ++ # run the test ++ run_serving_tests "$BENCHMARK_ROOT/tests/nightly-tests.json" ++ ++ # upload benchmark results to buildkite ++ python3 -m pip install tabulate pandas ++ python3 "$BENCHMARK_ROOT/scripts/summary-nightly-results.py" ++ upload_to_buildkite ++ ++} ++ ++main "$@" +diff --git a/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh +new file mode 100644 +index 0000000..0d16a83 +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh +@@ -0,0 +1,377 @@ ++#!/bin/bash ++ ++# This script should be run inside the CI process ++# This script assumes that we are already inside the vllm/ directory ++# Benchmarking results will be available inside vllm/benchmarks/results/ ++ ++# Do not set -e, as the mixtral 8x22B model tends to crash occasionally ++# and we still want to see other benchmarking results even when mixtral crashes. ++set -x ++set -o pipefail ++ ++check_gpus() { ++ # check the number of GPUs and GPU type. ++ declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l) ++ if [[ $gpu_count -gt 0 ]]; then ++ echo "GPU found." ++ else ++ echo "Need at least 1 GPU to run benchmarking." ++ exit 1 ++ fi ++ declare -g gpu_type=$(nvidia-smi --query-gpu=name --format=csv,noheader | awk '{print $2}') ++ echo "GPU type is $gpu_type" ++} ++ ++check_hf_token() { ++ # check if HF_TOKEN is available and valid ++ if [[ -z "$HF_TOKEN" ]]; then ++ echo "Error: HF_TOKEN is not set." ++ exit 1 ++ elif [[ ! "$HF_TOKEN" =~ ^hf_ ]]; then ++ echo "Error: HF_TOKEN does not start with 'hf_'." ++ exit 1 ++ else ++ echo "HF_TOKEN is set and valid." ++ fi ++} ++ ++ensure_sharegpt_downloaded() { ++ local FILE=ShareGPT_V3_unfiltered_cleaned_split.json ++ if [ ! -f "$FILE" ]; then ++ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/$FILE ++ else ++ echo "$FILE already exists." ++ fi ++} ++ ++json2args() { ++ # transforms the JSON string to command line args, and '_' is replaced to '-' ++ # example: ++ # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } ++ # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 ++ local json_string=$1 ++ local args=$( ++ echo "$json_string" | jq -r ' ++ to_entries | ++ map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | ++ join(" ") ++ ' ++ ) ++ echo "$args" ++} ++ ++wait_for_server() { ++ # wait for vllm server to start ++ # return 1 if vllm server crashes ++ timeout 1200 bash -c ' ++ until curl -X POST localhost:8000/v1/completions; do ++ sleep 1 ++ done' && return 0 || return 1 ++} ++ ++kill_processes_launched_by_current_bash() { ++ # Kill all python processes launched from current bash script ++ current_shell_pid=$$ ++ processes=$(ps -eo pid,ppid,command | awk -v ppid="$current_shell_pid" -v proc="$1" '$2 == ppid && $3 ~ proc {print $1}') ++ if [ -n "$processes" ]; then ++ echo "Killing the following processes matching '$1':" ++ echo "$processes" ++ echo "$processes" | xargs kill -9 ++ else ++ echo "No processes found matching '$1'." ++ fi ++} ++ ++kill_gpu_processes() { ++ ++ ps -aux ++ lsof -t -i:8000 | xargs -r kill -9 ++ pgrep python3 | xargs -r kill -9 ++ ++ ++ # wait until GPU memory usage smaller than 1GB ++ while [ "$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1)" -ge 1000 ]; do ++ sleep 1 ++ done ++ ++ # remove vllm config file ++ rm -rf ~/.config/vllm ++ ++} ++ ++upload_to_buildkite() { ++ # upload the benchmarking results to buildkite ++ ++ # if the agent binary is not found, skip uploading the results, exit 0 ++ # Check if buildkite-agent is available in the PATH or at /workspace/buildkite-agent ++ if command -v buildkite-agent >/dev/null 2>&1; then ++ BUILDKITE_AGENT_COMMAND="buildkite-agent" ++ elif [ -f /workspace/buildkite-agent ]; then ++ BUILDKITE_AGENT_COMMAND="/workspace/buildkite-agent" ++ else ++ echo "buildkite-agent binary not found. Skip uploading the results." ++ return 0 ++ fi ++ ++ # Use the determined command to annotate and upload artifacts ++ $BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" < "$RESULTS_FOLDER/benchmark_results.md" ++ $BUILDKITE_AGENT_COMMAND artifact upload "$RESULTS_FOLDER/*" ++} ++ ++run_latency_tests() { ++ # run latency tests using `benchmark_latency.py` ++ # $1: a json file specifying latency test cases ++ ++ local latency_test_file ++ latency_test_file=$1 ++ ++ # Iterate over latency tests ++ jq -c '.[]' "$latency_test_file" | while read -r params; do ++ # get the test name, and append the GPU type back to it. ++ test_name=$(echo "$params" | jq -r '.test_name') ++ if [[ ! "$test_name" =~ ^latency_ ]]; then ++ echo "In latency-test.json, test_name must start with \"latency_\"." ++ exit 1 ++ fi ++ ++ # if TEST_SELECTOR is set, only run the test cases that match the selector ++ if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then ++ echo "Skip test case $test_name." ++ continue ++ fi ++ ++ # get arguments ++ latency_params=$(echo "$params" | jq -r '.parameters') ++ latency_args=$(json2args "$latency_params") ++ ++ # check if there is enough GPU to run the test ++ tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size') ++ if [[ $gpu_count -lt $tp ]]; then ++ echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." ++ continue ++ fi ++ ++ latency_command="python3 benchmark_latency.py \ ++ --output-json $RESULTS_FOLDER/${test_name}.json \ ++ $latency_args" ++ ++ echo "Running test case $test_name" ++ echo "Latency command: $latency_command" ++ ++ # recoding benchmarking command ang GPU command ++ jq_output=$(jq -n \ ++ --arg latency "$latency_command" \ ++ --arg gpu "$gpu_type" \ ++ '{ ++ latency_command: $latency, ++ gpu_type: $gpu ++ }') ++ echo "$jq_output" >"$RESULTS_FOLDER/$test_name.commands" ++ ++ # run the benchmark ++ eval "$latency_command" ++ ++ kill_gpu_processes ++ ++ done ++} ++ ++run_throughput_tests() { ++ # run throughput tests using `benchmark_throughput.py` ++ # $1: a json file specifying throughput test cases ++ ++ local throughput_test_file ++ throughput_test_file=$1 ++ ++ # Iterate over throughput tests ++ jq -c '.[]' "$throughput_test_file" | while read -r params; do ++ # get the test name, and append the GPU type back to it. ++ test_name=$(echo "$params" | jq -r '.test_name') ++ if [[ ! "$test_name" =~ ^throughput_ ]]; then ++ echo "In throughput-test.json, test_name must start with \"throughput_\"." ++ exit 1 ++ fi ++ ++ # if TEST_SELECTOR is set, only run the test cases that match the selector ++ if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then ++ echo "Skip test case $test_name." ++ continue ++ fi ++ ++ # get arguments ++ throughput_params=$(echo "$params" | jq -r '.parameters') ++ throughput_args=$(json2args "$throughput_params") ++ ++ # check if there is enough GPU to run the test ++ tp=$(echo "$throughput_params" | jq -r '.tensor_parallel_size') ++ if [[ $gpu_count -lt $tp ]]; then ++ echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." ++ continue ++ fi ++ ++ throughput_command="python3 benchmark_throughput.py \ ++ --output-json $RESULTS_FOLDER/${test_name}.json \ ++ $throughput_args" ++ ++ echo "Running test case $test_name" ++ echo "Throughput command: $throughput_command" ++ # recoding benchmarking command ang GPU command ++ jq_output=$(jq -n \ ++ --arg command "$throughput_command" \ ++ --arg gpu "$gpu_type" \ ++ '{ ++ throughput_command: $command, ++ gpu_type: $gpu ++ }') ++ echo "$jq_output" >"$RESULTS_FOLDER/$test_name.commands" ++ ++ # run the benchmark ++ eval "$throughput_command" ++ ++ kill_gpu_processes ++ ++ done ++} ++ ++run_serving_tests() { ++ # run serving tests using `benchmark_serving.py` ++ # $1: a json file specifying serving test cases ++ ++ local serving_test_file ++ serving_test_file=$1 ++ ++ # Iterate over serving tests ++ jq -c '.[]' "$serving_test_file" | while read -r params; do ++ # get the test name, and append the GPU type back to it. ++ test_name=$(echo "$params" | jq -r '.test_name') ++ if [[ ! "$test_name" =~ ^serving_ ]]; then ++ echo "In serving-test.json, test_name must start with \"serving_\"." ++ exit 1 ++ fi ++ ++ # if TEST_SELECTOR is set, only run the test cases that match the selector ++ if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then ++ echo "Skip test case $test_name." ++ continue ++ fi ++ ++ # get client and server arguments ++ server_params=$(echo "$params" | jq -r '.server_parameters') ++ client_params=$(echo "$params" | jq -r '.client_parameters') ++ server_args=$(json2args "$server_params") ++ client_args=$(json2args "$client_params") ++ qps_list=$(echo "$params" | jq -r '.qps_list') ++ qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') ++ echo "Running over qps list $qps_list" ++ ++ # check if there is enough GPU to run the test ++ tp=$(echo "$server_params" | jq -r '.tensor_parallel_size') ++ if [[ $gpu_count -lt $tp ]]; then ++ echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." ++ continue ++ fi ++ ++ # check if server model and client model is aligned ++ server_model=$(echo "$server_params" | jq -r '.model') ++ client_model=$(echo "$client_params" | jq -r '.model') ++ if [[ $server_model != "$client_model" ]]; then ++ echo "Server model and client model must be the same. Skip testcase $test_name." ++ continue ++ fi ++ ++ server_command="python3 \ ++ -m vllm.entrypoints.openai.api_server \ ++ $server_args" ++ ++ # run the server ++ echo "Running test case $test_name" ++ echo "Server command: $server_command" ++ bash -c "$server_command" & ++ server_pid=$! ++ ++ # wait until the server is alive ++ if wait_for_server; then ++ echo "" ++ echo "vllm server is up and running." ++ else ++ echo "" ++ echo "vllm failed to start within the timeout period." ++ fi ++ ++ # iterate over different QPS ++ for qps in $qps_list; do ++ # remove the surrounding single quote from qps ++ if [[ "$qps" == *"inf"* ]]; then ++ echo "qps was $qps" ++ qps="inf" ++ echo "now qps is $qps" ++ fi ++ ++ new_test_name=$test_name"_qps_"$qps ++ ++ client_command="python3 benchmark_serving.py \ ++ --save-result \ ++ --result-dir $RESULTS_FOLDER \ ++ --result-filename ${new_test_name}.json \ ++ --request-rate $qps \ ++ $client_args" ++ ++ echo "Running test case $test_name with qps $qps" ++ echo "Client command: $client_command" ++ ++ bash -c "$client_command" ++ ++ # record the benchmarking commands ++ jq_output=$(jq -n \ ++ --arg server "$server_command" \ ++ --arg client "$client_command" \ ++ --arg gpu "$gpu_type" \ ++ '{ ++ server_command: $server, ++ client_command: $client, ++ gpu_type: $gpu ++ }') ++ echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands" ++ ++ done ++ ++ # clean up ++ kill -9 $server_pid ++ kill_gpu_processes ++ done ++} ++ ++main() { ++ check_gpus ++ check_hf_token ++ ++ # dependencies ++ (which wget && which curl) || (apt-get update && apt-get install -y wget curl) ++ (which jq) || (apt-get update && apt-get -y install jq) ++ (which lsof) || (apt-get update && apt-get install -y lsof) ++ ++ # get the current IP address, required by benchmark_serving.py ++ export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') ++ # turn of the reporting of the status of each request, to clean up the terminal output ++ export VLLM_LOG_LEVEL="WARNING" ++ ++ # prepare for benchmarking ++ cd benchmarks || exit 1 ++ ensure_sharegpt_downloaded ++ declare -g RESULTS_FOLDER=results/ ++ mkdir -p $RESULTS_FOLDER ++ QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ ++ ++ # benchmarking ++ run_serving_tests $QUICK_BENCHMARK_ROOT/tests/serving-tests.json ++ run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json ++ run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json ++ ++ # postprocess benchmarking results ++ pip install tabulate pandas ++ python3 $QUICK_BENCHMARK_ROOT/scripts/convert-results-json-to-markdown.py ++ ++ upload_to_buildkite ++} ++ ++main "$@" +diff --git a/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py b/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py +new file mode 100644 +index 0000000..92d6fad +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py +@@ -0,0 +1,83 @@ ++import datetime ++import json ++import os ++from pathlib import Path ++ ++import pandas as pd ++from tabulate import tabulate ++ ++results_folder = Path("results/") ++ ++# serving results and the keys that will be printed into markdown ++serving_results = [] ++serving_column_mapping = { ++ "test_name": "Test name", ++ "gpu_type": "GPU", ++ "completed": "Successful req.", ++ "request_throughput": "Tput (req/s)", ++ "mean_ttft_ms": "Mean TTFT (ms)", ++ "std_ttft_ms": "Std TTFT (ms)", ++ "median_ttft_ms": "Median TTFT (ms)", ++ "mean_itl_ms": "Mean ITL (ms)", ++ "std_itl_ms": "Std ITL (ms)", ++ "median_itl_ms": "Median ITL (ms)", ++ "mean_tpot_ms": "Mean TPOT (ms)", ++ "std_tpot_ms": "Std TPOT (ms)", ++ "median_tpot_ms": "Median TPOT (ms)", ++ "total_token_throughput": "Total Token Tput (tok/s)", ++ "output_throughput": "Output Tput (tok/s)", ++ "total_input_tokens": "Total input tokens", ++ "total_output_tokens": "Total output tokens", ++ "engine": "Engine", ++} ++ ++if __name__ == "__main__": ++ ++ # collect results ++ for test_file in results_folder.glob("*.json"): ++ ++ with open(test_file) as f: ++ raw_result = json.loads(f.read()) ++ ++ # attach the benchmarking command to raw_result ++ with open(test_file.with_suffix(".commands")) as f: ++ command = json.loads(f.read()) ++ raw_result.update(command) ++ ++ # update the test name of this result ++ raw_result.update({"test_name": test_file.stem}) ++ ++ # add the result to raw_result ++ serving_results.append(raw_result) ++ continue ++ ++ serving_results = pd.DataFrame.from_dict(serving_results) ++ ++ if not serving_results.empty: ++ serving_results = serving_results[list( ++ serving_column_mapping.keys())].rename( ++ columns=serving_column_mapping) ++ ++ serving_md_table_with_headers = tabulate(serving_results, ++ headers='keys', ++ tablefmt='pipe', ++ showindex=False) ++ # remove the first line of header ++ serving_md_table_lines = serving_md_table_with_headers.split('\n') ++ serving_md_table_without_header = '\n'.join(serving_md_table_lines[2:]) ++ ++ prefix = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ++ prefix = prefix + "_" + os.environ.get("CURRENT_LLM_SERVING_ENGINE") ++ ++ # document benchmarking results in markdown ++ with open(results_folder / f"{prefix}_nightly_results.md", "w") as f: ++ # document results with header. ++ # for those who wants to reproduce our benchmark. ++ f.write(serving_md_table_with_headers) ++ f.write('\n') ++ ++ # document benchmarking results in json ++ with open(results_folder / f"{prefix}_nightly_results.json", "w") as f: ++ ++ results = serving_results.to_dict(orient='records') ++ f.write(json.dumps(results)) +diff --git a/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh b/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh +new file mode 100644 +index 0000000..aa0f7ad +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh +@@ -0,0 +1,19 @@ ++#!/bin/sh ++TOKEN=$(curl -s -L "https://public.ecr.aws/token?service=public.ecr.aws&scope=repository:q9t5s3a7/vllm-ci-postmerge-repo:pull" | jq -r .token) ++URL="https://public.ecr.aws/v2/q9t5s3a7/vllm-ci-postmerge-repo/manifests/$BUILDKITE_COMMIT" ++ ++TIMEOUT_SECONDS=10 ++ ++retries=0 ++while [ $retries -lt 1000 ]; do ++ if [ "$(curl -s --max-time "$TIMEOUT_SECONDS" -L -H "Authorization: Bearer $TOKEN" -o /dev/null -w "%{http_code}" "$URL")" -eq 200 ]; then ++ exit 0 ++ fi ++ ++ echo "Waiting for image to be available..." ++ ++ retries=$((retries + 1)) ++ sleep 5 ++done ++ ++exit 1 +diff --git a/.buildkite/nightly-benchmarks/tests/latency-tests.json b/.buildkite/nightly-benchmarks/tests/latency-tests.json +new file mode 100644 +index 0000000..1841186 +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/tests/latency-tests.json +@@ -0,0 +1,32 @@ ++[ ++ { ++ "test_name": "latency_llama8B_tp1", ++ "parameters": { ++ "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", ++ "tensor_parallel_size": 1, ++ "load_format": "dummy", ++ "num_iters_warmup": 5, ++ "num_iters": 15 ++ } ++ }, ++ { ++ "test_name": "latency_llama70B_tp4", ++ "parameters": { ++ "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", ++ "tensor_parallel_size": 4, ++ "load_format": "dummy", ++ "num-iters-warmup": 5, ++ "num-iters": 15 ++ } ++ }, ++ { ++ "test_name": "latency_mixtral8x7B_tp2", ++ "parameters": { ++ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", ++ "tensor_parallel_size": 2, ++ "load_format": "dummy", ++ "num-iters-warmup": 5, ++ "num-iters": 15 ++ } ++ } ++] +\ No newline at end of file +diff --git a/.buildkite/nightly-benchmarks/tests/nightly-tests.json b/.buildkite/nightly-benchmarks/tests/nightly-tests.json +new file mode 100644 +index 0000000..fda1a7a +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/tests/nightly-tests.json +@@ -0,0 +1,323 @@ ++[ ++ { ++ "test_name": "llama8B_tp1_sharegpt", ++ "qps_list": [4,8,16,32,"inf"], ++ "common_parameters": { ++ "model": "meta-llama/Meta-Llama-3-8B-Instruct", ++ "tp": 1, ++ "dataset_name": "sharegpt", ++ "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", ++ "num_prompts": 500, ++ "port": 8000, ++ "reuse_server": false ++ }, ++ "lmdeploy_server_parameters": { ++ "dtype": "bfloat16" ++ }, ++ "lmdeploy_client_parameters": { ++ }, ++ "tgi_server_parameters": { ++ }, ++ "tgi_client_parameters": { ++ "endpoint": "/generate_stream" ++ }, ++ "trt_server_parameters": { ++ "model_type": "llama", ++ "model_dtype": "bfloat16", ++ "max_batch_size": 2048, ++ "max_input_len": 4096, ++ "max_seq_len": 6144, ++ "max_num_tokens": 16384, ++ "trt_llm_version": "v0.11.0" ++ }, ++ "trt_client_parameters": { ++ "endpoint": "/v2/models/ensemble/generate_stream" ++ }, ++ "vllm_server_parameters": { ++ "disable_log_stats": "", ++ "disable_log_requests": "", ++ "gpu_memory_utilization": 0.9, ++ "num_scheduler_steps": 10, ++ "max_num_seqs": 512, ++ "dtype": "bfloat16" ++ }, ++ "vllm_client_parameters": { ++ }, ++ "sglang_server_parameters": { ++ "disable_radix_cache": "", ++ "enable_torch_compile": "", ++ "dtype": "bfloat16" ++ }, ++ "sglang_client_parameters": { ++ } ++ }, ++ { ++ "test_name": "llama8B_tp1_sonnet_512_16", ++ "qps_list": [4,8,16,32,"inf"], ++ "common_parameters": { ++ "model": "meta-llama/Meta-Llama-3-8B-Instruct", ++ "tp": 1, ++ "dataset_name": "sonnet", ++ "dataset_path": "./sonnet_4x.txt", ++ "num_prompts": 500, ++ "port": 8000, ++ "sonnet_input_len": 512, ++ "sonnet_output_len": 16, ++ "sonnet_prefix_len": 50, ++ "reuse_server": true ++ }, ++ "lmdeploy_server_parameters": { ++ "dtype": "bfloat16" ++ }, ++ "lmdeploy_client_parameters": { ++ }, ++ "tgi_server_parameters": { ++ }, ++ "tgi_client_parameters": { ++ "endpoint": "/generate_stream" ++ }, ++ "trt_server_parameters": { ++ "model_type": "llama", ++ "model_dtype": "bfloat16", ++ "max_batch_size": 2048, ++ "max_input_len": 4096, ++ "max_seq_len": 6144, ++ "max_num_tokens": 16384, ++ "trt_llm_version": "v0.11.0" ++ }, ++ "trt_client_parameters": { ++ "endpoint": "/v2/models/ensemble/generate_stream" ++ }, ++ "vllm_server_parameters": { ++ "disable_log_stats": "", ++ "disable_log_requests": "", ++ "gpu_memory_utilization": 0.9, ++ "num_scheduler_steps": 10, ++ "max_num_seqs": 512, ++ "dtype": "bfloat16" ++ }, ++ "vllm_client_parameters": { ++ }, ++ "sglang_server_parameters": { ++ "disable_radix_cache": "", ++ "enable_torch_compile": "", ++ "dtype": "bfloat16" ++ }, ++ "sglang_client_parameters": { ++ } ++ }, ++ { ++ "test_name": "llama8B_tp1_sonnet_512_256", ++ "qps_list": [4,8,16,32,"inf"], ++ "common_parameters": { ++ "model": "meta-llama/Meta-Llama-3-8B-Instruct", ++ "tp": 1, ++ "dataset_name": "sonnet", ++ "dataset_path": "./sonnet_4x.txt", ++ "num_prompts": 500, ++ "port": 8000, ++ "sonnet_input_len": 512, ++ "sonnet_output_len": 256, ++ "sonnet_prefix_len": 50, ++ "reuse_server": true ++ }, ++ "lmdeploy_server_parameters": { ++ "dtype": "bfloat16" ++ }, ++ "lmdeploy_client_parameters": { ++ }, ++ "tgi_server_parameters": { ++ }, ++ "tgi_client_parameters": { ++ "endpoint": "/generate_stream" ++ }, ++ "trt_server_parameters": { ++ "model_type": "llama", ++ "model_dtype": "bfloat16", ++ "max_batch_size": 2048, ++ "max_input_len": 4096, ++ "max_seq_len": 6144, ++ "max_num_tokens": 16384, ++ "trt_llm_version": "v0.11.0" ++ }, ++ "trt_client_parameters": { ++ "endpoint": "/v2/models/ensemble/generate_stream" ++ }, ++ "vllm_server_parameters": { ++ "disable_log_stats": "", ++ "disable_log_requests": "", ++ "gpu_memory_utilization": 0.9, ++ "num_scheduler_steps": 10, ++ "max_num_seqs": 512, ++ "dtype": "bfloat16" ++ }, ++ "vllm_client_parameters": { ++ }, ++ "sglang_server_parameters": { ++ "disable_radix_cache": "", ++ "enable_torch_compile": "", ++ "dtype": "bfloat16" ++ }, ++ "sglang_client_parameters": { ++ } ++ }, ++ { ++ "test_name": "llama70B_tp4_sharegpt", ++ "qps_list": [4,8,16,32,"inf"], ++ "common_parameters": { ++ "model": "meta-llama/Meta-Llama-3-70B-Instruct", ++ "tp": 4, ++ "dataset_name": "sharegpt", ++ "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", ++ "num_prompts": 500, ++ "port": 8000, ++ "reuse_server": false ++ }, ++ "lmdeploy_server_parameters": { ++ "dtype": "bfloat16" ++ }, ++ "lmdeploy_client_parameters": { ++ }, ++ "tgi_server_parameters": { ++ }, ++ "tgi_client_parameters": { ++ "endpoint": "/generate_stream" ++ }, ++ "trt_server_parameters": { ++ "model_type": "llama", ++ "model_dtype": "bfloat16", ++ "max_batch_size": 2048, ++ "max_input_len": 4096, ++ "max_seq_len": 6144, ++ "max_num_tokens": 16384, ++ "trt_llm_version": "v0.11.0" ++ }, ++ "trt_client_parameters": { ++ "endpoint": "/v2/models/ensemble/generate_stream" ++ }, ++ "vllm_server_parameters": { ++ "disable_log_stats": "", ++ "disable_log_requests": "", ++ "gpu_memory_utilization": 0.9, ++ "num_scheduler_steps": 10, ++ "max_num_seqs": 512, ++ "dtype": "bfloat16" ++ }, ++ "vllm_client_parameters": { ++ }, ++ "sglang_server_parameters": { ++ "disable_radix_cache": "", ++ "dtype": "bfloat16" ++ }, ++ "sglang_client_parameters": { ++ } ++ }, ++ { ++ "test_name": "llama70B_tp4_sonnet_512_16", ++ "qps_list": [4,8,16,32,"inf"], ++ "common_parameters": { ++ "model": "meta-llama/Meta-Llama-3-70B-Instruct", ++ "tp": 4, ++ "dataset_name": "sonnet", ++ "dataset_path": "./sonnet_4x.txt", ++ "num_prompts": 500, ++ "port": 8000, ++ "sonnet_input_len": 512, ++ "sonnet_output_len": 16, ++ "sonnet_prefix_len": 50, ++ "reuse_server": true ++ }, ++ "lmdeploy_server_parameters": { ++ "dtype": "bfloat16" ++ }, ++ "lmdeploy_client_parameters": { ++ }, ++ "tgi_server_parameters": { ++ }, ++ "tgi_client_parameters": { ++ "endpoint": "/generate_stream" ++ }, ++ "trt_server_parameters": { ++ "model_type": "llama", ++ "model_dtype": "bfloat16", ++ "max_batch_size": 2048, ++ "max_input_len": 4096, ++ "max_seq_len": 6144, ++ "max_num_tokens": 16384, ++ "trt_llm_version": "v0.11.0" ++ }, ++ "trt_client_parameters": { ++ "endpoint": "/v2/models/ensemble/generate_stream" ++ }, ++ "vllm_server_parameters": { ++ "disable_log_stats": "", ++ "disable_log_requests": "", ++ "gpu_memory_utilization": 0.9, ++ "num_scheduler_steps": 10, ++ "max_num_seqs": 512, ++ "dtype": "bfloat16" ++ }, ++ "vllm_client_parameters": { ++ }, ++ "sglang_server_parameters": { ++ "disable_radix_cache": "", ++ "dtype": "bfloat16" ++ }, ++ "sglang_client_parameters": { ++ } ++ }, ++ { ++ "test_name": "llama70B_tp4_sonnet_512_256", ++ "qps_list": [4,8,16,32,"inf"], ++ "common_parameters": { ++ "model": "meta-llama/Meta-Llama-3-70B-Instruct", ++ "tp": 4, ++ "dataset_name": "sonnet", ++ "dataset_path": "./sonnet_4x.txt", ++ "num_prompts": 500, ++ "port": 8000, ++ "sonnet_input_len": 512, ++ "sonnet_output_len": 256, ++ "sonnet_prefix_len": 50, ++ "reuse_server": true ++ }, ++ "lmdeploy_server_parameters": { ++ "dtype": "bfloat16" ++ }, ++ "lmdeploy_client_parameters": { ++ }, ++ "tgi_server_parameters": { ++ }, ++ "tgi_client_parameters": { ++ "endpoint": "/generate_stream" ++ }, ++ "trt_server_parameters": { ++ "model_type": "llama", ++ "model_dtype": "bfloat16", ++ "max_batch_size": 2048, ++ "max_input_len": 4096, ++ "max_seq_len": 6144, ++ "max_num_tokens": 16384, ++ "trt_llm_version": "v0.11.0" ++ }, ++ "trt_client_parameters": { ++ "endpoint": "/v2/models/ensemble/generate_stream" ++ }, ++ "vllm_server_parameters": { ++ "disable_log_stats": "", ++ "disable_log_requests": "", ++ "gpu_memory_utilization": 0.9, ++ "num_scheduler_steps": 10, ++ "max_num_seqs": 512, ++ "dtype": "bfloat16" ++ }, ++ "vllm_client_parameters": { ++ }, ++ "sglang_server_parameters": { ++ "disable_radix_cache": "", ++ "dtype": "bfloat16" ++ }, ++ "sglang_client_parameters": { ++ } ++ } ++] +\ No newline at end of file +diff --git a/.buildkite/nightly-benchmarks/tests/serving-tests.json b/.buildkite/nightly-benchmarks/tests/serving-tests.json +new file mode 100644 +index 0000000..facb0ea +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/tests/serving-tests.json +@@ -0,0 +1,80 @@ ++[ ++ { ++ "test_name": "serving_llama8B_tp1_sharegpt", ++ "qps_list": [1, 4, 16, "inf"], ++ "server_parameters": { ++ "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", ++ "tensor_parallel_size": 1, ++ "swap_space": 16, ++ "disable_log_stats": "", ++ "disable_log_requests": "", ++ "load_format": "dummy" ++ }, ++ "client_parameters": { ++ "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", ++ "backend": "vllm", ++ "dataset_name": "sharegpt", ++ "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", ++ "num_prompts": 200 ++ } ++ }, ++ { ++ "test_name": "serving_llama70B_tp4_sharegpt", ++ "qps_list": [1, 4, 16, "inf"], ++ "server_parameters": { ++ "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", ++ "tensor_parallel_size": 4, ++ "swap_space": 16, ++ "disable_log_stats": "", ++ "disable_log_requests": "", ++ "load_format": "dummy" ++ }, ++ "client_parameters": { ++ "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", ++ "backend": "vllm", ++ "dataset_name": "sharegpt", ++ "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", ++ "num_prompts": 200 ++ } ++ }, ++ { ++ "test_name": "serving_mixtral8x7B_tp2_sharegpt", ++ "qps_list": [1, 4, 16, "inf"], ++ "server_parameters": { ++ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", ++ "tensor_parallel_size": 2, ++ "swap_space": 16, ++ "disable_log_stats": "", ++ "disable_log_requests": "", ++ "load_format": "dummy" ++ }, ++ "client_parameters": { ++ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", ++ "backend": "vllm", ++ "dataset_name": "sharegpt", ++ "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", ++ "num_prompts": 200 ++ } ++ }, ++ { ++ "test_name": "serving_llama70B_tp4_sharegpt_specdecode", ++ "qps_list": [2], ++ "server_parameters": { ++ "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", ++ "disable_log_requests": "", ++ "tensor_parallel_size": 4, ++ "swap_space": 16, ++ "speculative_model": "turboderp/Qwama-0.5B-Instruct", ++ "num_speculative_tokens": 4, ++ "speculative_draft_tensor_parallel_size": 1, ++ "use_v2_block_manager": "" ++ }, ++ "client_parameters": { ++ "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", ++ "backend": "vllm", ++ "dataset_name": "sharegpt", ++ "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", ++ "num_prompts": 200 ++ } ++ } ++] +diff --git a/.buildkite/nightly-benchmarks/tests/throughput-tests.json b/.buildkite/nightly-benchmarks/tests/throughput-tests.json +new file mode 100644 +index 0000000..91ef6d1 +--- /dev/null ++++ b/.buildkite/nightly-benchmarks/tests/throughput-tests.json +@@ -0,0 +1,35 @@ ++[ ++ { ++ "test_name": "throughput_llama8B_tp1", ++ "parameters": { ++ "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", ++ "tensor_parallel_size": 1, ++ "load_format": "dummy", ++ "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", ++ "num_prompts": 200, ++ "backend": "vllm" ++ } ++ }, ++ { ++ "test_name": "throughput_llama70B_tp4", ++ "parameters": { ++ "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", ++ "tensor_parallel_size": 4, ++ "load_format": "dummy", ++ "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", ++ "num_prompts": 200, ++ "backend": "vllm" ++ } ++ }, ++ { ++ "test_name": "throughput_mixtral8x7B_tp2", ++ "parameters": { ++ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", ++ "tensor_parallel_size": 2, ++ "load_format": "dummy", ++ "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", ++ "num_prompts": 200, ++ "backend": "vllm" ++ } ++ } ++] +\ No newline at end of file +diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml +new file mode 100644 +index 0000000..51618a2 +--- /dev/null ++++ b/.buildkite/release-pipeline.yaml +@@ -0,0 +1,72 @@ ++steps: ++ - label: "Build wheel - CUDA 12.1" ++ agents: ++ queue: cpu_queue_postmerge ++ commands: ++ - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.1.0 --tag vllm-ci:build-image --target build --progress plain ." ++ - "mkdir artifacts" ++ - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" ++ - "bash .buildkite/upload-wheels.sh" ++ env: ++ DOCKER_BUILDKIT: "1" ++ ++ # Note(simon): We can always build CUDA 11.8 wheel to ensure the build is working. ++ # However, this block can be uncommented to save some compute hours. ++ # - block: "Build CUDA 11.8 wheel" ++ # key: block-build-cu118-wheel ++ ++ - label: "Build wheel - CUDA 11.8" ++ # depends_on: block-build-cu118-wheel ++ agents: ++ queue: cpu_queue_postmerge ++ commands: ++ - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --tag vllm-ci:build-image --target build --progress plain ." ++ - "mkdir artifacts" ++ - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" ++ - "bash .buildkite/upload-wheels.sh" ++ env: ++ DOCKER_BUILDKIT: "1" ++ ++ - block: "Build release image" ++ depends_on: ~ ++ key: block-release-image-build ++ ++ - label: "Build release image" ++ depends_on: block-release-image-build ++ agents: ++ queue: cpu_queue_postmerge ++ commands: ++ - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" ++ - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.1.0 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain ." ++ - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" ++ ++ - label: "Build and publish TPU release image" ++ depends_on: ~ ++ if: build.env("NIGHTLY") == "1" ++ agents: ++ queue: tpu_queue_postmerge ++ commands: ++ - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --tag vllm/vllm-tpu:nightly --tag vllm/vllm-tpu:$BUILDKITE_COMMIT --progress plain -f Dockerfile.tpu ." ++ - "docker push vllm/vllm-tpu:nightly" ++ - "docker push vllm/vllm-tpu:$BUILDKITE_COMMIT" ++ plugins: ++ - docker-login#v3.0.0: ++ username: vllm ++ password-env: DOCKERHUB_TOKEN ++ env: ++ DOCKER_BUILDKIT: "1" ++ ++ - block: "Build CPU release image" ++ key: block-cpu-release-image-build ++ depends_on: ~ ++ ++ - label: "Build and publish CPU release image" ++ depends_on: block-cpu-release-image-build ++ agents: ++ queue: cpu_queue_postmerge ++ commands: ++ - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" ++ - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION --progress plain -f Dockerfile.cpu ." ++ - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION" ++ env: ++ DOCKER_BUILDKIT: "1" +diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh +index c04e05a..3515ccd 100644 +--- a/.buildkite/run-amd-test.sh ++++ b/.buildkite/run-amd-test.sh +@@ -1,10 +1,49 @@ +-# This script build the ROCm docker image and runs test inside it. +-set -ex ++#!/bin/bash ++ ++# This script runs test inside the corresponding ROCm docker container. ++set -o pipefail + + # Print ROCm version ++echo "--- Confirming Clean Initial State" ++while true; do ++ sleep 3 ++ if grep -q clean /opt/amdgpu/etc/gpu_state; then ++ echo "GPUs state is \"clean\"" ++ break ++ fi ++done ++ + echo "--- ROCm info" + rocminfo + ++# cleanup older docker images ++cleanup_docker() { ++ # Get Docker's root directory ++ docker_root=$(docker info -f '{{.DockerRootDir}}') ++ if [ -z "$docker_root" ]; then ++ echo "Failed to determine Docker root directory." ++ exit 1 ++ fi ++ echo "Docker root directory: $docker_root" ++ # Check disk usage of the filesystem where Docker's root directory is located ++ disk_usage=$(df "$docker_root" | tail -1 | awk '{print $5}' | sed 's/%//') ++ # Define the threshold ++ threshold=70 ++ if [ "$disk_usage" -gt "$threshold" ]; then ++ echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..." ++ # Remove dangling images (those that are not tagged and not used by any container) ++ docker image prune -f ++ # Remove unused volumes / force the system prune for old images as well. ++ docker volume prune -f && docker system prune --force --filter "until=72h" --all ++ echo "Docker images and volumes cleanup completed." ++ else ++ echo "Disk usage is below $threshold%. No cleanup needed." ++ fi ++} ++ ++# Call the cleanup docker function ++cleanup_docker ++ + echo "--- Resetting GPUs" + + echo "reset" > /opt/amdgpu/etc/gpu_state +@@ -17,28 +56,101 @@ while true; do + fi + done + +-echo "--- Building container" +-sha=$(git rev-parse --short HEAD) +-container_name=rocm_${sha} +-docker build \ +- -t ${container_name} \ +- -f Dockerfile.rocm \ +- --progress plain \ +- . ++echo "--- Pulling container" ++image_name="rocm/vllm-ci:${BUILDKITE_COMMIT}" ++container_name="rocm_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" ++docker pull "${image_name}" + + remove_docker_container() { +- docker rm -f ${container_name} || docker image rm -f ${container_name} || true ++ docker rm -f "${container_name}" || docker image rm -f "${image_name}" || true + } + trap remove_docker_container EXIT + + echo "--- Running container" + +-docker run \ ++HF_CACHE="$(realpath ~)/huggingface" ++mkdir -p "${HF_CACHE}" ++HF_MOUNT="/root/.cache/huggingface" ++ ++commands=$@ ++echo "Commands:$commands" ++#ignore certain kernels tests ++if [[ $commands == *" kernels "* ]]; then ++ commands="${commands} \ ++ --ignore=kernels/test_attention.py \ ++ --ignore=kernels/test_attention_selector.py \ ++ --ignore=kernels/test_blocksparse_attention.py \ ++ --ignore=kernels/test_causal_conv1d.py \ ++ --ignore=kernels/test_cutlass.py \ ++ --ignore=kernels/test_encoder_decoder_attn.py \ ++ --ignore=kernels/test_flash_attn.py \ ++ --ignore=kernels/test_flashinfer.py \ ++ --ignore=kernels/test_int8_quant.py \ ++ --ignore=kernels/test_machete_gemm.py \ ++ --ignore=kernels/test_mamba_ssm.py \ ++ --ignore=kernels/test_marlin_gemm.py \ ++ --ignore=kernels/test_moe.py \ ++ --ignore=kernels/test_prefix_prefill.py \ ++ --ignore=kernels/test_rand.py \ ++ --ignore=kernels/test_sampler.py" ++fi ++ ++#ignore certain Entrypoints tests ++if [[ $commands == *" entrypoints/openai "* ]]; then ++ commands=${commands//" entrypoints/openai "/" entrypoints/openai \ ++ --ignore=entrypoints/openai/test_accuracy.py \ ++ --ignore=entrypoints/openai/test_audio.py \ ++ --ignore=entrypoints/openai/test_encoder_decoder.py \ ++ --ignore=entrypoints/openai/test_embedding.py \ ++ --ignore=entrypoints/openai/test_oot_registration.py "} ++fi ++ ++PARALLEL_JOB_COUNT=8 ++# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs. ++if [[ $commands == *"--shard-id="* ]]; then ++ # assign job count as the number of shards used ++ commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "} ++ for GPU in $(seq 0 $(($PARALLEL_JOB_COUNT-1))); do ++ # assign shard-id for each shard ++ commands_gpu=${commands//"--shard-id= "/"--shard-id=${GPU} "} ++ echo "Shard ${GPU} commands:$commands_gpu" ++ docker run \ + --device /dev/kfd --device /dev/dri \ + --network host \ ++ --shm-size=16gb \ + --rm \ ++ -e HIP_VISIBLE_DEVICES="${GPU}" \ + -e HF_TOKEN \ +- --name ${container_name} \ +- ${container_name} \ +- /bin/bash -c $(echo $1 | sed "s/^'//" | sed "s/'$//") +- ++ -v "${HF_CACHE}:${HF_MOUNT}" \ ++ -e "HF_HOME=${HF_MOUNT}" \ ++ --name "${container_name}_${GPU}" \ ++ "${image_name}" \ ++ /bin/bash -c "${commands_gpu}" \ ++ |& while read -r line; do echo ">>Shard $GPU: $line"; done & ++ PIDS+=($!) ++ done ++ #wait for all processes to finish and collect exit codes ++ for pid in "${PIDS[@]}"; do ++ wait "${pid}" ++ STATUS+=($?) ++ done ++ for st in "${STATUS[@]}"; do ++ if [[ ${st} -ne 0 ]]; then ++ echo "One of the processes failed with $st" ++ exit "${st}" ++ fi ++ done ++else ++ docker run \ ++ --device /dev/kfd --device /dev/dri \ ++ --network host \ ++ --shm-size=16gb \ ++ --rm \ ++ -e HIP_VISIBLE_DEVICES=0 \ ++ -e HF_TOKEN \ ++ -v "${HF_CACHE}:${HF_MOUNT}" \ ++ -e "HF_HOME=${HF_MOUNT}" \ ++ --name "${container_name}" \ ++ "${image_name}" \ ++ /bin/bash -c "${commands}" ++fi +diff --git a/.buildkite/run-benchmarks.sh b/.buildkite/run-benchmarks.sh +index 7fbad1c..1641c1f 100644 +--- a/.buildkite/run-benchmarks.sh ++++ b/.buildkite/run-benchmarks.sh +@@ -1,3 +1,5 @@ ++#!/bin/bash ++ + # This script is run by buildkite to run the benchmarks and upload the results to buildkite + + set -ex +@@ -9,10 +11,10 @@ cd "$(dirname "${BASH_SOURCE[0]}")/.." + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + + # run python-based benchmarks and upload the result to buildkite +-python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt ++python3 benchmarks/benchmark_latency.py --output-json latency_results.json 2>&1 | tee benchmark_latency.txt + bench_latency_exit_code=$? + +-python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt ++python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --output-json throughput_results.json 2>&1 | tee benchmark_throughput.txt + bench_throughput_exit_code=$? + + # run server-based benchmarks and upload the result to buildkite +@@ -50,16 +52,16 @@ echo "### Serving Benchmarks" >> benchmark_results.md + sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line + echo "" >> benchmark_results.md + echo '```' >> benchmark_results.md +-tail -n 20 benchmark_serving.txt >> benchmark_results.md # last 20 lines ++tail -n 24 benchmark_serving.txt >> benchmark_results.md # last 24 lines + echo '```' >> benchmark_results.md + + # if the agent binary is not found, skip uploading the results, exit 0 +-if [ ! -f /workspace/buildkite-agent ]; then ++if [ ! -f /usr/bin/buildkite-agent ]; then + exit 0 + fi + + # upload the results to buildkite +-/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md ++buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md + + # exit with the exit code of the benchmarks + if [ $bench_latency_exit_code -ne 0 ]; then +@@ -74,4 +76,5 @@ if [ $bench_serving_exit_code -ne 0 ]; then + exit $bench_serving_exit_code + fi + +-/workspace/buildkite-agent artifact upload openai-*.json ++rm ShareGPT_V3_unfiltered_cleaned_split.json ++buildkite-agent artifact upload "*.json" +diff --git a/.buildkite/run-cpu-test-ppc64le.sh b/.buildkite/run-cpu-test-ppc64le.sh +new file mode 100644 +index 0000000..bc06838 +--- /dev/null ++++ b/.buildkite/run-cpu-test-ppc64le.sh +@@ -0,0 +1,14 @@ ++#!/bin/bash ++ ++# This script build the CPU docker image and run the offline inference inside the container. ++# It serves a sanity check for compilation and basic model usage. ++set -ex ++ ++# Setup cleanup ++remove_docker_container() { docker rm -f cpu-test || true; docker system prune -f; } ++trap remove_docker_container EXIT ++remove_docker_container ++ ++# Try building the docker image ++docker build -t cpu-test -f Dockerfile.ppc64le . ++ +diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh +index f187d1f..9925db7 100644 +--- a/.buildkite/run-cpu-test.sh ++++ b/.buildkite/run-cpu-test.sh +@@ -1,14 +1,88 @@ ++#!/bin/bash ++ + # This script build the CPU docker image and run the offline inference inside the container. + # It serves a sanity check for compilation and basic model usage. + set -ex + ++# allow to bind to different cores ++CORE_RANGE=${CORE_RANGE:-48-95} ++NUMA_NODE=${NUMA_NODE:-1} ++ + # Try building the docker image +-docker build -t cpu-test -f Dockerfile.cpu . ++numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build -t cpu-test-"$BUILDKITE_BUILD_NUMBER" -f Dockerfile.cpu . ++numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2 -f Dockerfile.cpu . + + # Setup cleanup +-remove_docker_container() { docker rm -f cpu-test || true; } ++remove_docker_container() { set -e; docker rm -f cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" || true; } + trap remove_docker_container EXIT + remove_docker_container + +-# Run the image and launch offline inference +-docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 examples/offline_inference.py ++# Run the image, setting --shm-size=4g for tensor parallel. ++docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \ ++ --cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER" ++docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \ ++ --cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2 ++ ++function cpu_tests() { ++ set -e ++ export NUMA_NODE=$2 ++ ++ # offline inference ++ docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" bash -c " ++ set -e ++ python3 examples/offline_inference/basic.py" ++ ++ # Run basic model test ++ docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " ++ set -e ++ pip install -r vllm/requirements-test.txt ++ pytest -v -s tests/models/decoder_only/language -m cpu_model ++ pytest -v -s tests/models/embedding/language -m cpu_model ++ pytest -v -s tests/models/encoder_decoder/language -m cpu_model ++ pytest -v -s tests/models/decoder_only/audio_language -m cpu_model ++ pytest -v -s tests/models/decoder_only/vision_language -m cpu_model" ++ ++ # Run compressed-tensor test ++ docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " ++ set -e ++ pytest -s -v \ ++ tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \ ++ tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token" ++ ++ # Run AWQ test ++ docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " ++ set -e ++ pytest -s -v \ ++ tests/quantization/test_ipex_quant.py" ++ ++ # Run chunked-prefill and prefix-cache test ++ docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " ++ set -e ++ pytest -s -v -k cpu_model \ ++ tests/basic_correctness/test_chunked_prefill.py" ++ ++ # online serving ++ docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " ++ set -e ++ export VLLM_CPU_KVCACHE_SPACE=10 ++ export VLLM_CPU_OMP_THREADS_BIND=$1 ++ python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half & ++ timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 ++ python3 benchmarks/benchmark_serving.py \ ++ --backend vllm \ ++ --dataset-name random \ ++ --model facebook/opt-125m \ ++ --num-prompts 20 \ ++ --endpoint /v1/completions \ ++ --tokenizer facebook/opt-125m" ++ ++ # Run multi-lora tests ++ docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " ++ set -e ++ pytest -s -v \ ++ tests/lora/test_qwen2vl.py" ++} ++ ++# All of CPU tests are expected to be finished less than 25 mins. ++export -f cpu_tests ++timeout 30m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +diff --git a/.buildkite/run-gh200-test.sh b/.buildkite/run-gh200-test.sh +new file mode 100644 +index 0000000..3e4e409 +--- /dev/null ++++ b/.buildkite/run-gh200-test.sh +@@ -0,0 +1,28 @@ ++#!/bin/bash ++ ++# This script build the GH200 docker image and run the offline inference inside the container. ++# It serves a sanity check for compilation and basic model usage. ++set -ex ++ ++# Skip the new torch installation during build since we are using the specified version for arm64 in the Dockerfile ++python3 use_existing_torch.py ++ ++# Try building the docker image ++DOCKER_BUILDKIT=1 docker build . \ ++ --target vllm-openai \ ++ --platform "linux/arm64" \ ++ -t gh200-test \ ++ --build-arg max_jobs=66 \ ++ --build-arg nvcc_threads=2 \ ++ --build-arg torch_cuda_arch_list="9.0+PTX" \ ++ --build-arg vllm_fa_cmake_gpu_arches="90-real" ++ ++# Setup cleanup ++remove_docker_container() { docker rm -f gh200-test || true; } ++trap remove_docker_container EXIT ++remove_docker_container ++ ++# Run the image and test offline inference ++docker run --name gh200-test --gpus=all --entrypoint="" gh200-test bash -c ' ++ python3 examples/offline_inference/basic.py ++' +diff --git a/.buildkite/run-hpu-test.sh b/.buildkite/run-hpu-test.sh +new file mode 100644 +index 0000000..8f3b082 +--- /dev/null ++++ b/.buildkite/run-hpu-test.sh +@@ -0,0 +1,16 @@ ++#!/bin/bash ++ ++# This script build the CPU docker image and run the offline inference inside the container. ++# It serves a sanity check for compilation and basic model usage. ++set -ex ++ ++# Try building the docker image ++docker build -t hpu-test-env -f Dockerfile.hpu . ++ ++# Setup cleanup ++remove_docker_container() { docker rm -f hpu-test || true; } ++trap remove_docker_container EXIT ++remove_docker_container ++ ++# Run the image and launch offline inference ++docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic.py +\ No newline at end of file +diff --git a/.buildkite/run-multi-node-test.sh b/.buildkite/run-multi-node-test.sh +new file mode 100644 +index 0000000..530bf90 +--- /dev/null ++++ b/.buildkite/run-multi-node-test.sh +@@ -0,0 +1,108 @@ ++#!/bin/bash ++ ++set -euox pipefail ++ ++if [[ $# -lt 4 ]]; then ++ echo "Usage: .buildkite/run-multi-node-test.sh WORKING_DIR NUM_NODES NUM_GPUS DOCKER_IMAGE COMMAND1 COMMAND2 ... COMMANDN" ++ exit 1 ++fi ++ ++WORKING_DIR=$1 ++NUM_NODES=$2 ++NUM_GPUS=$3 ++DOCKER_IMAGE=$4 ++ ++shift 4 ++COMMANDS=("$@") ++if [ ${#COMMANDS[@]} -ne "$NUM_NODES" ]; then ++ echo "The number of commands must be equal to the number of nodes." ++ echo "Number of nodes: $NUM_NODES" ++ echo "Number of commands: ${#COMMANDS[@]}" ++ exit 1 ++fi ++ ++echo "List of commands" ++for command in "${COMMANDS[@]}"; do ++ echo "$command" ++done ++ ++start_network() { ++ docker network create --subnet=192.168.10.0/24 docker-net ++} ++ ++start_nodes() { ++ for node in $(seq 0 $(($NUM_NODES-1))); do ++ GPU_DEVICES='"device=' ++ for node_gpu in $(seq 0 $(($NUM_GPUS - 1))); do ++ DEVICE_NUM=$(($node * $NUM_GPUS + $node_gpu)) ++ GPU_DEVICES+=$(($DEVICE_NUM)) ++ if [ "$node_gpu" -lt $(($NUM_GPUS - 1)) ]; then ++ GPU_DEVICES+=',' ++ fi ++ done ++ GPU_DEVICES+='"' ++ ++ # start the container in detached mode ++ # things to note: ++ # 1. --shm-size=10.24gb is required. don't use --ipc=host ++ # 2. pass HF_TOKEN to the container ++ # 3. map the huggingface cache directory to the container ++ # 3. assign ip addresses to the containers (head node: 192.168.10.10, worker nodes: ++ # starting from 192.168.10.11) ++ docker run -d --gpus "$GPU_DEVICES" --shm-size=10.24gb -e HF_TOKEN \ ++ -v ~/.cache/huggingface:/root/.cache/huggingface --name "node$node" \ ++ --network docker-net --ip 192.168.10.$((10 + $node)) --rm "$DOCKER_IMAGE" \ ++ /bin/bash -c "tail -f /dev/null" ++ ++ # organize containers into a ray cluster ++ if [ "$node" -eq 0 ]; then ++ # start the ray head node ++ docker exec -d "node$node" /bin/bash -c "ray start --head --port=6379 --block" ++ # wait for the head node to be ready ++ sleep 10 ++ else ++ # start the ray worker nodes, and connect them to the head node ++ docker exec -d "node$node" /bin/bash -c "ray start --address=192.168.10.10:6379 --block" ++ fi ++ done ++ ++ # wait for the cluster to be ready ++ sleep 10 ++ ++ # print the cluster status ++ docker exec node0 /bin/bash -c "ray status" ++} ++ ++run_nodes() { ++ # important: iterate in reverse order to start the head node last ++ # we start the worker nodes first, in detached mode, and then start the head node ++ # in the foreground, so that the output of the head node is visible in the buildkite logs ++ for node in $(seq $(($NUM_NODES - 1)) -1 0); do ++ GPU_DEVICES='"device=' ++ for node_gpu in $(seq 0 $(($NUM_GPUS - 1))); do ++ DEVICE_NUM=$(($node * $NUM_GPUS + $node_gpu)) ++ GPU_DEVICES+=$(($DEVICE_NUM)) ++ if [ "$node_gpu" -lt $(($NUM_GPUS - 1)) ]; then ++ GPU_DEVICES+=',' ++ fi ++ done ++ GPU_DEVICES+='"' ++ echo "Running node$node with GPU devices: $GPU_DEVICES" ++ if [ "$node" -ne 0 ]; then ++ docker exec -d "node$node" /bin/bash -c "cd $WORKING_DIR ; ${COMMANDS[$node]}" ++ else ++ docker exec "node$node" /bin/bash -c "cd $WORKING_DIR ; ${COMMANDS[$node]}" ++ fi ++ done ++} ++cleanup() { ++ for node in $(seq 0 $(($NUM_NODES-1))); do ++ docker stop "node$node" ++ done ++ docker network rm docker-net ++} ++trap cleanup EXIT ++start_network ++start_nodes ++run_nodes ++ +diff --git a/.buildkite/run-neuron-test.sh b/.buildkite/run-neuron-test.sh +index 252c0f7..189714e 100644 +--- a/.buildkite/run-neuron-test.sh ++++ b/.buildkite/run-neuron-test.sh +@@ -1,6 +1,20 @@ ++#!/bin/bash ++ + # This script build the Neuron docker image and run the API server inside the container. + # It serves a sanity check for compilation and basic model usage. + set -e ++set -v ++ ++image_name="neuron/vllm-ci" ++container_name="neuron_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" ++ ++HF_CACHE="$(realpath ~)/huggingface" ++mkdir -p "${HF_CACHE}" ++HF_MOUNT="/root/.cache/huggingface" ++ ++NEURON_COMPILE_CACHE_URL="$(realpath ~)/neuron_compile_cache" ++mkdir -p "${NEURON_COMPILE_CACHE_URL}" ++NEURON_COMPILE_CACHE_MOUNT="/root/.cache/neuron_compile_cache" + + # Try building the docker image + aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com +@@ -11,41 +25,30 @@ if [ -f /tmp/neuron-docker-build-timestamp ]; then + last_build=$(cat /tmp/neuron-docker-build-timestamp) + current_time=$(date +%s) + if [ $((current_time - last_build)) -gt 86400 ]; then ++ docker image prune -f + docker system prune -f +- echo $current_time > /tmp/neuron-docker-build-timestamp ++ rm -rf "${HF_MOUNT:?}/*" ++ rm -rf "${NEURON_COMPILE_CACHE_MOUNT:?}/*" ++ echo "$current_time" > /tmp/neuron-docker-build-timestamp + fi + else +- echo $(date +%s) > /tmp/neuron-docker-build-timestamp ++ date "+%s" > /tmp/neuron-docker-build-timestamp + fi + +-docker build -t neuron -f Dockerfile.neuron . ++docker build -t "${image_name}" -f Dockerfile.neuron . + + # Setup cleanup +-remove_docker_container() { docker rm -f neuron || true; } ++remove_docker_container() { ++ docker image rm -f "${image_name}" || true; ++} + trap remove_docker_container EXIT +-remove_docker_container + + # Run the image +-docker run --device=/dev/neuron0 --device=/dev/neuron1 --network host --name neuron neuron python3 -m vllm.entrypoints.api_server \ +- --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --max-num-seqs 8 --max-model-len 128 --block-size 128 --device neuron --tensor-parallel-size 2 & +- +-# Wait for the server to start +-wait_for_server_to_start() { +- timeout=300 +- counter=0 +- +- while [ "$(curl -s -o /dev/null -w ''%{http_code}'' localhost:8000/health)" != "200" ]; do +- sleep 1 +- counter=$((counter + 1)) +- if [ $counter -ge $timeout ]; then +- echo "Timeout after $timeout seconds" +- break +- fi +- done +-} +-wait_for_server_to_start +- +-# Test a simple prompt +-curl -X POST -H "Content-Type: application/json" \ +- localhost:8000/generate \ +- -d '{"prompt": "San Francisco is a"}' ++docker run --rm -it --device=/dev/neuron0 --device=/dev/neuron1 --network host \ ++ -v "${HF_CACHE}:${HF_MOUNT}" \ ++ -e "HF_HOME=${HF_MOUNT}" \ ++ -v "${NEURON_COMPILE_CACHE_URL}:${NEURON_COMPILE_CACHE_MOUNT}" \ ++ -e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \ ++ --name "${container_name}" \ ++ ${image_name} \ ++ /bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py" +diff --git a/.buildkite/run-openvino-test.sh b/.buildkite/run-openvino-test.sh +new file mode 100644 +index 0000000..6159b21 +--- /dev/null ++++ b/.buildkite/run-openvino-test.sh +@@ -0,0 +1,16 @@ ++#!/bin/bash ++ ++# This script build the OpenVINO docker image and run the offline inference inside the container. ++# It serves a sanity check for compilation and basic model usage. ++set -ex ++ ++# Try building the docker image ++docker build -t openvino-test -f Dockerfile.openvino . ++ ++# Setup cleanup ++remove_docker_container() { docker rm -f openvino-test || true; } ++trap remove_docker_container EXIT ++remove_docker_container ++ ++# Run the image and launch offline inference ++docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/examples/offline_inference/basic.py +diff --git a/.buildkite/run-tpu-test.sh b/.buildkite/run-tpu-test.sh +new file mode 100644 +index 0000000..650af0f +--- /dev/null ++++ b/.buildkite/run-tpu-test.sh +@@ -0,0 +1,26 @@ ++#!/bin/bash ++ ++set -e ++ ++# Build the docker image. ++docker build -f Dockerfile.tpu -t vllm-tpu . ++ ++# Set up cleanup. ++remove_docker_container() { docker rm -f tpu-test || true; } ++trap remove_docker_container EXIT ++# Remove the container that might not be cleaned up in the previous run. ++remove_docker_container ++ ++# For HF_TOKEN. ++source /etc/environment ++# Run a simple end-to-end example. ++docker run --privileged --net host --shm-size=16G -it \ ++ -e "HF_TOKEN=$HF_TOKEN" --name tpu-test \ ++ vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \ ++ && python3 -m pip install pytest \ ++ && python3 -m pip install lm_eval[api]==0.4.4 \ ++ && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py \ ++ && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \ ++ && python3 /workspace/vllm/tests/tpu/test_compilation.py \ ++ && python3 /workspace/vllm/tests/tpu/test_quantization_accuracy.py \ ++ && python3 /workspace/vllm/examples/offline_inference/tpu.py" +diff --git a/.buildkite/run-xpu-test.sh b/.buildkite/run-xpu-test.sh +new file mode 100644 +index 0000000..4d344e5 +--- /dev/null ++++ b/.buildkite/run-xpu-test.sh +@@ -0,0 +1,19 @@ ++#!/bin/bash ++ ++# This script build the CPU docker image and run the offline inference inside the container. ++# It serves a sanity check for compilation and basic model usage. ++set -ex ++ ++# Try building the docker image ++docker build -t xpu-test -f Dockerfile.xpu . ++ ++# Setup cleanup ++remove_docker_container() { docker rm -f xpu-test || true; } ++trap remove_docker_container EXIT ++remove_docker_container ++ ++# Run the image and test offline inference/tensor parallel ++docker run --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path --entrypoint="" xpu-test sh -c ' ++ python3 examples/offline_inference/basic.py ++ python3 examples/offline_inference/cli.py -tp 2 ++' +diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml +index e49a565..74b287c 100644 +--- a/.buildkite/test-pipeline.yaml ++++ b/.buildkite/test-pipeline.yaml +@@ -1,132 +1,594 @@ + # In this file, you can add more tests to run either by adding a new step or + # adding a new command to an existing step. See different options here for examples. +-# This script will be feed into Jinja template in `test-template.j2` to generate +-# the final pipeline yaml file. ++ ++# This script will be feed into Jinja template in `test-template-aws.j2` at ++# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2 ++# to generate the final pipeline yaml file. ++ ++# Documentation ++# label(str): the name of the test. emoji allowed. ++# fast_check(bool): whether to run this on each commit on fastcheck pipeline. ++# fast_check_only(bool): run this test on fastcheck pipeline only ++# optional(bool): never run this test by default (i.e. need to unblock manually) unless it's scheduled nightly run. ++# command(str): the single command to run for tests. incompatible with commands. ++# commands(list): the list of commands to run for test. incompatbile with command. ++# mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd] ++# gpu(str): override the GPU selection for the test. default is on L4 GPUs. currently only supports a100 ++# num_gpus(int): override the number of GPUs for the test. default to 1 GPU. currently support 2,4. ++# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host, ++# in this case, commands must be specified. the first command runs on first host, the second ++# command runs on the second host. ++# working_dir(str): specify the place where command should execute, default to /vllm-workspace/tests ++# source_file_dependencies(list): the list of prefix to opt-in the test for, if empty, the test will always run. ++ ++# When adding a test ++# - If the test belong to an existing group, add it there ++# - If the test is short, add to any existing step ++# - If the test takes more than 10min, then it is okay to create a new step. ++# Note that all steps execute in parallel. + + steps: +-- label: Regression Test +- command: pytest -v -s test_regression.py +- working_dir: "/vllm-workspace/tests" # optional ++##### fast check tests ##### ++ ++- label: Documentation Build # 2min ++ working_dir: "/vllm-workspace/test_docs/docs" ++ fast_check: true ++ no_gpu: True ++ commands: ++ - pip install -r requirements-docs.txt ++ - SPHINXOPTS=\"-W\" make html ++ # Check API reference (if it fails, you may have missing mock imports) ++ - grep \"sig sig-object py\" build/html/api/inference_params.html ++ ++- label: Async Engine, Inputs, Utils, Worker Test # 24min ++ fast_check: true ++ source_file_dependencies: ++ - vllm/ ++ - tests/mq_llm_engine ++ - tests/async_engine ++ - tests/test_inputs ++ - tests/multimodal ++ - tests/test_utils ++ - tests/worker ++ - tests/standalone_tests/lazy_torch_compile.py ++ commands: ++ - pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git # Used by multimoda processing test ++ - python3 standalone_tests/lazy_torch_compile.py ++ - pytest -v -s mq_llm_engine # MQLLMEngine ++ - pytest -v -s async_engine # AsyncLLMEngine ++ - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py ++ - pytest -v -s test_inputs.py ++ - pytest -v -s multimodal ++ - pytest -v -s test_utils.py # Utils ++ - pytest -v -s worker # Worker ++ ++- label: Python-only Installation Test ++ source_file_dependencies: ++ - tests/standalone_tests/python_only_compile.sh ++ - setup.py ++ commands: ++ - bash standalone_tests/python_only_compile.sh + +-- label: AsyncEngine Test +- command: pytest -v -s async_engine ++- label: Basic Correctness Test # 30min ++ #mirror_hardwares: [amd] ++ fast_check: true ++ source_file_dependencies: ++ - vllm/ ++ - tests/basic_correctness/test_basic_correctness ++ - tests/basic_correctness/test_cpu_offload ++ - tests/basic_correctness/test_preemption ++ commands: ++ - pytest -v -s basic_correctness/test_basic_correctness.py ++ - pytest -v -s basic_correctness/test_cpu_offload.py ++ - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py + +-- label: Basic Correctness Test ++- label: Chunked Prefill Test ++ source_file_dependencies: ++ - vllm/ ++ - tests/basic_correctness/test_chunked_prefill + commands: +- - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py +- - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py + - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py +- - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py + +-- label: Core Test ++- label: Core Test # 10min + mirror_hardwares: [amd] +- command: pytest -v -s core ++ fast_check: true ++ source_file_dependencies: ++ - vllm/core ++ - vllm/distributed ++ - tests/core ++ commands: ++ - pytest -v -s core + +-- label: Distributed Comm Ops Test +- command: pytest -v -s test_comm_ops.py +- working_dir: "/vllm-workspace/tests/distributed" +- num_gpus: 2 ++- label: Entrypoints Test # 40min ++ working_dir: "/vllm-workspace/tests" ++ fast_check: true ++ mirror_hardwares: [amd] ++ source_file_dependencies: ++ - vllm/ ++ commands: ++ - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py ++ - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process ++ - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process ++ - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process ++ - pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process ++ - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py ++ - pytest -v -s entrypoints/test_chat_utils.py ++ - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests ++ ++- label: Distributed Tests (4 GPUs) # 10min ++ working_dir: "/vllm-workspace/tests" ++ num_gpus: 4 ++ fast_check: true ++ source_file_dependencies: ++ - vllm/distributed/ ++ - vllm/core/ ++ - tests/distributed ++ - tests/spec_decode/e2e/test_integration_dist_tp4 ++ - tests/compile ++ commands: ++ - pytest -v -s distributed/test_utils.py ++ - pytest -v -s compile/test_basic_correctness.py ++ - pytest -v -s distributed/test_pynccl.py ++ - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py ++ ++- label: Metrics, Tracing Test # 10min ++ num_gpus: 2 ++ fast_check: true ++ source_file_dependencies: ++ - vllm/ ++ - tests/metrics ++ - tests/tracing ++ commands: ++ - pytest -v -s metrics ++ - "pip install \ ++ 'opentelemetry-sdk>=1.26.0,<1.27.0' \ ++ 'opentelemetry-api>=1.26.0,<1.27.0' \ ++ 'opentelemetry-exporter-otlp>=1.26.0,<1.27.0' \ ++ 'opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0'" ++ - pytest -v -s tracing + +-- label: Distributed Tests +- working_dir: "/vllm-workspace/tests/distributed" ++##### fast check tests ##### ++##### 1 GPU test ##### + +- num_gpus: 2 # only support 1 or 2 for now. ++- label: Regression Test # 5min + mirror_hardwares: [amd] ++ source_file_dependencies: ++ - vllm/ ++ - tests/test_regression ++ commands: ++ - pip install modelscope ++ - pytest -v -s test_regression.py ++ working_dir: "/vllm-workspace/tests" # optional + ++- label: Engine Test # 10min ++ mirror_hardwares: [amd] ++ source_file_dependencies: ++ - vllm/ ++ - tests/engine ++ - tests/tokenization + commands: +- - pytest -v -s test_pynccl_library.py +- - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py +- - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py +- - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py +- - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py ++ - pytest -v -s engine test_sequence.py test_config.py test_logger.py ++ # OOM in the CI unless we run this separately ++ - pytest -v -s tokenization + +-- label: Distributed Tests (Multiple Groups) +- working_dir: "/vllm-workspace/tests/distributed" +- num_gpus: 4 ++- label: V1 Test ++ #mirror_hardwares: [amd] ++ source_file_dependencies: ++ - vllm/ ++ - tests/v1 ++ commands: ++ - VLLM_USE_V1=1 pytest -v -s v1 ++ ++- label: Examples Test # 25min ++ working_dir: "/vllm-workspace/examples" ++ #mirror_hardwares: [amd] ++ source_file_dependencies: ++ - vllm/entrypoints ++ - examples/ + commands: +- - pytest -v -s test_pynccl.py ++ - pip install tensorizer # for tensorizer test ++ - python3 offline_inference/basic.py ++ - python3 offline_inference/cpu_offload.py ++ - python3 offline_inference/chat.py ++ - python3 offline_inference/prefix_caching.py ++ - python3 offline_inference/llm_engine_example.py ++ - python3 offline_inference/vision_language.py ++ - python3 offline_inference/vision_language_multi_image.py ++ - python3 other/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 other/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors ++ - python3 offline_inference/encoder_decoder.py ++ - python3 offline_inference/classification.py ++ - python3 offline_inference/embedding.py ++ - python3 offline_inference/scoring.py ++ - python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 + +-- label: Engine Test ++- label: Prefix Caching Test # 9min + mirror_hardwares: [amd] +- command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py ++ source_file_dependencies: ++ - vllm/ ++ - tests/prefix_caching ++ commands: ++ - pytest -v -s prefix_caching + +-- label: Entrypoints Test ++- label: Samplers Test # 36min ++ source_file_dependencies: ++ - vllm/model_executor/layers ++ - vllm/sampling_metadata.py ++ - tests/samplers ++ - tests/conftest.py + commands: +- # these tests have to be separated, because each one will allocate all posible GPU memory +- - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py +- - pytest -v -s entrypoints/test_server_oot_registration.py ++ - pytest -v -s samplers ++ - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers + +-- label: Examples Test +- working_dir: "/vllm-workspace/examples" ++- label: LogitsProcessor Test # 5min + mirror_hardwares: [amd] ++ source_file_dependencies: ++ - vllm/model_executor/layers ++ - vllm/model_executor/guided_decoding ++ - tests/test_logits_processor ++ - tests/model_executor/test_guided_processors ++ commands: ++ - pytest -v -s test_logits_processor.py ++ - pytest -v -s model_executor/test_guided_processors.py ++ ++- label: Speculative decoding tests # 40min ++ source_file_dependencies: ++ - vllm/spec_decode ++ - tests/spec_decode ++ - vllm/model_executor/models/eagle.py + commands: +- # install aws cli for llava_example.py +- - pip install awscli +- - python3 offline_inference.py +- - python3 offline_inference_with_prefix.py +- - python3 llm_engine_example.py +- - python3 llava_example.py ++ - pytest -v -s spec_decode/e2e/test_multistep_correctness.py ++ - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py ++ - pytest -v -s spec_decode/e2e/test_eagle_correctness.py + +-- label: Kernels Test %N +- command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT ++- label: LoRA Test %N # 15min each ++ mirror_hardwares: [amd] ++ source_file_dependencies: ++ - vllm/lora ++ - tests/lora ++ command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py + parallelism: 4 + +-- label: Models Test ++- label: "PyTorch Fullgraph Smoke Test" # 9min ++ fast_check: true ++ source_file_dependencies: ++ - vllm/ ++ - tests/compile ++ commands: ++ - pytest -v -s compile/test_basic_correctness.py ++ # these tests need to be separated, cannot combine ++ - pytest -v -s compile/piecewise/test_simple.py ++ - pytest -v -s compile/piecewise/test_toy_llama.py ++ ++- label: "PyTorch Fullgraph Test" # 18min ++ source_file_dependencies: ++ - vllm/ ++ - tests/compile ++ commands: ++ - pytest -v -s compile/test_full_graph.py ++ ++- label: Kernels Test %N # 1h each + mirror_hardwares: [amd] ++ source_file_dependencies: ++ - csrc/ ++ - vllm/attention ++ - tests/kernels + commands: +- - bash ../.buildkite/download-images.sh +- - pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py ++ - pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT ++ parallelism: 4 + +-- label: Llava Test ++- label: Tensorizer Test # 11min + mirror_hardwares: [amd] ++ soft_fail: true ++ source_file_dependencies: ++ - vllm/model_executor/model_loader ++ - tests/tensorizer_loader + commands: +- - bash ../.buildkite/download-images.sh +- - pytest -v -s models/test_llava.py ++ - apt-get update && apt-get install -y curl libsodium23 ++ - export VLLM_WORKER_MULTIPROC_METHOD=spawn ++ - pytest -v -s tensorizer_loader + +-- label: Prefix Caching Test ++- label: Benchmarks # 9min ++ working_dir: "/vllm-workspace/.buildkite" + mirror_hardwares: [amd] ++ source_file_dependencies: ++ - benchmarks/ + commands: +- - pytest -v -s prefix_caching ++ - bash run-benchmarks.sh + +-- label: Samplers Test +- command: pytest -v -s samplers ++- label: Quantization Test # 33min ++ source_file_dependencies: ++ - csrc/ ++ - vllm/model_executor/layers/quantization ++ - tests/quantization ++ command: VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization + +-- label: LogitsProcessor Test +- mirror_hardwares: [amd] +- command: pytest -v -s test_logits_processor.py ++- label: LM Eval Small Models # 53min ++ working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" ++ source_file_dependencies: ++ - csrc/ ++ - vllm/model_executor/layers/quantization ++ commands: ++ - export VLLM_WORKER_MULTIPROC_METHOD=spawn ++ - bash ./run-tests.sh -c configs/models-small.txt -t 1 + +-- label: Worker Test +- mirror_hardwares: [amd] +- command: pytest -v -s worker ++- label: Encoder Decoder tests # 5min ++ source_file_dependencies: ++ - vllm/ ++ - tests/encoder_decoder ++ commands: ++ - pytest -v -s encoder_decoder + +-- label: Speculative decoding tests +- mirror_hardwares: [amd] +- command: pytest -v -s spec_decode ++- label: OpenAI-Compatible Tool Use # 20 min ++ fast_check: false ++ mirror_hardwares: [ amd ] ++ source_file_dependencies: ++ - vllm/ ++ - tests/tool_use ++ commands: ++ - pytest -v -s tool_use + +-- label: LoRA Test %N +- command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT +- parallelism: 4 ++##### models test ##### + +-- label: Tensorizer Test +- command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader ++- label: Basic Models Test # 24min ++ source_file_dependencies: ++ - vllm/ ++ - tests/models ++ commands: ++ - pytest -v -s models/test_registry.py ++ - pytest -v -s models/test_initialization.py + +-- label: Metrics Test +- command: pytest -v -s metrics ++- label: Language Models Test (Standard) # 32min ++ #mirror_hardwares: [amd] ++ source_file_dependencies: ++ - vllm/ ++ - tests/models/decoder_only/language ++ - tests/models/embedding/language ++ - tests/models/encoder_decoder/language ++ commands: ++ - pytest -v -s models/decoder_only/language -m 'core_model or quant_model' ++ - pytest -v -s models/embedding/language -m core_model + +-- label: Quantization Test +- command: pytest -v -s quantization ++- label: Language Models Test (Extended) # 1h10min ++ optional: true ++ source_file_dependencies: ++ - vllm/ ++ - tests/models/decoder_only/language ++ - tests/models/embedding/language ++ - tests/models/encoder_decoder/language ++ commands: ++ - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' ++ - pytest -v -s models/embedding/language -m 'not core_model' + +-- label: Benchmarks +- working_dir: "/vllm-workspace/.buildkite" +- mirror_hardwares: [amd] ++- label: Multi-Modal Models Test (Standard) # 40min ++ #mirror_hardwares: [amd] ++ source_file_dependencies: ++ - vllm/ ++ - tests/models/decoder_only/audio_language ++ - tests/models/decoder_only/vision_language ++ - tests/models/embedding/vision_language ++ - tests/models/encoder_decoder/audio_language ++ - tests/models/encoder_decoder/vision_language + commands: +- - pip install aiohttp +- - bash run-benchmarks.sh ++ - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git ++ - pytest -v -s models/multimodal ++ - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' ++ - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model' ++ - pytest -v -s models/embedding/vision_language -m core_model ++ - pytest -v -s models/encoder_decoder/audio_language -m core_model ++ - pytest -v -s models/encoder_decoder/language -m core_model ++ - pytest -v -s models/encoder_decoder/vision_language -m core_model + +-- label: Documentation Build +- working_dir: "/vllm-workspace/test_docs/docs" +- no_gpu: True ++- label: Multi-Modal Models Test (Extended) 1 # 48m ++ optional: true ++ source_file_dependencies: ++ - vllm/ ++ - tests/models/decoder_only/audio_language ++ - tests/models/decoder_only/vision_language ++ - tests/models/embedding/vision_language ++ - tests/models/encoder_decoder/vision_language + commands: +- - pip install -r requirements-docs.txt +- - SPHINXOPTS=\"-W\" make html ++ - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git ++ - pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model' ++ - pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=0) and not core_model and not quant_model' ++ # HACK - run phi3v tests separately to sidestep this transformers bug ++ # https://github.com/huggingface/transformers/issues/34307 ++ - pytest -v -s models/decoder_only/vision_language/test_phi3v.py ++ - pytest -v -s --ignore models/decoder_only/vision_language/test_models.py --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model' ++ - pytest -v -s models/embedding/vision_language -m 'not core_model' ++ - pytest -v -s models/encoder_decoder/language -m 'not core_model' ++ - pytest -v -s models/encoder_decoder/vision_language -m 'not core_model' ++ ++- label: Multi-Modal Models Test (Extended) 2 # 38m ++ optional: true ++ source_file_dependencies: ++ - vllm/ ++ - tests/models/decoder_only/vision_language ++ commands: ++ - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git ++ - pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=1) and not core_model and not quant_model' ++ ++# This test is used only in PR development phase to test individual models and should never run on main ++- label: Custom Models Test ++ optional: true ++ commands: ++ - echo 'Testing custom models...' ++ # PR authors can temporarily add commands below to test individual models ++ # e.g. pytest -v -s models/encoder_decoder/vision_language/test_mllama.py ++ # *To avoid merge conflicts, remember to REMOVE (not just comment out) them before merging the PR* ++ ++##### 1 GPU test ##### ++##### multi gpus test ##### ++ ++- label: Distributed Comm Ops Test # 7min ++ working_dir: "/vllm-workspace/tests" ++ num_gpus: 2 ++ source_file_dependencies: ++ - vllm/distributed ++ - tests/distributed ++ commands: ++ - pytest -v -s distributed/test_comm_ops.py ++ - pytest -v -s distributed/test_shm_broadcast.py ++ ++- label: 2 Node Tests (4 GPUs in total) # 16min ++ working_dir: "/vllm-workspace/tests" ++ num_gpus: 2 ++ num_nodes: 2 ++ source_file_dependencies: ++ - vllm/distributed/ ++ - vllm/engine/ ++ - vllm/executor/ ++ - vllm/model_executor/models/ ++ - tests/distributed/ ++ commands: ++ - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) ++ - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' ++ - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py ++ - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py ++ - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) ++ - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' ++ ++- label: Distributed Tests (2 GPUs) # 40min ++ #mirror_hardwares: [amd] ++ working_dir: "/vllm-workspace/tests" ++ num_gpus: 2 ++ source_file_dependencies: ++ - vllm/distributed/ ++ - vllm/engine/ ++ - vllm/executor/ ++ - vllm/model_executor/models/ ++ - tests/distributed/ ++ - vllm/compilation ++ - vllm/worker/worker_base.py ++ - vllm/worker/worker.py ++ - vllm/worker/model_runner.py ++ commands: ++ - pytest -v -s ./compile/test_basic_correctness.py ++ - pytest -v -s ./compile/test_wrapper.py ++ - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' ++ - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' ++ # Avoid importing model tests that cause CUDA reinitialization error ++ - pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)' ++ - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)' ++ - pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)' ++ - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py ++ - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py ++ - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py ++ ++- label: Plugin Tests (2 GPUs) # 40min ++ working_dir: "/vllm-workspace/tests" ++ num_gpus: 2 ++ fast_check: true ++ source_file_dependencies: ++ - vllm/plugins/ ++ - tests/plugins/ ++ commands: ++ # begin platform plugin tests, all the code in-between runs on dummy platform ++ - pip install -e ./plugins/vllm_add_dummy_platform ++ - pytest -v -s plugins_tests/test_platform_plugins.py ++ - pip uninstall vllm_add_dummy_platform -y ++ # end platform plugin tests ++ # other tests continue here: ++ - pip install -e ./plugins/vllm_add_dummy_model ++ - pytest -v -s distributed/test_distributed_oot.py ++ - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process ++ - pytest -v -s models/test_oot_registration.py # it needs a clean process ++ ++- label: Multi-step Tests (4 GPUs) # 36min ++ working_dir: "/vllm-workspace/tests" ++ num_gpus: 4 ++ source_file_dependencies: ++ - vllm/model_executor/layers/sampler.py ++ - vllm/sequence.py ++ - vllm/worker/worker_base.py ++ - vllm/worker/worker.py ++ - vllm/worker/multi_step_worker.py ++ - vllm/worker/model_runner_base.py ++ - vllm/worker/model_runner.py ++ - vllm/worker/multi_step_model_runner.py ++ - vllm/engine ++ - tests/multi_step ++ commands: ++ - pytest -v -s multi_step/test_correctness_async_llm.py ++ - pytest -v -s multi_step/test_correctness_llm.py ++ ++- label: Pipeline Parallelism Test # 45min ++ working_dir: "/vllm-workspace/tests" ++ num_gpus: 4 ++ source_file_dependencies: ++ - vllm/distributed/ ++ - vllm/engine/ ++ - vllm/executor/ ++ - vllm/model_executor/models/ ++ - tests/distributed/ ++ commands: ++ - pytest -v -s distributed/test_pp_cudagraph.py ++ - pytest -v -s distributed/test_pipeline_parallel.py ++ ++- label: LoRA TP Test (Distributed) ++ num_gpus: 4 ++ source_file_dependencies: ++ - vllm/lora ++ - tests/lora ++ commands: ++ # FIXIT: find out which code initialize cuda before running the test ++ # before the fix, we need to use spawn to test it ++ - export VLLM_WORKER_MULTIPROC_METHOD=spawn ++ # This test runs llama 13B, so it is required to run on 4 GPUs. ++ - pytest -v -s -x lora/test_long_context.py ++ # There is some Tensor Parallelism related processing logic in LoRA that ++ # requires multi-GPU testing for validation. ++ - pytest -v -s -x lora/test_chatglm3_tp.py ++ - pytest -v -s -x lora/test_llama_tp.py ++ - pytest -v -s -x lora/test_minicpmv_tp.py ++ ++ ++- label: Weight Loading Multiple GPU Test # 33min ++ working_dir: "/vllm-workspace/tests" ++ num_gpus: 2 ++ source_file_dependencies: ++ - vllm/ ++ - tests/weight_loading ++ commands: ++ - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt ++ ++- label: Weight Loading Multiple GPU Test - Large Models # optional ++ working_dir: "/vllm-workspace/tests" ++ num_gpus: 2 ++ gpu: a100 ++ optional: true ++ source_file_dependencies: ++ - vllm/ ++ - tests/weight_loading ++ commands: ++ - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt ++ ++ ++##### multi gpus test ##### ++##### A100 test ##### ++ ++- label: Distributed Tests (A100) # optional ++ gpu: a100 ++ optional: true ++ num_gpus: 4 ++ source_file_dependencies: ++ - vllm/ ++ commands: ++ # NOTE: don't test llama model here, it seems hf implementation is buggy ++ # see https://github.com/vllm-project/vllm/pull/5689 for details ++ - pytest -v -s distributed/test_custom_all_reduce.py ++ - torchrun --nproc_per_node=2 distributed/test_ca_buffer_sharing.py ++ - TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' ++ - pytest -v -s -x lora/test_mixtral.py ++ ++- label: LM Eval Large Models # optional ++ gpu: a100 ++ optional: true ++ num_gpus: 4 ++ working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" ++ source_file_dependencies: ++ - csrc/ ++ - vllm/model_executor/layers/quantization ++ commands: ++ - export VLLM_WORKER_MULTIPROC_METHOD=spawn ++ - bash ./run-tests.sh -c configs/models-large.txt -t 4 +diff --git a/.buildkite/upload-wheels.sh b/.buildkite/upload-wheels.sh +new file mode 100644 +index 0000000..3c75665 +--- /dev/null ++++ b/.buildkite/upload-wheels.sh +@@ -0,0 +1,71 @@ ++#!/usr/bin/env bash ++ ++set -ex ++ ++# Assume wheels are in artifacts/dist/*.whl ++wheel_files=(artifacts/dist/*.whl) ++ ++# Check that exactly one wheel is found ++if [[ ${#wheel_files[@]} -ne 1 ]]; then ++ echo "Error: Expected exactly one wheel file in artifacts/dist/, but found ${#wheel_files[@]}" ++ exit 1 ++fi ++ ++# Get the single wheel file ++wheel="${wheel_files[0]}" ++ ++# Rename 'linux' to 'manylinux1' in the wheel filename ++new_wheel="${wheel/linux/manylinux1}" ++mv -- "$wheel" "$new_wheel" ++wheel="$new_wheel" ++ ++# Extract the version from the wheel ++version=$(unzip -p "$wheel" '**/METADATA' | grep '^Version: ' | cut -d' ' -f2) ++echo "Version: $version" ++ ++normal_wheel="$wheel" # Save the original wheel filename ++ ++# If the version contains "dev", rename it to v1.0.0.dev for consistency ++if [[ $version == *dev* ]]; then ++ suffix="${version##*.}" ++ if [[ $suffix == cu* ]]; then ++ new_version="1.0.0.dev+${suffix}" ++ else ++ new_version="1.0.0.dev" ++ fi ++ new_wheel="${wheel/$version/$new_version}" ++ # use cp to keep both files in the artifacts directory ++ cp -- "$wheel" "$new_wheel" ++ wheel="$new_wheel" ++ version="$new_version" ++fi ++ ++# Upload the wheel to S3 ++python3 .buildkite/generate_index.py --wheel "$normal_wheel" ++ ++# generate index for this commit ++aws s3 cp "$wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" ++aws s3 cp "$normal_wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" ++ ++if [[ $normal_wheel == *"cu118"* ]]; then ++ # if $normal_wheel matches cu118, do not upload the index.html ++ echo "Skipping index files for cu118 wheels" ++else ++ # only upload index.html for cu12 wheels (default wheels) ++ aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html" ++ aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html" ++fi ++ ++# generate index for nightly ++aws s3 cp "$wheel" "s3://vllm-wheels/nightly/" ++aws s3 cp "$normal_wheel" "s3://vllm-wheels/nightly/" ++ ++if [[ $normal_wheel == *"cu118"* ]]; then ++ # if $normal_wheel matches cu118, do not upload the index.html ++ echo "Skipping index files for cu118 wheels" ++else ++ # only upload index.html for cu12 wheels (default wheels) ++ aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html" ++fi ++ ++aws s3 cp "$wheel" "s3://vllm-wheels/$version/" +\ No newline at end of file +diff --git a/.clang-format b/.clang-format +new file mode 100644 +index 0000000..7f9e6d7 +--- /dev/null ++++ b/.clang-format +@@ -0,0 +1,26 @@ ++BasedOnStyle: Google ++UseTab: Never ++IndentWidth: 2 ++ColumnLimit: 80 ++ ++# Force pointers to the type for C++. ++DerivePointerAlignment: false ++PointerAlignment: Left ++ ++# Reordering #include statements can (and currently will) introduce errors ++SortIncludes: false ++ ++# Style choices ++AlignConsecutiveAssignments: false ++AlignConsecutiveDeclarations: false ++IndentPPDirectives: BeforeHash ++ ++IncludeCategories: ++ - Regex: '^<' ++ Priority: 4 ++ - Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/' ++ Priority: 3 ++ - Regex: '^"(qoda|\.\.)/' ++ Priority: 2 ++ - Regex: '.*' ++ Priority: 1 +diff --git a/.dockerignore b/.dockerignore +index 5cfe0dc..3863656 100644 +--- a/.dockerignore ++++ b/.dockerignore +@@ -1 +1,33 @@ ++/.venv ++/build ++dist + vllm/*.so ++ ++# Byte-compiled / optimized / DLL files ++__pycache__/ ++*.py[cod] ++*$py.class ++ ++.mypy_cache ++ ++# Distribution / packaging ++.Python ++/build/ ++cmake-build-*/ ++CMakeUserPresets.json ++develop-eggs/ ++/dist/ ++downloads/ ++eggs/ ++.eggs/ ++lib/ ++lib64/ ++parts/ ++sdist/ ++var/ ++wheels/ ++share/python-wheels/ ++*.egg-info/ ++.installed.cfg ++*.egg ++MANIFEST +diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS +new file mode 100644 +index 0000000..3cb91fc +--- /dev/null ++++ b/.github/CODEOWNERS +@@ -0,0 +1,33 @@ ++# See https://help.github.com/articles/about-codeowners/ ++# for more info about CODEOWNERS file ++ ++# This lists cover the "core" components of vLLM that require careful review ++/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill ++/vllm/core @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill ++/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill ++/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill ++/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill ++/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill ++/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill ++CMakeLists.txt @tlrmchlsmth ++ ++# vLLM V1 ++/vllm/v1 @WoosukKwon @robertgshaw2-neuralmagic @njhill @ywang96 @comaniac @alexm-neuralmagic ++ ++# Test ownership ++/tests/async_engine @njhill @robertgshaw2-neuralmagic @simon-mo ++/tests/test_inputs.py @DarkLight1337 @ywang96 ++/tests/entrypoints @DarkLight1337 @robertgshaw2-neuralmagic @simon-mo ++/tests/models @DarkLight1337 @ywang96 ++/tests/multimodal @DarkLight1337 @ywang96 ++/tests/prefix_caching @comaniac @KuntaiDu ++/tests/spec_decode @njhill @LiuXiaoxuanPKU ++/tests/kernels @tlrmchlsmth @WoosukKwon ++/tests/quantization @mgoin @robertgshaw2-neuralmagic ++/.buildkite/lm-eval-harness @mgoin @simon-mo ++/tests/distributed/test_multi_node_assignment.py @youkaichao ++/tests/distributed/test_pipeline_parallel.py @youkaichao ++/tests/distributed/test_same_node.py @youkaichao ++/tests/multi_step @alexm-neuralmagic @comaniac ++/tests/weight_loading @mgoin @youkaichao ++/tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac +diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml +new file mode 100644 +index 0000000..d1f6105 +--- /dev/null ++++ b/.github/FUNDING.yml +@@ -0,0 +1,2 @@ ++github: [vllm-project] ++open_collective: vllm +diff --git a/.github/ISSUE_TEMPLATE/100-documentation.yml b/.github/ISSUE_TEMPLATE/100-documentation.yml +index 501c0aa..74d397b 100644 +--- a/.github/ISSUE_TEMPLATE/100-documentation.yml ++++ b/.github/ISSUE_TEMPLATE/100-documentation.yml +@@ -20,3 +20,10 @@ body: + attributes: + value: > + Thanks for contributing 🎉! ++- type: checkboxes ++ id: askllm ++ attributes: ++ label: Before submitting a new issue... ++ options: ++ - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. ++ required: true +diff --git a/.github/ISSUE_TEMPLATE/200-installation.yml b/.github/ISSUE_TEMPLATE/200-installation.yml +index df41ade..590e56c 100644 +--- a/.github/ISSUE_TEMPLATE/200-installation.yml ++++ b/.github/ISSUE_TEMPLATE/200-installation.yml +@@ -38,3 +38,10 @@ body: + attributes: + value: > + Thanks for contributing 🎉! ++- type: checkboxes ++ id: askllm ++ attributes: ++ label: Before submitting a new issue... ++ options: ++ - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. ++ required: true +diff --git a/.github/ISSUE_TEMPLATE/300-usage.yml b/.github/ISSUE_TEMPLATE/300-usage.yml +index 54763af..004798a 100644 +--- a/.github/ISSUE_TEMPLATE/300-usage.yml ++++ b/.github/ISSUE_TEMPLATE/300-usage.yml +@@ -36,3 +36,10 @@ body: + attributes: + value: > + Thanks for contributing 🎉! ++- type: checkboxes ++ id: askllm ++ attributes: ++ label: Before submitting a new issue... ++ options: ++ - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. ++ required: true +diff --git a/.github/ISSUE_TEMPLATE/400-bug-report.yml b/.github/ISSUE_TEMPLATE/400-bug-report.yml +new file mode 100644 +index 0000000..30db172 +--- /dev/null ++++ b/.github/ISSUE_TEMPLATE/400-bug-report.yml +@@ -0,0 +1,107 @@ ++name: 🐛 Bug report ++description: Raise an issue here if you find a bug. ++title: "[Bug]: " ++labels: ["bug"] ++ ++body: ++- type: markdown ++ attributes: ++ value: > ++ #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+). ++- type: textarea ++ attributes: ++ label: Your current environment ++ description: | ++ Please run the following and paste the output below. ++ ```sh ++ wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py ++ # For security purposes, please feel free to check the contents of collect_env.py before running it. ++ python collect_env.py ++ ``` ++ It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues. ++ value: | ++
++ The output of `python collect_env.py` ++ ++ ```text ++ Your output of `python collect_env.py` here ++ ``` ++ ++
++ validations: ++ required: true ++- type: textarea ++ attributes: ++ label: Model Input Dumps ++ description: | ++ If you are facing crashing due to illegal memory access or other issues with model execution, vLLM may dump the problematic input of the model. In this case, you will see the message `Error in model execution (input dumped to /tmp/err_xxx.pkl)`. If you see this message, please zip the file (because GitHub doesn't support .pkl file format) and upload it here. This will help us to reproduce the issue and facilitate the debugging process. ++ placeholder: | ++ Upload the dumped input file. ++ validations: ++ required: false ++- type: textarea ++ attributes: ++ label: 🐛 Describe the bug ++ description: | ++ Please provide a clear and concise description of what the bug is. ++ ++ If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example: ++ ++ ```python ++ from vllm import LLM, SamplingParams ++ ++ prompts = [ ++ "Hello, my name is", ++ "The president of the United States is", ++ "The capital of France is", ++ "The future of AI is", ++ ] ++ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) ++ ++ llm = LLM(model="facebook/opt-125m") ++ ++ outputs = llm.generate(prompts, sampling_params) ++ ++ # Print the outputs. ++ for output in outputs: ++ prompt = output.prompt ++ generated_text = output.outputs[0].text ++ print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ++ ``` ++ ++ If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. ++ ++ Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. ++ ++ Please set the environment variable `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging to help debugging potential issues. ++ ++ If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs. ++ placeholder: | ++ A clear and concise description of what the bug is. ++ ++ ```python ++ # Sample code to reproduce the problem ++ ``` ++ ++ ``` ++ The error message you got, with the full traceback. ++ ``` ++ validations: ++ required: true ++- type: markdown ++ attributes: ++ value: > ++ ⚠️ Please separate bugs of `transformers` implementation or usage from bugs of `vllm`. If you think anything is wrong with the models' output: ++ ++ - Try the counterpart of `transformers` first. If the error appears, please go to [their issues](https://github.com/huggingface/transformers/issues?q=is%3Aissue+is%3Aopen+sort%3Aupdated-desc). ++ ++ - If the error only appears in vllm, please provide the detailed script of how you run `transformers` and `vllm`, also highlight the difference and what you expect. ++ ++ Thanks for contributing 🎉! ++- type: checkboxes ++ id: askllm ++ attributes: ++ label: Before submitting a new issue... ++ options: ++ - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. ++ required: true +diff --git a/.github/ISSUE_TEMPLATE/500-feature-request.yml b/.github/ISSUE_TEMPLATE/500-feature-request.yml +new file mode 100644 +index 0000000..097d88f +--- /dev/null ++++ b/.github/ISSUE_TEMPLATE/500-feature-request.yml +@@ -0,0 +1,38 @@ ++name: 🚀 Feature request ++description: Submit a proposal/request for a new vllm feature ++title: "[Feature]: " ++labels: ["feature request"] ++ ++body: ++- type: markdown ++ attributes: ++ value: > ++ #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+). ++- type: textarea ++ attributes: ++ label: 🚀 The feature, motivation and pitch ++ description: > ++ A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too. ++ validations: ++ required: true ++- type: textarea ++ attributes: ++ label: Alternatives ++ description: > ++ A description of any alternative solutions or features you've considered, if any. ++- type: textarea ++ attributes: ++ label: Additional context ++ description: > ++ Add any other context or screenshots about the feature request. ++- type: markdown ++ attributes: ++ value: > ++ Thanks for contributing 🎉! ++- type: checkboxes ++ id: askllm ++ attributes: ++ label: Before submitting a new issue... ++ options: ++ - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. ++ required: true +diff --git a/.github/ISSUE_TEMPLATE/600-new-model.yml b/.github/ISSUE_TEMPLATE/600-new-model.yml +new file mode 100644 +index 0000000..713e76c +--- /dev/null ++++ b/.github/ISSUE_TEMPLATE/600-new-model.yml +@@ -0,0 +1,40 @@ ++name: 🤗 Support request for a new model from huggingface ++description: Submit a proposal/request for a new model from huggingface ++title: "[New Model]: " ++labels: ["new model"] ++ ++body: ++- type: markdown ++ attributes: ++ value: > ++ #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+). ++ ++ #### We also highly recommend you read https://docs.vllm.ai/en/latest/contributing/model/adding_model.html first to understand how to add a new model. ++- type: textarea ++ attributes: ++ label: The model to consider. ++ description: > ++ A huggingface url, pointing to the model, e.g. https://huggingface.co/openai-community/gpt2 . ++ validations: ++ required: true ++- type: textarea ++ attributes: ++ label: The closest model vllm already supports. ++ description: > ++ Here is the list of models already supported by vllm: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models . Which model is the most similar to the model you want to add support for? ++- type: textarea ++ attributes: ++ label: What's your difficulty of supporting the model you want? ++ description: > ++ For example, any new operators or new architecture? ++- type: markdown ++ attributes: ++ value: > ++ Thanks for contributing 🎉! ++- type: checkboxes ++ id: askllm ++ attributes: ++ label: Before submitting a new issue... ++ options: ++ - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. ++ required: true +diff --git a/.github/ISSUE_TEMPLATE/700-performance-discussion.yml b/.github/ISSUE_TEMPLATE/700-performance-discussion.yml +new file mode 100644 +index 0000000..273f50d +--- /dev/null ++++ b/.github/ISSUE_TEMPLATE/700-performance-discussion.yml +@@ -0,0 +1,59 @@ ++name: ⚡ Discussion on the performance of vllm ++description: Submit a proposal/discussion about the performance of vllm ++title: "[Performance]: " ++labels: ["performance"] ++ ++body: ++- type: markdown ++ attributes: ++ value: > ++ #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+). ++- type: textarea ++ attributes: ++ label: Proposal to improve performance ++ description: > ++ How do you plan to improve vllm's performance? ++ validations: ++ required: false ++- type: textarea ++ attributes: ++ label: Report of performance regression ++ description: > ++ Please provide detailed description of performance comparison to confirm the regression. You may want to run the benchmark script at https://github.com/vllm-project/vllm/tree/main/benchmarks . ++ validations: ++ required: false ++- type: textarea ++ attributes: ++ label: Misc discussion on performance ++ description: > ++ Anything about the performance. ++ validations: ++ required: false ++- type: textarea ++ attributes: ++ label: Your current environment (if you think it is necessary) ++ description: | ++ Please run the following and paste the output below. ++ ```sh ++ wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py ++ # For security purposes, please feel free to check the contents of collect_env.py before running it. ++ python collect_env.py ++ ``` ++ It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues. ++ value: | ++ ```text ++ The output of `python collect_env.py` ++ ``` ++ validations: ++ required: false ++- type: markdown ++ attributes: ++ value: > ++ Thanks for contributing 🎉! ++- type: checkboxes ++ id: askllm ++ attributes: ++ label: Before submitting a new issue... ++ options: ++ - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. ++ required: true +diff --git a/.github/ISSUE_TEMPLATE/750-RFC.yml b/.github/ISSUE_TEMPLATE/750-RFC.yml +index 5382b12..e447c07 100644 +--- a/.github/ISSUE_TEMPLATE/750-RFC.yml ++++ b/.github/ISSUE_TEMPLATE/750-RFC.yml +@@ -47,3 +47,10 @@ body: + attributes: + value: > + Thanks for contributing 🎉! ++- type: checkboxes ++ id: askllm ++ attributes: ++ label: Before submitting a new issue... ++ options: ++ - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. ++ required: true +diff --git a/.github/ISSUE_TEMPLATE/800-misc-discussion.yml b/.github/ISSUE_TEMPLATE/800-misc-discussion.yml +new file mode 100644 +index 0000000..79e6e90 +--- /dev/null ++++ b/.github/ISSUE_TEMPLATE/800-misc-discussion.yml +@@ -0,0 +1,28 @@ ++name: 🎲 Misc/random discussions that do not fit into the above categories. ++description: Submit a discussion as you like. Note that developers are heavily overloaded and we mainly rely on community users to answer these issues. ++title: "[Misc]: " ++labels: ["misc"] ++ ++body: ++- type: markdown ++ attributes: ++ value: > ++ #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+). ++- type: textarea ++ attributes: ++ label: Anything you want to discuss about vllm. ++ description: > ++ Anything you want to discuss about vllm. ++ validations: ++ required: true ++- type: markdown ++ attributes: ++ value: > ++ Thanks for contributing 🎉! ++- type: checkboxes ++ id: askllm ++ attributes: ++ label: Before submitting a new issue... ++ options: ++ - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. ++ required: true +diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md +index 262ce8e..51a73c8 100644 +--- a/.github/PULL_REQUEST_TEMPLATE.md ++++ b/.github/PULL_REQUEST_TEMPLATE.md +@@ -2,63 +2,4 @@ FILL IN THE PR DESCRIPTION HERE + + FIX #xxxx (*link existing issues this PR will resolve*) + +-**BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE** +- +---- +- +-
+- +- PR Checklist (Click to Expand) +- +-

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

+- +-

PR Title and Classification

+-

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

+-
    +-
  • [Bugfix] for bug fixes.
  • +-
  • [CI/Build] for build or continuous integration improvements.
  • +-
  • [Doc] for documentation fixes and improvements.
  • +-
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • +-
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • +-
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • +-
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • +-
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • +-
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.
  • +-
+-

Note: If the PR spans more than one category, please include all relevant prefixes.

+- +-

Code Quality

+- +-

The PR need to meet the following code quality standards:

+- +-
    +-
  • We adhere to Google Python style guide and Google C++ style guide.
  • +-
  • Pass all linter checks. Please use format.sh to format your code.
  • +-
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • +-
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • +-
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.
  • +-
+- +-

Notes for Large Changes

+-

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

+- +-

What to Expect for the Reviews

+- +-

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

+- +-
    +-
  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • +-
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • +-
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • +-
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion. +-
  • +-
+- +-

Thank You

+- +-

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

+- +- +-
+- +- ++**BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html ** +diff --git a/.github/dependabot.yml b/.github/dependabot.yml +new file mode 100644 +index 0000000..683b70c +--- /dev/null ++++ b/.github/dependabot.yml +@@ -0,0 +1,31 @@ ++version: 2 ++updates: ++ # Maintain dependencies for GitHub Actions ++ - package-ecosystem: "github-actions" ++ directory: "/" ++ schedule: ++ interval: "weekly" ++ - package-ecosystem: "pip" ++ directory: "/" ++ schedule: ++ interval: "weekly" ++ labels: ["dependencies"] ++ open-pull-requests-limit: 5 ++ reviewers: ["khluu", "simon-mo"] ++ allow: ++ - dependency-type: "all" ++ ignore: ++ - dependency-name: "*" ++ update-types: ["version-update:semver-patch"] ++ - dependency-name: "torch" ++ - dependency-name: "torchvision" ++ - dependency-name: "xformers" ++ - dependency-name: "lm-format-enforcer" ++ - dependency-name: "gguf" ++ - dependency-name: "compressed-tensors" ++ - dependency-name: "ray[adag]" ++ - dependency-name: "lm-eval" ++ groups: ++ minor-update: ++ applies-to: version-updates ++ update-types: ["minor"] +diff --git a/.github/mergify.yml b/.github/mergify.yml +new file mode 100644 +index 0000000..ca4bd7e +--- /dev/null ++++ b/.github/mergify.yml +@@ -0,0 +1,60 @@ ++pull_request_rules: ++- name: label-documentation ++ description: Automatically apply documentation label ++ conditions: ++ - or: ++ - files~=^[^/]+\.md$ ++ - files~=^docs/ ++ actions: ++ label: ++ add: ++ - documentation ++ ++- name: label-ci-build ++ description: Automatically apply ci/build label ++ conditions: ++ - or: ++ - files~=^\.github/ ++ - files~=\.buildkite/ ++ - files~=^cmake/ ++ - files=CMakeLists.txt ++ - files~=^Dockerfile ++ - files~=^requirements.*\.txt ++ - files=setup.py ++ actions: ++ label: ++ add: ++ - ci/build ++ ++- name: label-frontend ++ description: Automatically apply frontend label ++ conditions: ++ - files~=^vllm/entrypoints/ ++ actions: ++ label: ++ add: ++ - frontend ++ ++- name: ping author on conflicts and add 'needs-rebase' label ++ conditions: ++ - conflict ++ - -closed ++ actions: ++ label: ++ add: ++ - needs-rebase ++ comment: ++ message: | ++ This pull request has merge conflicts that must be resolved before it can be ++ merged. Please rebase the PR, @{{author}}. ++ ++ https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork ++ ++- name: remove 'needs-rebase' label when conflict is resolved ++ conditions: ++ - -conflict ++ - -closed ++ actions: ++ label: ++ remove: ++ - needs-rebase +diff --git a/.github/scripts/cleanup_pr_body.sh b/.github/scripts/cleanup_pr_body.sh +new file mode 100644 +index 0000000..3246c6f +--- /dev/null ++++ b/.github/scripts/cleanup_pr_body.sh +@@ -0,0 +1,50 @@ ++#!/bin/bash ++ ++set -eu ++ ++# ensure 1 argument is passed ++if [ "$#" -ne 1 ]; then ++ echo "Usage: $0 " ++ exit 1 ++fi ++ ++PR_NUMBER=$1 ++OLD=/tmp/orig_pr_body.txt ++NEW=/tmp/new_pr_body.txt ++ ++gh pr view --json body --template "{{.body}}" "${PR_NUMBER}" > "${OLD}" ++cp "${OLD}" "${NEW}" ++ ++# Remove "FIX #xxxx (*link existing issues this PR will resolve*)" ++sed -i '/FIX #xxxx.*$/d' "${NEW}" ++ ++# Remove "FILL IN THE PR DESCRIPTION HERE" ++sed -i '/FILL IN THE PR DESCRIPTION HERE/d' "${NEW}" ++ ++# Remove all lines after and including "**BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE**" ++sed -i '/\*\*BEFORE SUBMITTING, PLEASE READ.*\*\*/,$d' "${NEW}" ++ ++# Remove HTML
section that includes text of "PR Checklist (Click to Expand)" ++python3 - <.*?.*?PR Checklist \(Click to Expand\).*?.*?
', re.DOTALL) ++content = re.sub(pattern, '', content) ++ ++with open("${NEW}", "w") as file: ++ file.write(content) ++EOF ++ ++# Run this only if ${NEW} is different than ${OLD} ++if ! cmp -s "${OLD}" "${NEW}"; then ++ gh pr edit --body-file "${NEW}" "${PR_NUMBER}" ++ echo ++ echo "Updated PR body:" ++ echo ++ cat "${NEW}" ++else ++ echo "No changes needed" ++fi +diff --git a/.github/workflows/actionlint.yml b/.github/workflows/actionlint.yml +new file mode 100644 +index 0000000..0226cf0 +--- /dev/null ++++ b/.github/workflows/actionlint.yml +@@ -0,0 +1,40 @@ ++name: Lint GitHub Actions workflows ++on: ++ push: ++ branches: ++ - "main" ++ paths: ++ - '.github/workflows/*.ya?ml' ++ - '.github/workflows/actionlint.*' ++ - '.github/workflows/matchers/actionlint.json' ++ pull_request: ++ branches: ++ - "main" ++ paths: ++ - '.github/workflows/*.ya?ml' ++ - '.github/workflows/actionlint.*' ++ - '.github/workflows/matchers/actionlint.json' ++ ++env: ++ LC_ALL: en_US.UTF-8 ++ ++defaults: ++ run: ++ shell: bash ++ ++permissions: ++ contents: read ++ ++jobs: ++ actionlint: ++ runs-on: ubuntu-latest ++ steps: ++ - name: "Checkout" ++ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 ++ with: ++ fetch-depth: 0 ++ ++ - name: "Run actionlint" ++ run: | ++ echo "::add-matcher::.github/workflows/matchers/actionlint.json" ++ tools/actionlint.sh -color +diff --git a/.github/workflows/add_label_automerge.yml b/.github/workflows/add_label_automerge.yml +new file mode 100644 +index 0000000..c9d6d42 +--- /dev/null ++++ b/.github/workflows/add_label_automerge.yml +@@ -0,0 +1,21 @@ ++name: Add label on auto-merge enabled ++on: ++ pull_request_target: ++ types: ++ - auto_merge_enabled ++jobs: ++ add-label-on-auto-merge: ++ runs-on: ubuntu-latest ++ steps: ++ - name: Add label ++ uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 ++ with: ++ script: | ++ github.rest.issues.addLabels({ ++ owner: context.repo.owner, ++ repo: context.repo.repo, ++ issue_number: context.issue.number, ++ labels: ['ready'] ++ }) ++ env: ++ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} +diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml +new file mode 100644 +index 0000000..68149d2 +--- /dev/null ++++ b/.github/workflows/clang-format.yml +@@ -0,0 +1,53 @@ ++name: clang-format ++ ++on: ++ # Trigger the workflow on push or pull request, ++ # but only for the main branch ++ push: ++ branches: ++ - main ++ paths: ++ - '**/*.h' ++ - '**/*.cpp' ++ - '**/*.cu' ++ - '**/*.cuh' ++ - '.github/workflows/clang-format.yml' ++ pull_request: ++ branches: ++ - main ++ paths: ++ - '**/*.h' ++ - '**/*.cpp' ++ - '**/*.cu' ++ - '**/*.cuh' ++ - '.github/workflows/clang-format.yml' ++ ++jobs: ++ clang-format: ++ runs-on: ubuntu-latest ++ strategy: ++ matrix: ++ python-version: ["3.11"] ++ steps: ++ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 ++ - name: Set up Python ${{ matrix.python-version }} ++ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 ++ with: ++ python-version: ${{ matrix.python-version }} ++ - name: Install dependencies ++ run: | ++ python -m pip install --upgrade pip ++ pip install clang-format==18.1.5 ++ - name: Running clang-format ++ run: | ++ EXCLUDES=( ++ 'csrc/moe/topk_softmax_kernels.cu' ++ 'csrc/quantization/gguf/ggml-common.h' ++ 'csrc/quantization/gguf/dequantize.cuh' ++ 'csrc/quantization/gguf/vecdotq.cuh' ++ 'csrc/quantization/gguf/mmq.cuh' ++ 'csrc/quantization/gguf/mmvq.cuh' ++ ) ++ find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ ++ | grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \ ++ | xargs clang-format --dry-run --Werror +diff --git a/.github/workflows/cleanup_pr_body.yml b/.github/workflows/cleanup_pr_body.yml +new file mode 100644 +index 0000000..0085a1c +--- /dev/null ++++ b/.github/workflows/cleanup_pr_body.yml +@@ -0,0 +1,26 @@ ++name: Cleanup PR Body ++ ++on: ++ pull_request_target: ++ types: [opened, reopened, edited] ++ ++permissions: ++ pull-requests: write ++ ++jobs: ++ update-description: ++ runs-on: ubuntu-latest ++ ++ steps: ++ - name: Checkout repository ++ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 ++ ++ - name: Set up Python ++ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 ++ with: ++ python-version: '3.12' ++ ++ - name: Update PR description ++ env: ++ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} ++ run: .github/scripts/cleanup_pr_body.sh "${{ github.event.number }}" +diff --git a/.github/workflows/codespell.yml b/.github/workflows/codespell.yml +new file mode 100644 +index 0000000..68887ad +--- /dev/null ++++ b/.github/workflows/codespell.yml +@@ -0,0 +1,45 @@ ++name: codespell ++ ++on: ++ # Trigger the workflow on push or pull request, ++ # but only for the main branch ++ push: ++ branches: ++ - main ++ paths: ++ - "**/*.py" ++ - "**/*.md" ++ - "**/*.rst" ++ - pyproject.toml ++ - requirements-lint.txt ++ - .github/workflows/codespell.yml ++ pull_request: ++ branches: ++ - main ++ paths: ++ - "**/*.py" ++ - "**/*.md" ++ - "**/*.rst" ++ - pyproject.toml ++ - requirements-lint.txt ++ - .github/workflows/codespell.yml ++ ++jobs: ++ codespell: ++ runs-on: ubuntu-latest ++ strategy: ++ matrix: ++ python-version: ["3.12"] ++ steps: ++ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 ++ - name: Set up Python ${{ matrix.python-version }} ++ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 ++ with: ++ python-version: ${{ matrix.python-version }} ++ - name: Install dependencies ++ run: | ++ python -m pip install --upgrade pip ++ pip install -r requirements-lint.txt ++ - name: Spelling check with codespell ++ run: | ++ codespell --toml pyproject.toml +diff --git a/.github/workflows/doc-lint.yml b/.github/workflows/doc-lint.yml +new file mode 100644 +index 0000000..2f5ee8b +--- /dev/null ++++ b/.github/workflows/doc-lint.yml +@@ -0,0 +1,32 @@ ++name: Lint documentation ++ ++on: ++ push: ++ branches: ++ - main ++ paths: ++ - "docs/**" ++ pull_request: ++ branches: ++ - main ++ paths: ++ - "docs/**" ++ ++jobs: ++ doc-lint: ++ runs-on: ubuntu-latest ++ strategy: ++ matrix: ++ python-version: ["3.12"] ++ steps: ++ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 ++ - name: Set up Python ${{ matrix.python-version }} ++ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 ++ with: ++ python-version: ${{ matrix.python-version }} ++ - name: Install dependencies ++ run: | ++ python -m pip install --upgrade pip ++ pip install -r requirements-lint.txt ++ - name: Linting docs ++ run: tools/doc-lint.sh +diff --git a/.github/workflows/lint-and-deploy.yaml b/.github/workflows/lint-and-deploy.yaml +new file mode 100644 +index 0000000..556b60d +--- /dev/null ++++ b/.github/workflows/lint-and-deploy.yaml +@@ -0,0 +1,82 @@ ++name: Lint and Deploy Charts ++ ++on: pull_request ++ ++jobs: ++ lint-and-deploy: ++ runs-on: ubuntu-latest ++ steps: ++ - name: Checkout ++ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 ++ with: ++ fetch-depth: 0 ++ ++ - name: Set up Helm ++ uses: azure/setup-helm@fe7b79cd5ee1e45176fcad797de68ecaf3ca4814 # v4.2.0 ++ with: ++ version: v3.14.4 ++ ++ #Python is required because ct lint runs Yamale and yamllint which require Python. ++ - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 ++ with: ++ python-version: '3.13' ++ ++ - name: Set up chart-testing ++ uses: helm/chart-testing-action@e6669bcd63d7cb57cb4380c33043eebe5d111992 # v2.6.1 ++ with: ++ version: v3.10.1 ++ ++ - name: Run chart-testing (lint) ++ run: ct lint --target-branch ${{ github.event.repository.default_branch }} --chart-dirs examples/online_serving/chart-helm --charts examples/online_serving/chart-helm ++ ++ - name: Setup minio ++ run: | ++ docker network create vllm-net ++ docker run -d -p 9000:9000 --name minio --net vllm-net \ ++ -e "MINIO_ACCESS_KEY=minioadmin" \ ++ -e "MINIO_SECRET_KEY=minioadmin" \ ++ -v /tmp/data:/data \ ++ -v /tmp/config:/root/.minio \ ++ minio/minio server /data ++ export AWS_ACCESS_KEY_ID=minioadmin ++ export AWS_SECRET_ACCESS_KEY=minioadmin ++ export AWS_EC2_METADATA_DISABLED=true ++ mkdir opt-125m ++ cd opt-125m && curl -O -Ls "https://huggingface.co/facebook/opt-125m/resolve/main/{pytorch_model.bin,config.json,generation_config.json,merges.txt,special_tokens_map.json,tokenizer_config.json,vocab.json}" && cd .. ++ aws --endpoint-url http://127.0.0.1:9000/ s3 mb s3://testbucket ++ aws --endpoint-url http://127.0.0.1:9000/ s3 cp opt-125m/ s3://testbucket/opt-125m --recursive ++ ++ - name: Create kind cluster ++ uses: helm/kind-action@0025e74a8c7512023d06dc019c617aa3cf561fde # v1.10.0 ++ ++ - name: Build the Docker image vllm cpu ++ run: docker buildx build -f Dockerfile.cpu -t vllm-cpu-env . ++ ++ - name: Configuration of docker images, network and namespace for the kind cluster ++ run: | ++ docker pull amazon/aws-cli:2.6.4 ++ kind load docker-image amazon/aws-cli:2.6.4 --name chart-testing ++ kind load docker-image vllm-cpu-env:latest --name chart-testing ++ docker network connect vllm-net "$(docker ps -aqf "name=chart-testing-control-plane")" ++ kubectl create ns ns-vllm ++ ++ - name: Run chart-testing (install) ++ run: | ++ export AWS_ACCESS_KEY_ID=minioadmin ++ export AWS_SECRET_ACCESS_KEY=minioadmin ++ sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" & ++ helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env" ++ ++ - name: curl test ++ run: | ++ kubectl -n ns-vllm port-forward service/test-vllm-service 8001:80 & ++ sleep 10 ++ CODE="$(curl -v -f --location http://localhost:8001/v1/completions \ ++ --header "Content-Type: application/json" \ ++ --data '{ ++ "model": "opt-125m", ++ "prompt": "San Francisco is a", ++ "max_tokens": 7, ++ "temperature": 0 ++ }'):$CODE" ++ echo "$CODE" +\ No newline at end of file +diff --git a/.github/workflows/matchers/actionlint.json b/.github/workflows/matchers/actionlint.json +new file mode 100644 +index 0000000..4613e16 +--- /dev/null ++++ b/.github/workflows/matchers/actionlint.json +@@ -0,0 +1,17 @@ ++{ ++ "problemMatcher": [ ++ { ++ "owner": "actionlint", ++ "pattern": [ ++ { ++ "regexp": "^(?:\\x1b\\[\\d+m)?(.+?)(?:\\x1b\\[\\d+m)*:(?:\\x1b\\[\\d+m)*(\\d+)(?:\\x1b\\[\\d+m)*:(?:\\x1b\\[\\d+m)*(\\d+)(?:\\x1b\\[\\d+m)*: (?:\\x1b\\[\\d+m)*(.+?)(?:\\x1b\\[\\d+m)* \\[(.+?)\\]$", ++ "file": 1, ++ "line": 2, ++ "column": 3, ++ "message": 4, ++ "code": 5 ++ } ++ ] ++ } ++ ] ++} +diff --git a/.github/workflows/matchers/mypy.json b/.github/workflows/matchers/mypy.json +new file mode 100644 +index 0000000..f048fce +--- /dev/null ++++ b/.github/workflows/matchers/mypy.json +@@ -0,0 +1,16 @@ ++{ ++ "problemMatcher": [ ++ { ++ "owner": "mypy", ++ "pattern": [ ++ { ++ "regexp": "^(.+):(\\d+):\\s(error|warning):\\s(.+)$", ++ "file": 1, ++ "line": 2, ++ "severity": 3, ++ "message": 4 ++ } ++ ] ++ } ++ ] ++} +diff --git a/.github/workflows/matchers/ruff.json b/.github/workflows/matchers/ruff.json +new file mode 100644 +index 0000000..f6d4479 +--- /dev/null ++++ b/.github/workflows/matchers/ruff.json +@@ -0,0 +1,17 @@ ++{ ++ "problemMatcher": [ ++ { ++ "owner": "ruff", ++ "pattern": [ ++ { ++ "regexp": "^(.+?):(\\d+):(\\d+): (\\w+): (.+)$", ++ "file": 1, ++ "line": 2, ++ "column": 3, ++ "code": 4, ++ "message": 5 ++ } ++ ] ++ } ++ ] ++ } +diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml +index a20753d..73eeacf 100644 +--- a/.github/workflows/mypy.yaml ++++ b/.github/workflows/mypy.yaml +@@ -6,45 +6,46 @@ on: + push: + branches: + - main ++ paths: ++ - '**/*.py' ++ - '.github/workflows/mypy.yaml' ++ - 'tools/mypy.sh' ++ - 'pyproject.toml' + pull_request: + branches: + - main ++ # This workflow is only relevant when one of the following files changes. ++ # However, we have github configured to expect and require this workflow ++ # to run and pass before github with auto-merge a pull request. Until github ++ # allows more flexible auto-merge policy, we can just run this on every PR. ++ # It doesn't take that long to run, anyway. ++ #paths: ++ # - '**/*.py' ++ # - '.github/workflows/mypy.yaml' ++ # - 'tools/mypy.sh' ++ # - 'pyproject.toml' + + jobs: +- ruff: ++ mypy: + runs-on: ubuntu-latest + strategy: + matrix: +- python-version: ["3.8", "3.9", "3.10", "3.11"] ++ python-version: ["3.9", "3.10", "3.11", "3.12"] + steps: +- - uses: actions/checkout@v2 ++ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set up Python ${{ matrix.python-version }} +- uses: actions/setup-python@v2 ++ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip +- pip install mypy==1.9.0 ++ pip install mypy==1.11.1 + pip install types-setuptools + pip install types-PyYAML + pip install types-requests + pip install types-setuptools + - name: Mypy + run: | +- mypy vllm/attention --config-file pyproject.toml +- mypy vllm/core --config-file pyproject.toml +- mypy vllm/distributed --config-file pyproject.toml +- mypy vllm/entrypoints --config-file pyproject.toml +- mypy vllm/executor --config-file pyproject.toml +- mypy vllm/usage --config-file pyproject.toml +- mypy vllm/*.py --config-file pyproject.toml +- mypy vllm/transformers_utils --config-file pyproject.toml +- mypy vllm/engine --config-file pyproject.toml +- mypy vllm/worker --config-file pyproject.toml +- mypy vllm/spec_decode --config-file pyproject.toml +- mypy vllm/model_executor --config-file pyproject.toml +- mypy vllm/lora --config-file pyproject.toml +- mypy vllm/logging --config-file pyproject.toml +- mypy vllm/model_executor --config-file pyproject.toml +- ++ echo "::add-matcher::.github/workflows/matchers/mypy.json" ++ tools/mypy.sh 1 ${{ matrix.python-version }} +diff --git a/.github/workflows/png-lint.yml b/.github/workflows/png-lint.yml +new file mode 100644 +index 0000000..4932af9 +--- /dev/null ++++ b/.github/workflows/png-lint.yml +@@ -0,0 +1,37 @@ ++name: Lint PNG exports from excalidraw ++on: ++ push: ++ branches: ++ - "main" ++ paths: ++ - '*.excalidraw.png' ++ - '.github/workflows/png-lint.yml' ++ pull_request: ++ branches: ++ - "main" ++ paths: ++ - '*.excalidraw.png' ++ - '.github/workflows/png-lint.yml' ++ ++env: ++ LC_ALL: en_US.UTF-8 ++ ++defaults: ++ run: ++ shell: bash ++ ++permissions: ++ contents: read ++ ++jobs: ++ actionlint: ++ runs-on: ubuntu-latest ++ steps: ++ - name: "Checkout" ++ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 ++ with: ++ fetch-depth: 0 ++ ++ - name: "Run png-lint.sh to check excalidraw exported images" ++ run: | ++ tools/png-lint.sh +diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml +index ac60ce0..e40ceaa 100644 +--- a/.github/workflows/publish.yml ++++ b/.github/workflows/publish.yml +@@ -21,16 +21,16 @@ jobs: + upload_url: ${{ steps.create_release.outputs.upload_url }} + steps: + - name: Checkout +- uses: actions/checkout@v3 ++ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Extract branch info + shell: bash + run: | +- echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV ++ echo "release_tag=${GITHUB_REF#refs/*/}" >> "$GITHUB_ENV" + + - name: Create Release + id: create_release +- uses: "actions/github-script@v6" ++ uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + env: + RELEASE_TAG: ${{ env.release_tag }} + with: +@@ -39,64 +39,68 @@ jobs: + const script = require('.github/workflows/scripts/create_release.js') + await script(github, context, core) + +- wheel: +- name: Build Wheel +- runs-on: ${{ matrix.os }} +- needs: release +- +- strategy: +- fail-fast: false +- matrix: +- os: ['ubuntu-20.04'] +- python-version: ['3.8', '3.9', '3.10', '3.11'] +- pytorch-version: ['2.3.0'] # Must be the most recent version that meets requirements-cuda.txt. +- cuda-version: ['11.8', '12.1'] +- +- steps: +- - name: Checkout +- uses: actions/checkout@v3 +- +- - name: Setup ccache +- uses: hendrikmuhs/ccache-action@v1.2 +- +- - name: Set up Linux Env +- if: ${{ runner.os == 'Linux' }} +- run: | +- bash -x .github/workflows/scripts/env.sh +- +- - name: Set up Python +- uses: actions/setup-python@v4 +- with: +- python-version: ${{ matrix.python-version }} +- +- - name: Install CUDA ${{ matrix.cuda-version }} +- run: | +- bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }} +- +- - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }} +- run: | +- bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }} +- +- - name: Build wheel +- shell: bash +- env: +- CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size +- run: | +- bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }} +- wheel_name=$(ls dist/*whl | xargs -n 1 basename) +- asset_name=${wheel_name//"linux"/"manylinux1"} +- echo "wheel_name=${wheel_name}" >> $GITHUB_ENV +- echo "asset_name=${asset_name}" >> $GITHUB_ENV +- +- - name: Upload Release Asset +- uses: actions/upload-release-asset@v1 +- env: +- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} +- with: +- upload_url: ${{ needs.release.outputs.upload_url }} +- asset_path: ./dist/${{ env.wheel_name }} +- asset_name: ${{ env.asset_name }} +- asset_content_type: application/* ++ # NOTE(simon): No longer build wheel using Github Actions. See buildkite's release workflow. ++ # wheel: ++ # name: Build Wheel ++ # runs-on: ${{ matrix.os }} ++ # needs: release ++ ++ # strategy: ++ # fail-fast: false ++ # matrix: ++ # os: ['ubuntu-20.04'] ++ # python-version: ['3.9', '3.10', '3.11', '3.12'] ++ # pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt. ++ # cuda-version: ['11.8', '12.1'] ++ ++ # steps: ++ # - name: Checkout ++ # uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 ++ ++ # - name: Setup ccache ++ # uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14 ++ # with: ++ # create-symlink: true ++ # key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }} ++ ++ # - name: Set up Linux Env ++ # if: ${{ runner.os == 'Linux' }} ++ # run: | ++ # bash -x .github/workflows/scripts/env.sh ++ ++ # - name: Set up Python ++ # uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 ++ # with: ++ # python-version: ${{ matrix.python-version }} ++ ++ # - name: Install CUDA ${{ matrix.cuda-version }} ++ # run: | ++ # bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }} ++ ++ # - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }} ++ # run: | ++ # bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }} ++ ++ # - name: Build wheel ++ # shell: bash ++ # env: ++ # CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size ++ # run: | ++ # bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }} ++ # wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename) ++ # asset_name=${wheel_name//"linux"/"manylinux1"} ++ # echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV" ++ # echo "asset_name=${asset_name}" >> "$GITHUB_ENV" ++ ++ # - name: Upload Release Asset ++ # uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2 ++ # env: ++ # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} ++ # with: ++ # upload_url: ${{ needs.release.outputs.upload_url }} ++ # asset_path: ./dist/${{ env.wheel_name }} ++ # asset_name: ${{ env.asset_name }} ++ # asset_content_type: application/* + + # (Danielkinz): This last step will publish the .whl to pypi. Warning: untested + # - name: Publish package +diff --git a/.github/workflows/reminder_comment.yml b/.github/workflows/reminder_comment.yml +new file mode 100644 +index 0000000..df62539 +--- /dev/null ++++ b/.github/workflows/reminder_comment.yml +@@ -0,0 +1,21 @@ ++name: PR Reminder Comment Bot ++on: ++ pull_request_target: ++ types: [opened] ++ ++jobs: ++ pr_reminder: ++ runs-on: ubuntu-latest ++ steps: ++ - name: Remind to run full CI on PR ++ uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 ++ with: ++ script: | ++ github.rest.issues.createComment({ ++ owner: context.repo.owner, ++ repo: context.repo.repo, ++ issue_number: context.issue.number, ++ body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org. \n\nOnce the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n To run CI, PR reviewers can do one of these:\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀' ++ }) ++ env: ++ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} +diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml +index e71033f..7266cc3 100644 +--- a/.github/workflows/ruff.yml ++++ b/.github/workflows/ruff.yml +@@ -6,32 +6,47 @@ on: + push: + branches: + - main ++ paths: ++ - "**/*.py" ++ - pyproject.toml ++ - requirements-lint.txt ++ - .github/workflows/matchers/ruff.json ++ - .github/workflows/ruff.yml + pull_request: + branches: + - main ++ # This workflow is only relevant when one of the following files changes. ++ # However, we have github configured to expect and require this workflow ++ # to run and pass before github with auto-merge a pull request. Until github ++ # allows more flexible auto-merge policy, we can just run this on every PR. ++ # It doesn't take that long to run, anyway. ++ #paths: ++ # - "**/*.py" ++ # - pyproject.toml ++ # - requirements-lint.txt ++ # - .github/workflows/matchers/ruff.json ++ # - .github/workflows/ruff.yml + + jobs: + ruff: + runs-on: ubuntu-latest + strategy: + matrix: +- python-version: ["3.8", "3.9", "3.10", "3.11"] ++ python-version: ["3.12"] + steps: +- - uses: actions/checkout@v2 +- - name: Set up Python ${{ matrix.python-version }} +- uses: actions/setup-python@v2 +- with: +- python-version: ${{ matrix.python-version }} +- - name: Install dependencies +- run: | +- python -m pip install --upgrade pip +- pip install ruff==0.1.5 codespell==2.2.6 tomli==2.0.1 isort==5.13.2 +- - name: Analysing the code with ruff +- run: | +- ruff . +- - name: Spelling check with codespell +- run: | +- codespell --toml pyproject.toml +- - name: Run isort +- run: | +- isort . --check-only ++ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 ++ - name: Set up Python ${{ matrix.python-version }} ++ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 ++ with: ++ python-version: ${{ matrix.python-version }} ++ - name: Install dependencies ++ run: | ++ python -m pip install --upgrade pip ++ pip install -r requirements-lint.txt ++ - name: Analysing the code with ruff ++ run: | ++ echo "::add-matcher::.github/workflows/matchers/ruff.json" ++ ruff check --output-format github . ++ - name: Run isort ++ run: | ++ isort . --check-only +diff --git a/.github/workflows/scripts/build.sh b/.github/workflows/scripts/build.sh +index 60a3978..122e4e1 100644 +--- a/.github/workflows/scripts/build.sh ++++ b/.github/workflows/scripts/build.sh +@@ -1,4 +1,5 @@ + #!/bin/bash ++set -eux + + python_executable=python$1 + cuda_home=/usr/local/cuda-$2 +@@ -8,14 +9,15 @@ PATH=${cuda_home}/bin:$PATH + LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH + + # Install requirements +-$python_executable -m pip install wheel packaging +-$python_executable -m pip install -r requirements-cuda.txt ++$python_executable -m pip install -r requirements-build.txt -r requirements-cuda.txt + + # Limit the number of parallel jobs to avoid OOM + export MAX_JOBS=1 +-# Make sure punica is built for the release (for LoRA) +-export VLLM_INSTALL_PUNICA_KERNELS=1 + # Make sure release wheels are built for the following architectures + export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" ++export VLLM_FA_CMAKE_GPU_ARCHES="80-real;90-real" ++ ++bash tools/check_repo.sh ++ + # Build + $python_executable setup.py bdist_wheel --dist-dir=dist +diff --git a/.github/workflows/scripts/cuda-install.sh b/.github/workflows/scripts/cuda-install.sh +index 312c6e8..3d0b7a1 100644 +--- a/.github/workflows/scripts/cuda-install.sh ++++ b/.github/workflows/scripts/cuda-install.sh +@@ -1,16 +1,16 @@ + #!/bin/bash + + # Replace '.' with '-' ex: 11.8 -> 11-8 +-cuda_version=$(echo $1 | tr "." "-") ++cuda_version=$(echo "$1" | tr "." "-") + # Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004 +-OS=$(echo $2 | tr -d ".\-") ++OS=$(echo "$2" | tr -d ".\-") + + # Installs CUDA +-wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb ++wget -nv "https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb" + sudo dpkg -i cuda-keyring_1.1-1_all.deb + rm cuda-keyring_1.1-1_all.deb + sudo apt -qq update +-sudo apt -y install cuda-${cuda_version} cuda-nvcc-${cuda_version} cuda-libraries-dev-${cuda_version} ++sudo apt -y install "cuda-${cuda_version}" "cuda-nvcc-${cuda_version}" "cuda-libraries-dev-${cuda_version}" + sudo apt clean + + # Test nvcc +diff --git a/.github/workflows/scripts/pytorch-install.sh b/.github/workflows/scripts/pytorch-install.sh +index dfc1851..e3cda7d 100644 +--- a/.github/workflows/scripts/pytorch-install.sh ++++ b/.github/workflows/scripts/pytorch-install.sh +@@ -6,7 +6,7 @@ cuda_version=$3 + + # Install torch + $python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya +-$python_executable -m pip install torch==${pytorch_version}+cu${cuda_version//./} --extra-index-url https://download.pytorch.org/whl/cu${cuda_version//./} ++$python_executable -m pip install torch=="${pytorch_version}+cu${cuda_version//./}" --extra-index-url "https://download.pytorch.org/whl/cu${cuda_version//./}" + + # Print version information + $python_executable --version +diff --git a/.github/workflows/shellcheck.yml b/.github/workflows/shellcheck.yml +new file mode 100644 +index 0000000..4b1587e +--- /dev/null ++++ b/.github/workflows/shellcheck.yml +@@ -0,0 +1,37 @@ ++name: Lint shell scripts ++on: ++ push: ++ branches: ++ - "main" ++ paths: ++ - '**/*.sh' ++ - '.github/workflows/shellcheck.yml' ++ pull_request: ++ branches: ++ - "main" ++ paths: ++ - '**/*.sh' ++ - '.github/workflows/shellcheck.yml' ++ ++env: ++ LC_ALL: en_US.UTF-8 ++ ++defaults: ++ run: ++ shell: bash ++ ++permissions: ++ contents: read ++ ++jobs: ++ shellcheck: ++ runs-on: ubuntu-latest ++ steps: ++ - name: "Checkout" ++ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 ++ with: ++ fetch-depth: 0 ++ ++ - name: "Check shell scripts" ++ run: | ++ tools/shellcheck.sh +diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml +new file mode 100644 +index 0000000..81e7c9b +--- /dev/null ++++ b/.github/workflows/stale.yml +@@ -0,0 +1,52 @@ ++name: 'Close inactive issues and PRs' ++ ++on: ++ schedule: ++ # Daily at 1:30 AM UTC ++ - cron: '30 1 * * *' ++ ++jobs: ++ close-issues-and-pull-requests: ++ permissions: ++ issues: write ++ pull-requests: write ++ actions: write ++ runs-on: ubuntu-latest ++ steps: ++ - uses: actions/stale@28ca1036281a5e5922ead5184a1bbf96e5fc984e # v9.0.0 ++ with: ++ # Increasing this value ensures that changes to this workflow ++ # propagate to all issues and PRs in days rather than months ++ operations-per-run: 1000 ++ ++ exempt-draft-pr: true ++ exempt-issue-labels: 'keep-open' ++ exempt-pr-labels: 'keep-open' ++ ++ labels-to-add-when-unstale: 'unstale' ++ labels-to-remove-when-stale: 'unstale' ++ ++ days-before-issue-stale: 90 ++ days-before-issue-close: 30 ++ stale-issue-label: 'stale' ++ stale-issue-message: > ++ This issue has been automatically marked as stale because it has not ++ had any activity within 90 days. It will be automatically closed if no ++ further activity occurs within 30 days. Leave a comment if ++ you feel this issue should remain open. Thank you! ++ close-issue-message: > ++ This issue has been automatically closed due to inactivity. Please ++ feel free to reopen if you feel it is still relevant. Thank you! ++ ++ days-before-pr-stale: 90 ++ days-before-pr-close: 30 ++ stale-pr-label: 'stale' ++ stale-pr-message: > ++ This pull request has been automatically marked as stale because it ++ has not had any activity within 90 days. It will be automatically ++ closed if no further activity occurs within 30 days. Leave a comment ++ if you feel this pull request should remain open. Thank you! ++ close-pr-message: > ++ This pull request has been automatically closed due to inactivity. ++ Please feel free to reopen if you intend to continue working on it. ++ Thank you! +diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml +index 04f307b..ff441f9 100644 +--- a/.github/workflows/yapf.yml ++++ b/.github/workflows/yapf.yml +@@ -6,26 +6,33 @@ on: + push: + branches: + - main ++ paths: ++ - "**/*.py" ++ - .github/workflows/yapf.yml + pull_request: + branches: + - main ++ paths: ++ - "**/*.py" ++ - .github/workflows/yapf.yml ++ + jobs: + yapf: + runs-on: ubuntu-latest + strategy: + matrix: +- python-version: ["3.8", "3.9", "3.10", "3.11"] ++ python-version: ["3.12"] + steps: +- - uses: actions/checkout@v2 +- - name: Set up Python ${{ matrix.python-version }} +- uses: actions/setup-python@v2 +- with: +- python-version: ${{ matrix.python-version }} +- - name: Install dependencies +- run: | +- python -m pip install --upgrade pip +- pip install yapf==0.32.0 +- pip install toml==0.10.2 +- - name: Running yapf +- run: | +- yapf --diff --recursive . ++ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 ++ - name: Set up Python ${{ matrix.python-version }} ++ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 ++ with: ++ python-version: ${{ matrix.python-version }} ++ - name: Install dependencies ++ run: | ++ python -m pip install --upgrade pip ++ pip install yapf==0.32.0 ++ pip install toml==0.10.2 ++ - name: Running yapf ++ run: | ++ yapf --diff --recursive . +diff --git a/.gitignore b/.gitignore +index e077366..89dab8f 100644 +--- a/.gitignore ++++ b/.gitignore +@@ -1,3 +1,9 @@ ++# version file generated by setuptools-scm ++/vllm/_version.py ++ ++# vllm-flash-attn built from source ++vllm/vllm_flash_attn/ ++ + # Byte-compiled / optimized / DLL files + __pycache__/ + *.py[cod] +@@ -9,6 +15,8 @@ __pycache__/ + # Distribution / packaging + .Python + build/ ++cmake-build-*/ ++CMakeUserPresets.json + develop-eggs/ + dist/ + downloads/ +@@ -25,6 +33,7 @@ share/python-wheels/ + .installed.cfg + *.egg + MANIFEST ++/.deps/ + + # PyInstaller + # Usually these files are written by a python script from a template +@@ -70,8 +79,7 @@ instance/ + + # Sphinx documentation + docs/_build/ +-docs/source/getting_started/examples/*.rst +-!**/*.template.rst ++docs/source/getting_started/examples/ + + # PyBuilder + .pybuilder/ +@@ -84,6 +92,9 @@ target/ + profile_default/ + ipython_config.py + ++# generated files ++**/generated/** ++ + # pyenv + # For a library or package, you might want to ignore these files since the code is + # intended to run in multiple environments; otherwise, check them in: +@@ -186,4 +197,8 @@ _build/ + hip_compat.h + + # Benchmark dataset +-*.json ++benchmarks/*.json ++ ++# Linting ++actionlint ++shellcheck*/ +diff --git a/.readthedocs.yaml b/.readthedocs.yaml +index 428e199..284196b 100644 +--- a/.readthedocs.yaml ++++ b/.readthedocs.yaml +@@ -6,16 +6,16 @@ version: 2 + build: + os: ubuntu-22.04 + tools: +- python: "3.8" ++ python: "3.12" + + sphinx: +- configuration: docs/source/conf.py ++ configuration: docs/source/conf.py ++ fail_on_warning: true + + # If using Sphinx, optionally build your docs in additional formats such as PDF +-formats: +- - pdf ++formats: [] + + # Optionally declare the Python requirements required to build your docs + python: +- install: +- - requirements: docs/requirements-docs.txt ++ install: ++ - requirements: docs/requirements-docs.txt +diff --git a/.shellcheckrc b/.shellcheckrc +new file mode 100644 +index 0000000..f3b6eed +--- /dev/null ++++ b/.shellcheckrc +@@ -0,0 +1,9 @@ ++# rules currently disabled: ++# ++# SC1091 (info): Not following: was not specified as input (see shellcheck -x) ++# SC2004 (style): $/${} is unnecessary on arithmetic variables. ++# SC2129 (style): Consider using { cmd1; cmd2; } >> file instead of individual redirects. ++# SC2155 (warning): Declare and assign separately to avoid masking return values. ++# SC2164 (warning): Use 'cd ... || exit' or 'cd ... || return' in case cd fails. ++# ++disable=SC1091,SC2004,SC2129,SC2155,SC2164 +diff --git a/CMakeLists.txt b/CMakeLists.txt +index f817f33..f4b9c3e 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -1,25 +1,43 @@ +-cmake_minimum_required(VERSION 3.21) ++cmake_minimum_required(VERSION 3.26) + ++# When building directly using CMake, make sure you run the install step ++# (it places the .so files in the correct location). ++# ++# Example: ++# mkdir build && cd build ++# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_INSTALL_PREFIX=.. .. ++# cmake --build . --target install ++# ++# If you want to only build one target, make sure to install it manually: ++# cmake --build . --target _C ++# cmake --install . --component _C + project(vllm_extensions LANGUAGES CXX) + +-option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda") ++# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py) ++set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM") + + message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") + message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") + + include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) + ++# Suppress potential warnings about unused manually-specified variables ++set(ignoreMe "${VLLM_PYTHON_PATH}") ++ ++# Prevent installation of dependencies (cutlass) by default. ++install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) ++ + # + # Supported python versions. These versions will be searched in order, the + # first match will be selected. These should be kept in sync with setup.py. + # +-set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") ++set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12") + + # Supported NVIDIA architectures. +-set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") ++set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0") + + # Supported AMD GPU architectures. +-set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100") ++set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101") + + # + # Supported/expected torch versions for CUDA/ROCm. +@@ -31,9 +49,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11 + # requirements.txt files and should be kept consistent. The ROCm torch + # versions are derived from Dockerfile.rocm + # +-set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0") +-set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1") +-set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1") ++set(TORCH_SUPPORTED_VERSION_CUDA "2.5.1") ++set(TORCH_SUPPORTED_VERSION_ROCM "2.5.1") + + # + # Try to find python package with an executable that exactly matches +@@ -66,19 +83,6 @@ endif() + # + find_package(Torch REQUIRED) + +-# +-# Normally `torch.utils.cpp_extension.CUDAExtension` would add +-# `libtorch_python.so` for linking against an extension. Torch's cmake +-# configuration does not include this library (presumably since the cmake +-# config is used for standalone C++ binaries that link against torch). +-# The `libtorch_python.so` library defines some of the glue code between +-# torch/python via pybind and is required by VLLM extensions for this +-# reason. So, add it by manually with `find_library` using torch's +-# installed library path. +-# +-find_library(torch_python_LIBRARY torch_python PATHS +- "${TORCH_INSTALL_PREFIX}/lib") +- + # + # Forward the non-CUDA device extensions to external CMake scripts. + # +@@ -87,7 +91,7 @@ if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND + if (VLLM_TARGET_DEVICE STREQUAL "cpu") + include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake) + else() +- message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}") ++ return() + endif() + return() + endif() +@@ -111,31 +115,42 @@ elseif(HIP_FOUND) + # .hip extension automatically, HIP must be enabled explicitly. + enable_language(HIP) + +- # ROCm 5.x +- if (ROCM_VERSION_DEV_MAJOR EQUAL 5 AND +- NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_5X}) +- message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_5X} " +- "expected for ROCMm 5.x build, saw ${Torch_VERSION} instead.") +- endif() +- +- # ROCm 6.x +- if (ROCM_VERSION_DEV_MAJOR EQUAL 6 AND +- NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_6X}) +- message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_6X} " +- "expected for ROCMm 6.x build, saw ${Torch_VERSION} instead.") ++ # ROCm 5.X and 6.X ++ if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND ++ NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM}) ++ message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} " ++ "expected for ROCm build, saw ${Torch_VERSION} instead.") + endif() + else() + message(FATAL_ERROR "Can't find CUDA or HIP installation.") + endif() + +-# +-# Override the GPU architectures detected by cmake/torch and filter them by +-# the supported versions for the current language. +-# The final set of arches is stored in `VLLM_GPU_ARCHES`. +-# +-override_gpu_arches(VLLM_GPU_ARCHES +- ${VLLM_GPU_LANG} +- "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}") ++ ++if(VLLM_GPU_LANG STREQUAL "CUDA") ++ # ++ # For cuda we want to be able to control which architectures we compile for on ++ # a per-file basis in order to cut down on compile time. So here we extract ++ # the set of architectures we want to compile for and remove the from the ++ # CMAKE_CUDA_FLAGS so that they are not applied globally. ++ # ++ clear_cuda_arches(CUDA_ARCH_FLAGS) ++ extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}") ++ message(STATUS "CUDA target architectures: ${CUDA_ARCHS}") ++ # Filter the target architectures by the supported supported archs ++ # since for some files we will build for all CUDA_ARCHS. ++ cuda_archs_loose_intersection(CUDA_ARCHS ++ "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}") ++ message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}") ++else() ++ # ++ # For other GPU targets override the GPU architectures detected by cmake/torch ++ # and filter them by the supported versions for the current language. ++ # The final set of arches is stored in `VLLM_GPU_ARCHES`. ++ # ++ override_gpu_arches(VLLM_GPU_ARCHES ++ ${VLLM_GPU_LANG} ++ "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}") ++endif() + + # + # Query torch for additional GPU compilation flags for the given +@@ -151,8 +166,19 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") + list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") + endif() + ++ + # +-# Define extension targets ++# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process. ++# setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache. ++# Each dependency that produces build artifacts should override its BINARY_DIR to avoid ++# conflicts between build types. It should instead be set to ${CMAKE_BINARY_DIR}/. ++# ++include(FetchContent) ++file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists ++message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") ++ ++# ++# Define other extension targets + # + + # +@@ -161,27 +187,243 @@ endif() + + set(VLLM_EXT_SRC + "csrc/cache_kernels.cu" +- "csrc/attention/attention_kernels.cu" ++ "csrc/attention/paged_attention_v1.cu" ++ "csrc/attention/paged_attention_v2.cu" + "csrc/pos_encoding_kernels.cu" + "csrc/activation_kernels.cu" + "csrc/layernorm_kernels.cu" +- "csrc/quantization/squeezellm/quant_cuda_kernel.cu" ++ "csrc/layernorm_quant_kernels.cu" + "csrc/quantization/gptq/q_gemm.cu" +- "csrc/quantization/fp8/fp8_cuda_kernels.cu" ++ "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" ++ "csrc/quantization/fp8/common.cu" ++ "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" ++ "csrc/quantization/gguf/gguf_kernel.cu" + "csrc/cuda_utils_kernels.cu" +- "csrc/moe_align_block_size_kernels.cu" +- "csrc/pybind.cpp") ++ "csrc/prepare_inputs/advance_step.cu" ++ "csrc/torch_bindings.cpp") + + if(VLLM_GPU_LANG STREQUAL "CUDA") ++ SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") ++ ++ # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. ++ set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use") ++ ++ # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided ++ if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) ++ set(VLLM_CUTLASS_SRC_DIR $ENV{VLLM_CUTLASS_SRC_DIR}) ++ endif() ++ ++ if(VLLM_CUTLASS_SRC_DIR) ++ if(NOT IS_ABSOLUTE VLLM_CUTLASS_SRC_DIR) ++ get_filename_component(VLLM_CUTLASS_SRC_DIR "${VLLM_CUTLASS_SRC_DIR}" ABSOLUTE) ++ endif() ++ message(STATUS "The VLLM_CUTLASS_SRC_DIR is set, using ${VLLM_CUTLASS_SRC_DIR} for compilation") ++ FetchContent_Declare(cutlass SOURCE_DIR ${VLLM_CUTLASS_SRC_DIR}) ++ else() ++ FetchContent_Declare( ++ cutlass ++ GIT_REPOSITORY https://github.com/nvidia/cutlass.git ++ GIT_TAG v3.6.0 ++ GIT_PROGRESS TRUE ++ ++ # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. ++ # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. ++ # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE ++ GIT_SHALLOW TRUE ++ ) ++ endif() ++ FetchContent_MakeAvailable(cutlass) ++ + list(APPEND VLLM_EXT_SRC ++ "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" ++ "csrc/mamba/causal_conv1d/causal_conv1d.cu" + "csrc/quantization/aqlm/gemm_kernels.cu" + "csrc/quantization/awq/gemm_kernels.cu" +- "csrc/quantization/marlin/marlin_cuda_kernel.cu" +- "csrc/quantization/gptq_marlin/gptq_marlin.cu" +- "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" +- "csrc/custom_all_reduce.cu") ++ "csrc/custom_all_reduce.cu" ++ "csrc/permute_cols.cu" ++ "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" ++ "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" ++ "csrc/sparse/cutlass/sparse_compressor_entry.cu" ++ "csrc/cutlass_extensions/common.cpp") ++ ++ set_gencode_flags_for_srcs( ++ SRCS "${VLLM_EXT_SRC}" ++ CUDA_ARCHS "${CUDA_ARCHS}") ++ ++ # Only build Marlin kernels if we are building for at least some compatible archs. ++ # Keep building Marlin for 9.0 as there are some group sizes and shapes that ++ # are not supported by Machete yet. ++ cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" ${CUDA_ARCHS}) ++ if (MARLIN_ARCHS) ++ set(MARLIN_SRCS ++ "csrc/quantization/fp8/fp8_marlin.cu" ++ "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" ++ "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" ++ "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" ++ "csrc/quantization/gptq_marlin/gptq_marlin.cu" ++ "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" ++ "csrc/quantization/gptq_marlin/awq_marlin_repack.cu") ++ set_gencode_flags_for_srcs( ++ SRCS "${MARLIN_SRCS}" ++ CUDA_ARCHS "${MARLIN_ARCHS}") ++ list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}") ++ message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}") ++ else() ++ message(STATUS "Not building Marlin kernels as no compatible archs found" ++ " in CUDA target architectures") ++ endif() ++ ++ # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require ++ # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now). ++ cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") ++ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) ++ set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") ++ set_gencode_flags_for_srcs( ++ SRCS "${SRCS}" ++ CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") ++ list(APPEND VLLM_EXT_SRC "${SRCS}") ++ list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1") ++ message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}") ++ else() ++ if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) ++ message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is " ++ "not >= 12.0, we recommend upgrading to CUDA 12.0 or " ++ "later if you intend on running FP8 quantized models on " ++ "Hopper.") ++ else() ++ message(STATUS "Not building scaled_mm_c3x as no compatible archs found " ++ "in CUDA target architectures") ++ endif() ++ ++ # clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't ++ # build any 3x kernels ++ set(SCALED_MM_3X_ARCHS) ++ endif() ++ ++ # ++ # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) ++ # kernels for the remaining archs that are not already built for 3x. ++ cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS ++ "7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") ++ # subtract out the archs that are already built for 3x ++ list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) ++ if (SCALED_MM_2X_ARCHS) ++ set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu") ++ set_gencode_flags_for_srcs( ++ SRCS "${SRCS}" ++ CUDA_ARCHS "${SCALED_MM_2X_ARCHS}") ++ list(APPEND VLLM_EXT_SRC "${SRCS}") ++ list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1") ++ message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}") ++ else() ++ if (SCALED_MM_3X_ARCHS) ++ message(STATUS "Not building scaled_mm_c2x as all archs are already built" ++ " for and covered by scaled_mm_c3x") ++ else() ++ message(STATUS "Not building scaled_mm_c2x as no compatible archs found " ++ "in CUDA target architectures") ++ endif() ++ endif() ++ ++ # ++ # 2:4 Sparse Kernels ++ ++ # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor ++ # require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now). ++ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS) ++ set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu" ++ "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") ++ set_gencode_flags_for_srcs( ++ SRCS "${SRCS}" ++ CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") ++ list(APPEND VLLM_EXT_SRC "${SRCS}") ++ list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1") ++ message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}") ++ else() ++ if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS) ++ message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is " ++ "not >= 12.2, we recommend upgrading to CUDA 12.2 or later " ++ "if you intend on running FP8 sparse quantized models on Hopper.") ++ else() ++ message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found " ++ "in CUDA target architectures") ++ endif() ++ endif() ++ ++ ++ # ++ # Machete kernels ++ ++ # The machete kernels only work on hopper and require CUDA 12.0 or later. ++ # Only build Machete kernels if we are building for something compatible with sm90a ++ cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}") ++ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND MACHETE_ARCHS) ++ # ++ # For the Machete kernels we automatically generate sources for various ++ # preselected input type pairs and schedules. ++ # Generate sources: ++ set(MACHETE_GEN_SCRIPT ++ ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py) ++ file(MD5 ${MACHETE_GEN_SCRIPT} MACHETE_GEN_SCRIPT_HASH) ++ ++ message(STATUS "Machete generation script hash: ${MACHETE_GEN_SCRIPT_HASH}") ++ message(STATUS "Last run machete generate script hash: $CACHE{MACHETE_GEN_SCRIPT_HASH}") ++ ++ if (NOT DEFINED CACHE{MACHETE_GEN_SCRIPT_HASH} ++ OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL ${MACHETE_GEN_SCRIPT_HASH}) ++ execute_process( ++ COMMAND ${CMAKE_COMMAND} -E env ++ PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH ++ ${Python_EXECUTABLE} ${MACHETE_GEN_SCRIPT} ++ RESULT_VARIABLE machete_generation_result ++ OUTPUT_VARIABLE machete_generation_output ++ OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log ++ ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log ++ ) ++ ++ if (NOT machete_generation_result EQUAL 0) ++ message(FATAL_ERROR "Machete generation failed." ++ " Result: \"${machete_generation_result}\"" ++ "\nCheck the log for details: " ++ "${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log") ++ else() ++ set(MACHETE_GEN_SCRIPT_HASH ${MACHETE_GEN_SCRIPT_HASH} ++ CACHE STRING "Last run machete generate script hash" FORCE) ++ message(STATUS "Machete generation completed successfully.") ++ endif() ++ else() ++ message(STATUS "Machete generation script has not changed, skipping generation.") ++ endif() ++ ++ # Add machete generated sources ++ file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu") ++ list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES}) ++ ++ # forward compatible ++ set_gencode_flags_for_srcs( ++ SRCS "${MACHETE_GEN_SOURCES}" ++ CUDA_ARCHS "${MACHETE_ARCHS}") ++ ++ list(APPEND VLLM_EXT_SRC ++ csrc/quantization/machete/machete_pytorch.cu) ++ ++ message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}") ++ else() ++ if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 ++ AND MACHETE_ARCHS) ++ message(STATUS "Not building Machete kernels as CUDA Compiler version is " ++ "not >= 12.0, we recommend upgrading to CUDA 12.0 or " ++ "later if you intend on running w4a16 quantized models on " ++ "Hopper.") ++ else() ++ message(STATUS "Not building Machete kernels as no compatible archs " ++ "found in CUDA target architectures") ++ endif() ++ endif() ++# if CUDA endif + endif() + ++message(STATUS "Enabling C extension.") + define_gpu_extension_target( + _C + DESTINATION vllm +@@ -189,16 +431,55 @@ define_gpu_extension_target( + SOURCES ${VLLM_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} ++ INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} ++ USE_SABI 3 + WITH_SOABI) + ++# If CUTLASS is compiled on NVCC >= 12.5, it by default uses ++# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the ++# driver API. This causes problems when linking with earlier versions of CUDA. ++# Setting this variable sidesteps the issue by calling the driver directly. ++target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) ++ + # + # _moe_C extension + # + + set(VLLM_MOE_EXT_SRC +- "csrc/moe/moe_ops.cpp" ++ "csrc/moe/torch_bindings.cpp" ++ "csrc/moe/moe_align_sum_kernels.cu" + "csrc/moe/topk_softmax_kernels.cu") + ++set_gencode_flags_for_srcs( ++ SRCS "${VLLM_MOE_EXT_SRC}" ++ CUDA_ARCHS "${CUDA_ARCHS}") ++ ++if(VLLM_GPU_LANG STREQUAL "CUDA") ++ cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") ++ if (MARLIN_MOE_ARCHS) ++ set(MARLIN_MOE_SRC ++ "csrc/moe/marlin_kernels/marlin_moe_kernel.h" ++ "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h" ++ "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu" ++ "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h" ++ "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu" ++ "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h" ++ "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu" ++ "csrc/moe/marlin_moe_ops.cu") ++ ++ set_gencode_flags_for_srcs( ++ SRCS "${MARLIN_MOE_SRC}" ++ CUDA_ARCHS "${MARLIN_MOE_ARCHS}") ++ ++ list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_SRC}") ++ message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") ++ else() ++ message(STATUS "Not building Marlin MOE kernels as no compatible archs found" ++ " in CUDA target architectures") ++ endif() ++endif() ++ ++message(STATUS "Enabling moe extension.") + define_gpu_extension_target( + _moe_C + DESTINATION vllm +@@ -206,89 +487,101 @@ define_gpu_extension_target( + SOURCES ${VLLM_MOE_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} ++ USE_SABI 3 + WITH_SOABI) + +-# +-# _punica_C extension +-# +- +-set(VLLM_PUNICA_EXT_SRC +- "csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu" +- "csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu" +- "csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu" +- "csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu" +- "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" +- "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" +- "csrc/punica/punica_ops.cc") +- +-# +-# Copy GPU compilation flags+update for punica +-# +-set(VLLM_PUNICA_GPU_FLAGS ${VLLM_GPU_FLAGS}) +-list(REMOVE_ITEM VLLM_PUNICA_GPU_FLAGS +- "-D__CUDA_NO_HALF_OPERATORS__" +- "-D__CUDA_NO_HALF_CONVERSIONS__" +- "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" +- "-D__CUDA_NO_HALF2_OPERATORS__") +- +-# +-# Filter out CUDA architectures < 8.0 for punica. +-# +-if (${VLLM_GPU_LANG} STREQUAL "CUDA") +- set(VLLM_PUNICA_GPU_ARCHES) +- foreach(ARCH ${VLLM_GPU_ARCHES}) +- string_to_ver(CODE_VER ${ARCH}) +- if (CODE_VER GREATER_EQUAL 8.0) +- list(APPEND VLLM_PUNICA_GPU_ARCHES ${ARCH}) +- endif() +- endforeach() +- message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") +-endif() ++if(VLLM_GPU_LANG STREQUAL "HIP") ++ # ++ # _rocm_C extension ++ # ++ set(VLLM_ROCM_EXT_SRC ++ "csrc/rocm/torch_bindings.cpp" ++ "csrc/rocm/attention.cu") + +-if (VLLM_PUNICA_GPU_ARCHES) + define_gpu_extension_target( +- _punica_C ++ _rocm_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} +- SOURCES ${VLLM_PUNICA_EXT_SRC} +- COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS} +- ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES} ++ SOURCES ${VLLM_ROCM_EXT_SRC} ++ COMPILE_FLAGS ${VLLM_GPU_FLAGS} ++ ARCHITECTURES ${VLLM_GPU_ARCHES} ++ USE_SABI 3 + WITH_SOABI) +-else() +- message(WARNING "Unable to create _punica_C target because none of the " +- "requested architectures (${VLLM_GPU_ARCHES}) are supported, i.e. >= 8.0") ++endif() ++ ++# vllm-flash-attn currently only supported on CUDA ++if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda") ++ return() ++endif () ++ ++# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target ++# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the ++# arches in the CUDA case (and instead set the gencodes on a per file basis) ++# we need to manually set VLLM_GPU_ARCHES here. ++if(VLLM_GPU_LANG STREQUAL "CUDA") ++ foreach(_ARCH ${CUDA_ARCHS}) ++ string(REPLACE "." "" _ARCH "${_ARCH}") ++ list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real") ++ endforeach() + endif() + + # +-# Add the `default` target which detects which extensions should be +-# built based on platform/architecture. This is the same logic that +-# setup.py uses to select which extensions should be built and should +-# be kept in sync. +-# +-# The `default` target makes direct use of cmake easier since knowledge +-# of which extensions are supported has been factored in, e.g. ++# Build vLLM flash attention from source + # +-# mkdir build && cd build +-# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm .. +-# cmake --build . --target default ++# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM. ++# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs. ++# They should be identical but if they aren't, this is a massive footgun. + # +-add_custom_target(default) ++# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. ++# To only install vllm-flash-attn, use --component vllm_flash_attn_c. ++# If no component is specified, vllm-flash-attn is still installed. + +-if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") +- message(STATUS "Enabling C extension.") +- add_dependencies(default _C) ++# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. ++# This is to enable local development of vllm-flash-attn within vLLM. ++# It can be set as an environment variable or passed as a cmake argument. ++# The environment variable takes precedence. ++if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR}) ++ set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR}) + endif() + +-if(VLLM_GPU_LANG STREQUAL "CUDA") +- message(STATUS "Enabling moe extension.") +- add_dependencies(default _moe_C) +- +- # Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or +- # VLLM_INSTALL_PUNICA_KERNELS is set in the environment and +- # there are supported target arches. +- if (VLLM_PUNICA_GPU_ARCHES AND +- (ENV{VLLM_INSTALL_PUNICA_KERNELS} OR VLLM_INSTALL_PUNICA_KERNELS)) +- message(STATUS "Enabling punica extension.") +- add_dependencies(default _punica_C) +- endif() ++if(VLLM_FLASH_ATTN_SRC_DIR) ++ FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR}) ++else() ++ FetchContent_Declare( ++ vllm-flash-attn ++ GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git ++ GIT_TAG 96266b1111111f3d11aabefaf3bacbab6a89d03c ++ GIT_PROGRESS TRUE ++ # Don't share the vllm-flash-attn build between build types ++ BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn ++ ) + endif() ++ ++# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization. ++set(VLLM_PARENT_BUILD ON) ++ ++# Ensure the vllm/vllm_flash_attn directory exists before installation ++install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c) ++ ++# Make sure vllm-flash-attn install rules are nested under vllm/ ++install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c) ++install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c) ++install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c) ++ ++# Fetch the vllm-flash-attn library ++FetchContent_MakeAvailable(vllm-flash-attn) ++message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") ++ ++# Restore the install prefix ++install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c) ++install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c) ++ ++# Copy over the vllm-flash-attn python files ++install( ++ DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ ++ DESTINATION vllm/vllm_flash_attn ++ COMPONENT vllm_flash_attn_c ++ FILES_MATCHING PATTERN "*.py" ++) ++ ++# Nothing after vllm-flash-attn, see comment about macros above +diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md +new file mode 100644 +index 0000000..f801b5f +--- /dev/null ++++ b/CODE_OF_CONDUCT.md +@@ -0,0 +1,128 @@ ++ ++# vLLM Code of Conduct ++ ++## Our Pledge ++ ++We as members, contributors, and leaders pledge to make participation in our ++community a harassment-free experience for everyone, regardless of age, body ++size, visible or invisible disability, ethnicity, sex characteristics, gender ++identity and expression, level of experience, education, socioeconomic status, ++nationality, personal appearance, race, caste, color, religion, or sexual ++identity and orientation. ++ ++We pledge to act and interact in ways that contribute to an open, welcoming, ++diverse, inclusive, and healthy community. ++ ++## Our Standards ++ ++Examples of behavior that contributes to a positive environment for our ++community include: ++ ++* Demonstrating empathy and kindness toward other people ++* Being respectful of differing opinions, viewpoints, and experiences ++* Giving and gracefully accepting constructive feedback ++* Accepting responsibility and apologizing to those affected by our mistakes, ++ and learning from the experience ++* Focusing on what is best not just for us as individuals, but for the overall ++ community ++ ++Examples of unacceptable behavior include: ++ ++* The use of sexualized language or imagery, and sexual attention or advances of ++ any kind ++* Trolling, insulting or derogatory comments, and personal or political attacks ++* Public or private harassment ++* Publishing others' private information, such as a physical or email address, ++ without their explicit permission ++* Other conduct which could reasonably be considered inappropriate in a ++ professional setting ++ ++## Enforcement Responsibilities ++ ++Community leaders are responsible for clarifying and enforcing our standards of ++acceptable behavior and will take appropriate and fair corrective action in ++response to any behavior that they deem inappropriate, threatening, offensive, ++or harmful. ++ ++Community leaders have the right and responsibility to remove, edit, or reject ++comments, commits, code, wiki edits, issues, and other contributions that are ++not aligned to this Code of Conduct, and will communicate reasons for moderation ++decisions when appropriate. ++ ++## Scope ++ ++This Code of Conduct applies within all community spaces, and also applies when ++an individual is officially representing the community in public spaces. ++Examples of representing our community include using an official email address, ++posting via an official social media account, or acting as an appointed ++representative at an online or offline/IRL event. ++ ++## Enforcement ++ ++Instances of abusive, harassing, or otherwise unacceptable behavior may be ++reported to the community leaders responsible for enforcement in the #code-of-conduct ++channel in the [vLLM Discord](https://discord.com/invite/jz7wjKhh6g). ++All complaints will be reviewed and investigated promptly and fairly. ++ ++All community leaders are obligated to respect the privacy and security of the ++reporter of any incident. ++ ++## Enforcement Guidelines ++ ++Community leaders will follow these Community Impact Guidelines in determining ++the consequences for any action they deem in violation of this Code of Conduct: ++ ++### 1. Correction ++ ++**Community Impact**: Use of inappropriate language or other behavior deemed ++unprofessional or unwelcome in the community. ++ ++**Consequence**: A private, written warning from community leaders, providing ++clarity around the nature of the violation and an explanation of why the ++behavior was inappropriate. A public apology may be requested. ++ ++### 2. Warning ++ ++**Community Impact**: A violation through a single incident or series of ++actions. ++ ++**Consequence**: A warning with consequences for continued behavior. No ++interaction with the people involved, including unsolicited interaction with ++those enforcing the Code of Conduct, for a specified period of time. This ++includes avoiding interactions in community spaces as well as external channels ++like social media. Violating these terms may lead to a temporary or permanent ++ban. ++ ++### 3. Temporary Ban ++ ++**Community Impact**: A serious violation of community standards, including ++sustained inappropriate behavior. ++ ++**Consequence**: A temporary ban from any sort of interaction or public ++communication with the community for a specified period of time. No public or ++private interaction with the people involved, including unsolicited interaction ++with those enforcing the Code of Conduct, is allowed during this period. ++Violating these terms may lead to a permanent ban. ++ ++### 4. Permanent Ban ++ ++**Community Impact**: Demonstrating a pattern of violation of community ++standards, including sustained inappropriate behavior, harassment of an ++individual, or aggression toward or disparagement of classes of individuals. ++ ++**Consequence**: A permanent ban from any sort of public interaction within the ++community. ++ ++## Attribution ++ ++This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org/), ++version 2.1, available at ++[v2.1](https://www.contributor-covenant.org/version/2/1/code_of_conduct.html). ++ ++Community Impact Guidelines were inspired by ++[Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/inclusion). ++ ++For answers to common questions about this code of conduct, see the ++[Contributor Covenant FAQ](https://www.contributor-covenant.org/faq). Translations are available at ++[Contributor Covenant translations](https://www.contributor-covenant.org/translations). ++ +diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md +index 81a8db2..6d46a6d 100644 +--- a/CONTRIBUTING.md ++++ b/CONTRIBUTING.md +@@ -1,56 +1,3 @@ + # Contributing to vLLM + +-Thank you for your interest in contributing to vLLM! +-Our community is open to everyone and welcomes all kinds of contributions, no matter how small or large. +-There are several ways you can contribute to the project: +- +-- Identify and report any issues or bugs. +-- Request or add a new model. +-- Suggest or implement new features. +- +-However, remember that contributions aren't just about code. +-We believe in the power of community support; thus, answering queries, assisting others, and enhancing the documentation are highly regarded and beneficial contributions. +- +-Finally, one of the most impactful ways to support us is by raising awareness about vLLM. +-Talk about it in your blog posts, highlighting how it's driving your incredible projects. +-Express your support on Twitter if vLLM aids you, or simply offer your appreciation by starring our repository. +- +- +-## Setup for development +- +-### Build from source +- +-```bash +-pip install -e . # This may take several minutes. +-``` +- +-### Testing +- +-```bash +-pip install -r requirements-dev.txt +- +-# linting and formatting +-bash format.sh +-# Static type checking +-mypy +-# Unit tests +-pytest tests/ +-``` +-**Note:** Currently, the repository does not pass the mypy tests. +- +- +-## Contributing Guidelines +- +-### Issue Reporting +- +-If you encounter a bug or have a feature request, please check our issues page first to see if someone else has already reported it. +-If not, please file a new issue, providing as much relevant information as possible. +- +-### Pull Requests & Code Reviews +- +-Please check the PR checklist in the [PR template](.github/PULL_REQUEST_TEMPLATE.md) for detailed guide for contribution. +- +-### Thank You +- +-Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. +-Your contributions make vLLM a great tool for everyone! ++You may find information about contributing to vLLM on [docs.vllm.ai](https://docs.vllm.ai/en/latest/contributing/overview.html). +diff --git a/DCO b/DCO +new file mode 100644 +index 0000000..49b8cb0 +--- /dev/null ++++ b/DCO +@@ -0,0 +1,34 @@ ++Developer Certificate of Origin ++Version 1.1 ++ ++Copyright (C) 2004, 2006 The Linux Foundation and its contributors. ++ ++Everyone is permitted to copy and distribute verbatim copies of this ++license document, but changing it is not allowed. ++ ++ ++Developer's Certificate of Origin 1.1 ++ ++By making a contribution to this project, I certify that: ++ ++(a) The contribution was created in whole or in part by me and I ++ have the right to submit it under the open source license ++ indicated in the file; or ++ ++(b) The contribution is based upon previous work that, to the best ++ of my knowledge, is covered under an appropriate open source ++ license and I have the right under that license to submit that ++ work with modifications, whether created in whole or in part ++ by me, under the same open source license (unless I am ++ permitted to submit under a different license), as indicated ++ in the file; or ++ ++(c) The contribution was provided directly to me by some other ++ person who certified (a), (b) or (c) and I have not modified ++ it. ++ ++(d) I understand and agree that this project and the contribution ++ are public and that a record of the contribution (including all ++ personal information I submit with it, including my sign-off) is ++ maintained indefinitely and may be redistributed consistent with ++ this project or the open source license(s) involved. +diff --git a/Dockerfile b/Dockerfile +index 90be3a3..4542bc9 100644 +--- a/Dockerfile ++++ b/Dockerfile +@@ -2,34 +2,63 @@ + # to run the OpenAI compatible server. + + # Please update any changes made here to +-# docs/source/dev/dockerfile/dockerfile.rst and +-# docs/source/assets/dev/dockerfile-stages-dependency.png ++# docs/source/contributing/dockerfile/dockerfile.md and ++# docs/source/assets/contributing/dockerfile-stages-dependency.png + ++ARG CUDA_VERSION=12.4.1 + #################### BASE BUILD IMAGE #################### + # prepare basic build environment +-FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS dev +- +-RUN apt-get update -y \ +- && apt-get install -y python3-pip git ++FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base ++ARG CUDA_VERSION=12.4.1 ++ARG PYTHON_VERSION=3.12 ++ARG TARGETPLATFORM ++ENV DEBIAN_FRONTEND=noninteractive ++ ++# Install Python and other dependencies ++RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ ++ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ ++ && apt-get update -y \ ++ && apt-get install -y ccache software-properties-common git curl sudo \ ++ && add-apt-repository ppa:deadsnakes/ppa \ ++ && apt-get update -y \ ++ && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ ++ && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ ++ && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ ++ && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ ++ && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ ++ && python3 --version && python3 -m pip --version ++ ++# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519 ++# as it was causing spam when compiling the CUTLASS kernels ++RUN apt-get install -y gcc-10 g++-10 ++RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 110 --slave /usr/bin/g++ g++ /usr/bin/g++-10 ++RUN <> /etc/environment ++ ++# Install Python and other dependencies ++RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ ++ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ ++ && apt-get update -y \ ++ && apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \ ++ && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \ ++ && add-apt-repository ppa:deadsnakes/ppa \ ++ && apt-get update -y \ ++ && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \ ++ && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ ++ && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ ++ && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ ++ && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ ++ && python3 --version && python3 -m pip --version + + # Workaround for https://github.com/openai/triton/issues/2507 and + # https://github.com/pytorch/pytorch/issues/107960 -- hopefully + # this won't be needed for future versions of this docker image + # or future versions of triton. +-RUN ldconfig /usr/local/cuda-12.4/compat/ ++RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ ++ ++# arm64 (GH200) build follows the practice of "use existing pytorch" build, ++# we need to install torch and torchvision from the nightly builds first, ++# pytorch will not appear as a vLLM dependency in all of the following steps ++# after this step ++RUN --mount=type=cache,target=/root/.cache/pip \ ++ if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ ++ python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \ ++ fi + +-# install vllm wheel first, so that torch etc will be installed ++# Install vllm wheel first, so that torch etc will be installed. + RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ + --mount=type=cache,target=/root/.cache/pip \ +- pip install dist/*.whl --verbose ++ python3 -m pip install dist/*.whl --verbose + +-RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ +- --mount=type=cache,target=/root/.cache/pip \ +- pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir ++RUN --mount=type=cache,target=/root/.cache/pip \ ++. /etc/environment && \ ++if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ ++ python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \ ++fi ++COPY examples examples + #################### vLLM installation IMAGE #################### + +- + #################### TEST IMAGE #################### + # image to run unit testing suite + # note that this uses vllm installed by `pip` +@@ -138,7 +211,19 @@ ADD . /vllm-workspace/ + + # install development dependencies (for testing) + RUN --mount=type=cache,target=/root/.cache/pip \ +- pip install -r requirements-dev.txt ++ python3 -m pip install -r requirements-dev.txt ++ ++# install development dependencies (for testing) ++RUN --mount=type=cache,target=/root/.cache/pip \ ++ python3 -m pip install -e tests/vllm_test_utils ++ ++# enable fast downloads from hf (for testing) ++RUN --mount=type=cache,target=/root/.cache/pip \ ++ python3 -m pip install hf_transfer ++ENV HF_HUB_ENABLE_HF_TRANSFER 1 ++ ++# Copy in the v1 package for testing (it isn't distributed yet) ++COPY vllm/v1 /usr/local/lib/python3.12/dist-packages/vllm/v1 + + # doc requires source code + # we hide them inside `test_docs/` , so that this source code +@@ -146,18 +231,30 @@ RUN --mount=type=cache,target=/root/.cache/pip \ + RUN mkdir test_docs + RUN mv docs test_docs/ + RUN mv vllm test_docs/ +- + #################### TEST IMAGE #################### + + #################### OPENAI API SERVER #################### +-# openai api server alternative +-FROM vllm-base AS vllm-openai ++# base openai image with additional requirements, for any subsequent openai-style images ++FROM vllm-base AS vllm-openai-base + + # install additional dependencies for openai api server + RUN --mount=type=cache,target=/root/.cache/pip \ +- pip install accelerate hf_transfer modelscope ++ if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ ++ pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \ ++ else \ ++ pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \ ++ fi + + ENV VLLM_USAGE_SOURCE production-docker-image + ++# define sagemaker first, so it is not default from `docker build` ++FROM vllm-openai-base AS vllm-sagemaker ++ ++COPY examples/online_serving/sagemaker-entrypoint.sh . ++RUN chmod +x sagemaker-entrypoint.sh ++ENTRYPOINT ["./sagemaker-entrypoint.sh"] ++ ++FROM vllm-openai-base AS vllm-openai ++ + ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] + #################### OPENAI API SERVER #################### +diff --git a/Dockerfile.arm b/Dockerfile.arm +new file mode 100644 +index 0000000..093ee22 +--- /dev/null ++++ b/Dockerfile.arm +@@ -0,0 +1,62 @@ ++# This vLLM Dockerfile is used to construct an image that can build and run vLLM on ARM CPU platform. ++ ++FROM ubuntu:22.04 AS cpu-test-arm ++ ++ENV CCACHE_DIR=/root/.cache/ccache ++ ++ENV CMAKE_CXX_COMPILER_LAUNCHER=ccache ++ ++RUN --mount=type=cache,target=/var/cache/apt \ ++ apt-get update -y \ ++ && apt-get install -y curl ccache git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \ ++ && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \ ++ && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 ++ ++# tcmalloc provides better memory allocation efficiency, e.g., holding memory in caches to speed up access of commonly-used objects. ++RUN --mount=type=cache,target=/root/.cache/pip \ ++ pip install py-cpuinfo # Use this to gather CPU info and optimize based on ARM Neoverse cores ++ ++# Set LD_PRELOAD for tcmalloc on ARM ++ENV LD_PRELOAD="/usr/lib/aarch64-linux-gnu/libtcmalloc_minimal.so.4" ++ ++RUN echo 'ulimit -c 0' >> ~/.bashrc ++ ++WORKDIR /workspace ++ ++ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" ++ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL} ++RUN --mount=type=cache,target=/root/.cache/pip \ ++ --mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \ ++ pip install --upgrade pip && \ ++ pip install -r requirements-build.txt ++ ++FROM cpu-test-arm AS build ++ ++WORKDIR /workspace/vllm ++ ++RUN --mount=type=cache,target=/root/.cache/pip \ ++ --mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \ ++ --mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \ ++ pip install -v -r requirements-cpu.txt ++ ++COPY . . ++ARG GIT_REPO_CHECK=0 ++RUN --mount=type=bind,source=.git,target=.git \ ++ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi ++ ++# Disabling AVX512 specific optimizations for ARM ++ARG VLLM_CPU_DISABLE_AVX512="true" ++ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512} ++ ++RUN --mount=type=cache,target=/root/.cache/pip \ ++ --mount=type=cache,target=/root/.cache/ccache \ ++ --mount=type=bind,source=.git,target=.git \ ++ VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \ ++ pip install dist/*.whl && \ ++ rm -rf dist ++ ++WORKDIR /workspace/ ++ ++RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks ++ ++ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] +\ No newline at end of file +diff --git a/Dockerfile.cpu b/Dockerfile.cpu +index 4251fdd..f163edc 100644 +--- a/Dockerfile.cpu ++++ b/Dockerfile.cpu +@@ -1,20 +1,69 @@ + # This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform. + +-FROM ubuntu:22.04 ++FROM ubuntu:22.04 AS cpu-test-1 + +-RUN apt-get update -y \ +- && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \ ++ENV CCACHE_DIR=/root/.cache/ccache ++ ++ENV CMAKE_CXX_COMPILER_LAUNCHER=ccache ++ ++RUN --mount=type=cache,target=/var/cache/apt \ ++ apt-get update -y \ ++ && apt-get install -y curl ccache git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \ ++ && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \ + && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 + +-RUN pip install --upgrade pip \ +- && pip install wheel packaging ninja setuptools>=49.4.0 numpy ++# https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html ++# intel-openmp provides additional performance improvement vs. openmp ++# tcmalloc provides better memory allocation efficiency, e.g, holding memory in caches to speed up access of commonly-used objects. ++RUN --mount=type=cache,target=/root/.cache/pip \ ++ pip install intel-openmp==2025.0.1 ++ ++ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so" ++ ++RUN echo 'ulimit -c 0' >> ~/.bashrc ++ ++RUN pip install intel_extension_for_pytorch==2.5.0 ++ ++WORKDIR /workspace + +-COPY ./ /workspace/vllm ++COPY requirements-build.txt requirements-build.txt ++ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" ++ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL} ++RUN --mount=type=cache,target=/root/.cache/pip \ ++ pip install --upgrade pip && \ ++ pip install -r requirements-build.txt ++ ++FROM cpu-test-1 AS build + + WORKDIR /workspace/vllm + +-RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu ++COPY requirements-common.txt requirements-common.txt ++COPY requirements-cpu.txt requirements-cpu.txt ++RUN --mount=type=cache,target=/root/.cache/pip \ ++ pip install -v -r requirements-cpu.txt ++ ++COPY . . ++ARG GIT_REPO_CHECK=0 ++RUN --mount=type=bind,source=.git,target=.git \ ++ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi ++ ++# Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ... ++ARG VLLM_CPU_DISABLE_AVX512 ++ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512} ++ ++RUN --mount=type=cache,target=/root/.cache/pip \ ++ --mount=type=cache,target=/root/.cache/ccache \ ++ --mount=type=bind,source=.git,target=.git \ ++ VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \ ++ pip install dist/*.whl && \ ++ rm -rf dist ++ ++WORKDIR /workspace/ ++ ++RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks + +-RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install ++# install development dependencies (for testing) ++RUN --mount=type=cache,target=/root/.cache/pip \ ++ pip install -e tests/vllm_test_utils + +-CMD ["/bin/bash"] ++ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] +diff --git a/Dockerfile.hpu b/Dockerfile.hpu +new file mode 100644 +index 0000000..87e0c1a +--- /dev/null ++++ b/Dockerfile.hpu +@@ -0,0 +1,21 @@ ++FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest ++ ++COPY ./ /workspace/vllm ++ ++WORKDIR /workspace/vllm ++ ++RUN pip install -v -r requirements-hpu.txt ++ ++ENV no_proxy=localhost,127.0.0.1 ++ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true ++ ++RUN VLLM_TARGET_DEVICE=hpu python3 setup.py install ++ ++# install development dependencies (for testing) ++RUN python3 -m pip install -e tests/vllm_test_utils ++ ++WORKDIR /workspace/ ++ ++RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks ++ ++ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] +diff --git a/Dockerfile.neuron b/Dockerfile.neuron +index fe42b4e..e9cb828 100644 +--- a/Dockerfile.neuron ++++ b/Dockerfile.neuron +@@ -1,36 +1,49 @@ + # default base image +-ARG BASE_IMAGE="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-neuronx:2.1.1-neuronx-py310-sdk2.17.0-ubuntu20.04" ++# https://gallery.ecr.aws/neuron/pytorch-inference-neuronx ++ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.5.1-neuronx-py310-sdk2.21.0-ubuntu22.04" + + FROM $BASE_IMAGE + + RUN echo "Base image is $BASE_IMAGE" + + # Install some basic utilities +-RUN apt-get update && apt-get install python3 python3-pip -y ++RUN apt-get update && \ ++ apt-get install -y \ ++ git \ ++ python3 \ ++ python3-pip \ ++ ffmpeg libsm6 libxext6 libgl1 + + ### Mount Point ### +-# When launching the container, mount the code directory to /app +-ARG APP_MOUNT=/app ++# When launching the container, mount the code directory to /workspace ++ARG APP_MOUNT=/workspace + VOLUME [ ${APP_MOUNT} ] +-WORKDIR ${APP_MOUNT} ++WORKDIR ${APP_MOUNT}/vllm + + RUN python3 -m pip install --upgrade pip + RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas +-RUN python3 -m pip install sentencepiece transformers==4.36.2 -U ++RUN python3 -m pip install sentencepiece transformers==4.45.2 -U + RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U +-RUN python3 -m pip install --pre neuronx-cc==2.12.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U ++RUN python3 -m pip install neuronx-cc==2.16.345.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com -U ++RUN python3 -m pip install pytest + +-COPY ./vllm /app/vllm/vllm +-COPY ./setup.py /app/vllm/setup.py +-COPY ./requirements-common.txt /app/vllm/requirements-common.txt +-COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt ++COPY . . ++ARG GIT_REPO_CHECK=0 ++RUN --mount=type=bind,source=.git,target=.git \ ++ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi + +-RUN cd /app/vllm \ +- && python3 -m pip install -U -r requirements-neuron.txt ++RUN python3 -m pip install -U \ ++ 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ ++ -r requirements-neuron.txt + +-ENV VLLM_BUILD_WITH_NEURON 1 +-RUN cd /app/vllm \ +- && pip install -e . \ +- && cd .. ++ENV VLLM_TARGET_DEVICE neuron ++RUN --mount=type=bind,source=.git,target=.git \ ++ pip install --no-build-isolation -v -e . ++ ++# install development dependencies (for testing) ++RUN python3 -m pip install -e tests/vllm_test_utils ++ ++# overwrite entrypoint to run bash script ++RUN echo "import subprocess; import sys; subprocess.check_call(sys.argv[1:])" > /usr/local/bin/dockerd-entrypoint.py + + CMD ["/bin/bash"] +diff --git a/Dockerfile.openvino b/Dockerfile.openvino +new file mode 100644 +index 0000000..32bcbfa +--- /dev/null ++++ b/Dockerfile.openvino +@@ -0,0 +1,29 @@ ++# The vLLM Dockerfile is used to construct vLLM image that can be directly used ++# to run the OpenAI compatible server. ++ ++FROM ubuntu:22.04 AS dev ++ ++RUN apt-get update -y && \ ++ apt-get install -y \ ++ git python3-pip \ ++ ffmpeg libsm6 libxext6 libgl1 ++WORKDIR /workspace ++ ++COPY . . ++ARG GIT_REPO_CHECK=0 ++RUN --mount=type=bind,source=.git,target=.git \ ++ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi ++ ++RUN python3 -m pip install -U pip ++# install build requirements ++RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/requirements-build.txt ++# build vLLM with OpenVINO backend ++RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace ++ ++COPY examples/ /workspace/examples ++COPY benchmarks/ /workspace/benchmarks ++ ++# install development dependencies (for testing) ++RUN python3 -m pip install -e tests/vllm_test_utils ++ ++CMD ["/bin/bash"] +diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le +new file mode 100644 +index 0000000..d3cd1c7 +--- /dev/null ++++ b/Dockerfile.ppc64le +@@ -0,0 +1,38 @@ ++FROM mambaorg/micromamba ++ARG MAMBA_DOCKERFILE_ACTIVATE=1 ++USER root ++ ++ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/" ++ ++RUN apt-get update -y && apt-get install -y git wget curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1 libssl-dev ++ ++# Some packages in requirements-cpu are installed here ++# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba ++# Currently these may not be available for venv or pip directly ++RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 torchvision-cpu=0.16.2 rust && micromamba clean --all --yes ++ ++COPY ./ /workspace/vllm ++ ++WORKDIR /workspace/vllm ++ARG GIT_REPO_CHECK=0 ++RUN --mount=type=bind,source=.git,target=.git \ ++ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi ++ ++RUN --mount=type=cache,target=/root/.cache/pip \ ++ RUSTFLAGS='-L /opt/conda/lib' pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \ ++ 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ ++ torch==2.3.1 \ ++ -r requirements-cpu.txt \ ++ xformers uvloop==0.20.0 ++ ++RUN --mount=type=bind,source=.git,target=.git \ ++ VLLM_TARGET_DEVICE=cpu python3 setup.py install ++ ++# install development dependencies (for testing) ++RUN python3 -m pip install -e tests/vllm_test_utils ++ ++WORKDIR /workspace/ ++ ++RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks ++ ++ENTRYPOINT ["/opt/conda/bin/python3", "-m", "vllm.entrypoints.openai.api_server"] +diff --git a/Dockerfile.rocm b/Dockerfile.rocm +index d04bb99..e733994 100644 +--- a/Dockerfile.rocm ++++ b/Dockerfile.rocm +@@ -1,35 +1,27 @@ +-# default base image +-ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" +- +-FROM $BASE_IMAGE +- +-ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" +- +-RUN echo "Base image is $BASE_IMAGE" +- +-# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" +-# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ++# Default ROCm 6.2 base image ++ARG BASE_IMAGE="rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0" + ++# Default ROCm ARCHes to build vLLM for. ++ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100" + ++# Whether to install CK-based flash-attention ++# If 0, will not install flash-attention ++ARG BUILD_FA="1" + ARG FA_GFX_ARCHS="gfx90a;gfx942" +-RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS" ++ARG FA_BRANCH="3cea2fb" + +-ARG FA_BRANCH="ae7928c" +-RUN echo "FA_BRANCH is $FA_BRANCH" ++# Whether to build triton on rocm ++ARG BUILD_TRITON="1" ++ARG TRITON_BRANCH="e192dba" + +-# whether to build flash-attention +-# if 0, will not build flash attention +-# this is useful for gfx target where flash-attention is not supported +-# In that case, we need to use the python reference attention implementation in vllm +-ARG BUILD_FA="1" ++### Base image build stage ++FROM $BASE_IMAGE AS base + +-# whether to build triton on rocm +-ARG BUILD_TRITON="1" ++# Import arg(s) defined before this build stage ++ARG PYTORCH_ROCM_ARCH + + # Install some basic utilities + RUN apt-get update && apt-get install python3 python3-pip -y +- +-# Install some basic utilities + RUN apt-get update && apt-get install -y \ + curl \ + ca-certificates \ +@@ -40,68 +32,143 @@ RUN apt-get update && apt-get install -y \ + build-essential \ + wget \ + unzip \ +- nvidia-cuda-toolkit \ + tmux \ ++ ccache \ + && rm -rf /var/lib/apt/lists/* + +-### Mount Point ### +-# When launching the container, mount the code directory to /app ++# When launching the container, mount the code directory to /vllm-workspace + ARG APP_MOUNT=/vllm-workspace +-VOLUME [ ${APP_MOUNT} ] + WORKDIR ${APP_MOUNT} + + RUN python3 -m pip install --upgrade pip +-RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas ++# Remove sccache so it doesn't interfere with ccache ++# TODO: implement sccache support across components ++RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)" ++ ++# Install torch == 2.6.0 on ROCm ++RUN --mount=type=cache,target=/root/.cache/pip \ ++ case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ ++ *"rocm-6.2"*) \ ++ python3 -m pip uninstall -y torch torchvision \ ++ && python3 -m pip install --pre \ ++ torch==2.6.0.dev20241113+rocm6.2 \ ++ 'setuptools-scm>=8' \ ++ torchvision==0.20.0.dev20241113+rocm6.2 \ ++ --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2;; \ ++ *) ;; esac + + ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer + ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin: + ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib: + ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/: + +-# Install ROCm flash-attention +-RUN if [ "$BUILD_FA" = "1" ]; then \ +- mkdir libs \ +- && cd libs \ +- && git clone https://github.com/ROCm/flash-attention.git \ +- && cd flash-attention \ +- && git checkout ${FA_BRANCH} \ +- && git submodule update --init \ +- && export GPU_ARCHS=${FA_GFX_ARCHS} \ +- && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \ +- patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \ +- && python3 setup.py install \ +- && cd ..; \ ++ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} ++ENV CCACHE_DIR=/root/.cache/ccache ++ ++ ++### AMD-SMI build stage ++FROM base AS build_amdsmi ++# Build amdsmi wheel always ++RUN cd /opt/rocm/share/amd_smi \ ++ && python3 -m pip wheel . --wheel-dir=/install ++ ++ ++### Flash-Attention wheel build stage ++FROM base AS build_fa ++ARG BUILD_FA ++ARG FA_GFX_ARCHS ++ARG FA_BRANCH ++# Build ROCm flash-attention wheel if `BUILD_FA = 1` ++RUN --mount=type=cache,target=${CCACHE_DIR} \ ++ if [ "$BUILD_FA" = "1" ]; then \ ++ mkdir -p libs \ ++ && cd libs \ ++ && git clone https://github.com/ROCm/flash-attention.git \ ++ && cd flash-attention \ ++ && git checkout "${FA_BRANCH}" \ ++ && git submodule update --init \ ++ && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \ ++ # Create an empty directory otherwise as later build stages expect one ++ else mkdir -p /install; \ + fi + +-# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt. +-# Manually removed it so that later steps of numpy upgrade can continue +-RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \ +- rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi + +-# build triton +-RUN if [ "$BUILD_TRITON" = "1" ]; then \ ++### Triton wheel build stage ++FROM base AS build_triton ++ARG BUILD_TRITON ++ARG TRITON_BRANCH ++# Build triton wheel if `BUILD_TRITON = 1` ++RUN --mount=type=cache,target=${CCACHE_DIR} \ ++ if [ "$BUILD_TRITON" = "1" ]; then \ + mkdir -p libs \ + && cd libs \ +- && pip uninstall -y triton \ +- && git clone https://github.com/ROCm/triton.git \ +- && cd triton/python \ +- && pip3 install . \ +- && cd ../..; \ ++ && python3 -m pip install ninja cmake wheel pybind11 \ ++ && git clone https://github.com/OpenAI/triton.git \ ++ && cd triton \ ++ && git checkout "${TRITON_BRANCH}" \ ++ && cd python \ ++ && python3 setup.py bdist_wheel --dist-dir=/install; \ ++ # Create an empty directory otherwise as later build stages expect one ++ else mkdir -p /install; \ + fi + +-WORKDIR /vllm-workspace ++ ++### Final vLLM build stage ++FROM base AS final ++# Import the vLLM development directory from the build context + COPY . . ++ARG GIT_REPO_CHECK=0 ++RUN --mount=type=bind,source=.git,target=.git \ ++ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi + +-RUN python3 -m pip install --upgrade pip numba ++RUN python3 -m pip install --upgrade pip + ++# Package upgrades for useful functionality or to avoid dependency issues + RUN --mount=type=cache,target=/root/.cache/pip \ +- pip install -U -r requirements-rocm.txt \ +- && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \ +- && python3 setup.py install \ +- && cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \ +- && cd .. ++ python3 -m pip install --upgrade numba scipy huggingface-hub[cli] pytest-shard + +-RUN python3 -m pip install --upgrade pip +-RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3 ++ ++# Workaround for ray >= 2.10.0 ++ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 ++# Silences the HF Tokenizers warning ++ENV TOKENIZERS_PARALLELISM=false ++ ++RUN --mount=type=cache,target=${CCACHE_DIR} \ ++ --mount=type=bind,source=.git,target=.git \ ++ --mount=type=cache,target=/root/.cache/pip \ ++ python3 -m pip install -Ur requirements-rocm.txt \ ++ && python3 setup.py clean --all \ ++ && python3 setup.py develop ++ ++# Copy amdsmi wheel into final image ++RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \ ++ mkdir -p libs \ ++ && cp /install/*.whl libs \ ++ # Preemptively uninstall to avoid same-version no-installs ++ && python3 -m pip uninstall -y amdsmi; ++ ++# Copy triton wheel(s) into final image if they were built ++RUN --mount=type=bind,from=build_triton,src=/install,target=/install \ ++ mkdir -p libs \ ++ && if ls /install/*.whl; then \ ++ cp /install/*.whl libs \ ++ # Preemptively uninstall to avoid same-version no-installs ++ && python3 -m pip uninstall -y triton; fi ++ ++# Copy flash-attn wheel(s) into final image if they were built ++RUN --mount=type=bind,from=build_fa,src=/install,target=/install \ ++ mkdir -p libs \ ++ && if ls /install/*.whl; then \ ++ cp /install/*.whl libs \ ++ # Preemptively uninstall to avoid same-version no-installs ++ && python3 -m pip uninstall -y flash-attn; fi ++ ++# Install wheels that were built to the final image ++RUN --mount=type=cache,target=/root/.cache/pip \ ++ if ls libs/*.whl; then \ ++ python3 -m pip install libs/*.whl; fi ++ ++# install development dependencies (for testing) ++RUN python3 -m pip install -e tests/vllm_test_utils + + CMD ["/bin/bash"] +diff --git a/Dockerfile.tpu b/Dockerfile.tpu +new file mode 100644 +index 0000000..b617932 +--- /dev/null ++++ b/Dockerfile.tpu +@@ -0,0 +1,28 @@ ++ARG NIGHTLY_DATE="20241017" ++ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" ++ ++FROM $BASE_IMAGE ++WORKDIR /workspace/vllm ++ ++# Install some basic utilities ++RUN apt-get update && apt-get install -y \ ++ git \ ++ ffmpeg libsm6 libxext6 libgl1 ++ ++# Build vLLM. ++COPY . . ++ARG GIT_REPO_CHECK=0 ++RUN --mount=type=bind,source=.git,target=.git \ ++ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi ++ ++ENV VLLM_TARGET_DEVICE="tpu" ++RUN --mount=type=cache,target=/root/.cache/pip \ ++ --mount=type=bind,source=.git,target=.git \ ++ python3 -m pip install \ ++ -r requirements-tpu.txt ++RUN python3 setup.py develop ++ ++# install development dependencies (for testing) ++RUN python3 -m pip install -e tests/vllm_test_utils ++ ++CMD ["/bin/bash"] +diff --git a/Dockerfile.xpu b/Dockerfile.xpu +new file mode 100644 +index 0000000..a374f20 +--- /dev/null ++++ b/Dockerfile.xpu +@@ -0,0 +1,69 @@ ++FROM intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04 AS vllm-base ++ ++RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \ ++ echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \ ++ chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \ ++ wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \ ++ echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \ ++ chmod 644 /usr/share/keyrings/intel-graphics.gpg ++ ++RUN apt-get update -y && \ ++ apt-get install -y --no-install-recommends --fix-missing \ ++ curl \ ++ ffmpeg \ ++ git \ ++ libsndfile1 \ ++ libsm6 \ ++ libxext6 \ ++ libgl1 \ ++ lsb-release \ ++ numactl \ ++ python3 \ ++ python3-dev \ ++ python3-pip \ ++ # vim \ ++ wget ++ ++WORKDIR /workspace/vllm ++COPY requirements-xpu.txt /workspace/vllm/requirements-xpu.txt ++COPY requirements-common.txt /workspace/vllm/requirements-common.txt ++ ++RUN --mount=type=cache,target=/root/.cache/pip \ ++ pip install --no-cache-dir \ ++ -r requirements-xpu.txt ++ ++RUN git clone https://github.com/intel/pti-gpu && \ ++ cd pti-gpu/sdk && \ ++ git checkout 6c491f07a777ed872c2654ca9942f1d0dde0a082 && \ ++ mkdir build && \ ++ cd build && \ ++ cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../cmake/toolchains/icpx_toolchain.cmake -DBUILD_TESTING=OFF .. && \ ++ make -j && \ ++ cmake --install . --config Release --prefix "/usr/local" ++ ++ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/" ++ ++COPY . . ++ARG GIT_REPO_CHECK ++RUN --mount=type=bind,source=.git,target=.git \ ++ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi ++ ++ENV VLLM_TARGET_DEVICE=xpu ++ ++RUN --mount=type=cache,target=/root/.cache/pip \ ++ --mount=type=bind,source=.git,target=.git \ ++ python3 setup.py install ++ ++CMD ["/bin/bash"] ++ ++FROM vllm-base AS vllm-openai ++ ++# install additional dependencies for openai api server ++RUN --mount=type=cache,target=/root/.cache/pip \ ++ pip install accelerate hf_transfer 'modelscope!=1.15.0' ++ ++ENV VLLM_USAGE_SOURCE production-docker-image \ ++ TRITON_XPU_PROFILE 1 ++# install development dependencies (for testing) ++RUN python3 -m pip install -e tests/vllm_test_utils ++ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] +diff --git a/README.md b/README.md +index 524d027..67c557b 100644 +--- a/README.md ++++ b/README.md +@@ -10,21 +10,28 @@ Easy, fast, and cheap LLM serving for everyone + + +

+-| Documentation | Blog | Paper | Discord | +- ++| Documentation | Blog | Paper | Discord | Twitter/X | Developer Slack | +

+ ++--- ++ ++The first vLLM meetup in 2025 is happening on January 22nd, Wednesday, with Google Cloud in San Francisco! We will talk about vLLM's performant V1 architecture, Q1 roadmap, Google Cloud's innovation around vLLM: networking, Cloud Run, Vertex, and TPU! [Register Now](https://lu.ma/zep56hui) ++ ++--- ++ + *Latest News* 🔥 ++- [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone! ++- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing). ++- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there! ++- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://www.youtube.com/playlist?list=PLzTswPQNepXl6AQwifuwUImLPFRVpksjR) from other vLLM contributors and users! ++- [2024/09] We hosted [the sixth vLLM meetup](https://lu.ma/87q3nvnh) with NVIDIA! Please find the meetup slides [here](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing). ++- [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing). ++- [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html). ++- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing). + - [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing). +-- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing). +-- [2024/01] Added ROCm 6.0 support to vLLM. +-- [2023/12] Added ROCm 5.7 support to vLLM. +-- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing). +-- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there. +-- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv! ++- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) with IBM! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing). ++- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) with a16z! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing). + - [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM. +-- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command! +-- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds. + - [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai). + + --- +@@ -34,77 +41,89 @@ vLLM is a fast and easy-to-use library for LLM inference and serving. + vLLM is fast with: + + - State-of-the-art serving throughput +-- Efficient management of attention key and value memory with **PagedAttention** ++- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html) + - Continuous batching of incoming requests + - Fast model execution with CUDA/HIP graph +-- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache +-- Optimized CUDA kernels ++- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), INT4, INT8, and FP8. ++- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer. ++- Speculative decoding ++- Chunked prefill ++ ++**Performance benchmark**: We include a performance benchmark at the end of [our blog post](https://blog.vllm.ai/2024/09/05/perf-update.html). It compares the performance of vLLM against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [SGLang](https://github.com/sgl-project/sglang) and [LMDeploy](https://github.com/InternLM/lmdeploy)). The implementation is under [nightly-benchmarks folder](.buildkite/nightly-benchmarks/) and you can [reproduce](https://github.com/vllm-project/vllm/issues/8176) this benchmark using our one-click runnable script. + + vLLM is flexible and easy to use with: + + - Seamless integration with popular Hugging Face models + - High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more +-- Tensor parallelism support for distributed inference ++- Tensor parallelism and pipeline parallelism support for distributed inference + - Streaming outputs + - OpenAI-compatible API server +-- Support NVIDIA GPUs and AMD GPUs +-- (Experimental) Prefix caching support +-- (Experimental) Multi-lora support +- +-vLLM seamlessly supports many Hugging Face models, including the following architectures: +- +-- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.) +-- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.) +-- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) +-- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.) +-- Command-R (`CohereForAI/c4ai-command-r-v01`, etc.) +-- DBRX (`databricks/dbrx-base`, `databricks/dbrx-instruct` etc.) +-- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.) +-- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.) +-- Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.) +-- GPT-2 (`gpt2`, `gpt2-xl`, etc.) +-- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) +-- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.) +-- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) +-- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.) +-- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.) +-- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.) +-- LLaMA, Llama 2, and Meta Llama 3 (`meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) +-- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.) +-- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) +-- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.) +-- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) +-- OLMo (`allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc.) +-- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) +-- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.) +-- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) +-- Phi-3 (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, etc.) +-- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) +-- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.) +-- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.) +-- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.) +-- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.) +-- Xverse (`xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.) +-- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.) +- +-Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): ++- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron. ++- Prefix caching support ++- Multi-lora support ++ ++vLLM seamlessly supports most popular open-source models on HuggingFace, including: ++- Transformer-like LLMs (e.g., Llama) ++- Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3) ++- Embedding Models (e.g. E5-Mistral) ++- Multi-modal LLMs (e.g., LLaVA) ++ ++Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html). ++ ++## Getting Started ++ ++Install vLLM with `pip` or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): + + ```bash + pip install vllm + ``` + +-## Getting Started +- +-Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started. ++Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more. + - [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) + - [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html) +-- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html) ++- [List of Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html) + + ## Contributing + + We welcome and value any contributions and collaborations. + Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved. + ++## Sponsors ++ ++vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support! ++ ++ ++ ++Cash Donations: ++- a16z ++- Dropbox ++- Sequoia Capital ++- Skywork AI ++- ZhenFund ++ ++Compute Resources: ++- AMD ++- Anyscale ++- AWS ++- Crusoe Cloud ++- Databricks ++- DeepInfra ++- Google Cloud ++- Lambda Lab ++- Nebius ++- Novita AI ++- NVIDIA ++- Replicate ++- Roblox ++- RunPod ++- Trainy ++- UC Berkeley ++- UC San Diego ++ ++Slack Sponsor: Anyscale ++ ++We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM. ++ + ## Citation + + If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180): +@@ -116,3 +135,15 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs + year={2023} + } + ``` ++ ++## Contact Us ++ ++* For technical questions and feature requests, please use Github issues or discussions. ++* For discussing with fellow users, please use Discord. ++* For coordinating contributions and development, please use Slack. ++* For security disclosures, please use Github's security advisory feature. ++* For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu. ++ ++## Media Kit ++ ++* If you wish to use vLLM's logo, please refer to [our media kit repo](https://github.com/vllm-project/media-kit). +diff --git a/SECURITY.md b/SECURITY.md +new file mode 100644 +index 0000000..de0032d +--- /dev/null ++++ b/SECURITY.md +@@ -0,0 +1,11 @@ ++# Security Policy ++ ++## Reporting a Vulnerability ++ ++If you believe you have found a security vulnerability in vLLM, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem. ++ ++Please report security issues privately using [the vulnerability submission form](https://github.com/vllm-project/vllm/security/advisories/new). Reports will then be triaged by the [vulnerability management team](https://docs.vllm.ai/contributing/vulnerability_management/). ++ ++--- ++ ++Please see [PyTorch's Security Policy](https://github.com/pytorch/pytorch/blob/main/SECURITY.md) for more information and recommendations on how to securely interact with models. +diff --git a/benchmarks/README.md b/benchmarks/README.md +index 192d6c4..2aa4a28 100644 +--- a/benchmarks/README.md ++++ b/benchmarks/README.md +@@ -6,3 +6,14 @@ You can download the dataset by running: + ```bash + wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + ``` ++ ++## Downloading the ShareGPT4V dataset ++ ++The json file refers to several image datasets (coco, llava, etc.). The benchmark scripts ++will ignore a datapoint if the referred image is missing. ++```bash ++wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/resolve/main/sharegpt4v_instruct_gpt4-vision_cap100k.json ++mkdir coco -p ++wget http://images.cocodataset.org/zips/train2017.zip -O coco/train2017.zip ++unzip coco/train2017.zip -d coco/ ++``` +diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py +index f9d1675..b678490 100644 +--- a/benchmarks/backend_request_func.py ++++ b/benchmarks/backend_request_func.py +@@ -4,10 +4,13 @@ import sys + import time + import traceback + from dataclasses import dataclass, field +-from typing import List, Optional ++from typing import List, Optional, Union + + import aiohttp ++import huggingface_hub.constants + from tqdm.asyncio import tqdm ++from transformers import (AutoTokenizer, PreTrainedTokenizer, ++ PreTrainedTokenizerFast) + + AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + +@@ -20,7 +23,10 @@ class RequestFuncInput: + output_len: int + model: str + best_of: int = 1 +- use_beam_search: bool = False ++ logprobs: Optional[int] = None ++ extra_body: Optional[dict] = None ++ multi_modal_content: Optional[dict] = None ++ ignore_eos: bool = False + + + @dataclass +@@ -31,6 +37,7 @@ class RequestFuncOutput: + ttft: float = 0.0 # Time to first token + itl: List[float] = field( + default_factory=list) # List of inter-token latencies ++ tpot: float = 0.0 # avg next-token latencies + prompt_len: int = 0 + error: str = "" + +@@ -43,13 +50,14 @@ async def async_request_tgi( + assert api_url.endswith("generate_stream") + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: +- assert not request_func_input.use_beam_search + params = { + "best_of": request_func_input.best_of, + "max_new_tokens": request_func_input.output_len, + "do_sample": True, + "temperature": 0.01, # TGI does not accept 0.0 temperature. + "top_p": 0.99, # TGI does not accept 1.0 top_p. ++ "truncate": request_func_input.prompt_len, ++ # TGI does not accept ignore_eos flag. + } + payload = { + "inputs": request_func_input.prompt, +@@ -68,9 +76,13 @@ async def async_request_tgi( + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue ++ chunk_bytes = chunk_bytes.decode("utf-8") + +- chunk = remove_prefix(chunk_bytes.decode("utf-8"), +- "data:") ++ #NOTE: Sometimes TGI returns a ping response without ++ # any data, we should skip it. ++ if chunk_bytes.startswith(":"): ++ continue ++ chunk = chunk_bytes.removeprefix("data:") + + data = json.loads(chunk) + timestamp = time.perf_counter() +@@ -89,6 +101,9 @@ async def async_request_tgi( + output.latency = most_recent_timestamp - st + output.success = True + output.generated_text = data["generated_text"] ++ else: ++ output.error = response.reason or "" ++ output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() +@@ -107,7 +122,6 @@ async def async_request_trt_llm( + assert api_url.endswith("generate_stream") + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: +- assert not request_func_input.use_beam_search + assert request_func_input.best_of == 1 + payload = { + "accumulate_tokens": True, +@@ -117,6 +131,8 @@ async def async_request_trt_llm( + "max_tokens": request_func_input.output_len, + "stream": True, + } ++ if request_func_input.ignore_eos: ++ payload["min_length"] = request_func_input.output_len + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + +@@ -131,8 +147,8 @@ async def async_request_trt_llm( + if not chunk_bytes: + continue + +- chunk = remove_prefix(chunk_bytes.decode("utf-8"), +- "data:") ++ chunk = chunk_bytes.decode("utf-8").removeprefix( ++ "data:") + + data = json.loads(chunk) + output.generated_text += data["text_output"] +@@ -171,7 +187,6 @@ async def async_request_deepspeed_mii( + ) -> RequestFuncOutput: + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + assert request_func_input.best_of == 1 +- assert not request_func_input.use_beam_search + + payload = { + "prompt": request_func_input.prompt, +@@ -215,19 +230,22 @@ async def async_request_openai_completions( + ) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith( +- "v1/completions" +- ), "OpenAI Completions API URL must end with 'v1/completions'." ++ ("completions", "profile") ++ ), "OpenAI Completions API URL must end with 'completions' or 'profile'." + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: +- assert not request_func_input.use_beam_search + payload = { + "model": request_func_input.model, + "prompt": request_func_input.prompt, + "temperature": 0.0, + "best_of": request_func_input.best_of, + "max_tokens": request_func_input.output_len, ++ "logprobs": request_func_input.logprobs, + "stream": True, ++ "ignore_eos": request_func_input.ignore_eos, + } ++ if request_func_input.extra_body: ++ payload.update(request_func_input.extra_body) + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } +@@ -243,39 +261,49 @@ async def async_request_openai_completions( + async with session.post(url=api_url, json=payload, + headers=headers) as response: + if response.status == 200: ++ first_chunk_received = False + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + +- chunk = remove_prefix(chunk_bytes.decode("utf-8"), +- "data: ") ++ chunk = chunk_bytes.decode("utf-8").removeprefix( ++ "data: ") + if chunk == "[DONE]": + latency = time.perf_counter() - st + else: + data = json.loads(chunk) + ++ # NOTE: Some completion API might have a last ++ # usage summary response without a token so we ++ # want to check a token was generated + if data["choices"][0]["text"]: + timestamp = time.perf_counter() + # First token +- if ttft == 0.0: ++ if not first_chunk_received: ++ first_chunk_received = True + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase +- # NOTE: Some completion API might have a last +- # usage summary response without a token so we +- # do not want to include as inter-token-latency +- elif data.get("usage", None) is None: ++ else: + output.itl.append(timestamp - + most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += data["choices"][0]["text"] +- ++ if first_chunk_received: ++ output.success = True ++ else: ++ output.success = False ++ output.error = ( ++ "Never received a valid chunk to calculate TTFT." ++ "This response will be marked as failed!") + output.generated_text = generated_text +- output.success = True + output.latency = latency ++ else: ++ output.error = response.reason or "" ++ output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() +@@ -292,23 +320,28 @@ async def async_request_openai_chat_completions( + ) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith( +- "v1/chat/completions" +- ), "OpenAI Chat Completions API URL must end with 'v1/chat/completions'." ++ "chat/completions" ++ ), "OpenAI Chat Completions API URL must end with 'chat/completions'." + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: +- assert not request_func_input.use_beam_search ++ content = [{"type": "text", "text": request_func_input.prompt}] ++ if request_func_input.multi_modal_content: ++ content.append(request_func_input.multi_modal_content) + payload = { + "model": request_func_input.model, + "messages": [ + { + "role": "user", +- "content": request_func_input.prompt, ++ "content": content + }, + ], + "temperature": 0.0, +- "max_tokens": request_func_input.output_len, ++ "max_completion_tokens": request_func_input.output_len, + "stream": True, ++ "ignore_eos": request_func_input.ignore_eos, + } ++ if request_func_input.extra_body: ++ payload.update(request_func_input.extra_body) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", +@@ -330,8 +363,8 @@ async def async_request_openai_chat_completions( + if not chunk_bytes: + continue + +- chunk = remove_prefix(chunk_bytes.decode("utf-8"), +- "data: ") ++ chunk = chunk_bytes.decode("utf-8").removeprefix( ++ "data: ") + if chunk == "[DONE]": + latency = time.perf_counter() - st + else: +@@ -370,12 +403,28 @@ async def async_request_openai_chat_completions( + return output + + +-# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix) +-# introduced in Python 3.9 +-def remove_prefix(text: str, prefix: str) -> str: +- if text.startswith(prefix): +- return text[len(prefix):] +- return text ++def get_model(pretrained_model_name_or_path: str) -> str: ++ if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': ++ from modelscope import snapshot_download ++ ++ model_path = snapshot_download( ++ model_id=pretrained_model_name_or_path, ++ local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, ++ ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) ++ ++ return model_path ++ return pretrained_model_name_or_path ++ ++ ++def get_tokenizer( ++ pretrained_model_name_or_path: str, trust_remote_code: bool ++) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ++ if pretrained_model_name_or_path is not None and not os.path.exists( ++ pretrained_model_name_or_path): ++ pretrained_model_name_or_path = get_model( ++ pretrained_model_name_or_path) ++ return AutoTokenizer.from_pretrained(pretrained_model_name_or_path, ++ trust_remote_code=trust_remote_code) + + + ASYNC_REQUEST_FUNCS = { +@@ -386,4 +435,6 @@ ASYNC_REQUEST_FUNCS = { + "openai": async_request_openai_completions, + "openai-chat": async_request_openai_chat_completions, + "tensorrt-llm": async_request_trt_llm, ++ "scalellm": async_request_openai_completions, ++ "sglang": async_request_openai_completions, + } +diff --git a/benchmarks/benchmark_guided.py b/benchmarks/benchmark_guided.py +new file mode 100644 +index 0000000..1a0e625 +--- /dev/null ++++ b/benchmarks/benchmark_guided.py +@@ -0,0 +1,494 @@ ++"""Benchmark guided decoding throughput.""" ++import argparse ++import dataclasses ++import json ++import os ++import random ++import time ++from typing import List ++ ++import datasets ++import pandas as pd ++import uvloop ++from transformers import AutoTokenizer, PreTrainedTokenizerBase ++ ++from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs ++from vllm.entrypoints.openai.api_server import ( ++ build_async_engine_client_from_engine_args) ++from vllm.sampling_params import GuidedDecodingParams ++from vllm.utils import FlexibleArgumentParser, merge_async_iterators ++ ++ ++@dataclasses.dataclass ++class SampleRequest: ++ """A class representing a single inference request for benchmarking. ++ ++ Attributes: ++ prompt: The input text prompt for the model. ++ multi_modal_data: Optional dictionary containing multi-modal data (e.g. ++ images). ++ prompt_len: The length of the prompt in tokens. ++ expected_output_len: The expected length of the output in tokens. ++ """ ++ prompt: str ++ prompt_len: int ++ expected_output_len: int ++ schema: dict ++ structure_type: str = 'json' ++ completion: str = None ++ ++ ++def run_vllm(requests: List[SampleRequest], ++ engine_args: EngineArgs, ++ n: int, ++ guided_decoding_rate: float = 1.0, ++ warmup: bool = False) -> float: ++ from vllm import LLM, SamplingParams ++ llm = LLM(**vars(engine_args)) ++ ++ # Add the requests to the engine. ++ prompts: List[str] = [] ++ sampling_params: List[SamplingParams] = [] ++ # create a list containing random selected true or false ++ guided_decoding_req_idx = random.sample( ++ range(len(requests)), int(len(requests) * guided_decoding_rate)) ++ ++ if warmup: ++ print(">>>>> Running warmup prompt, for the first 5") ++ # We setup the first 5 requests to warmup FSM ++ # if using xgrammar dataset, we will skip warmup ++ warmup_requests = requests[:5] ++ for i, request in enumerate(warmup_requests): ++ prompts.append(request.prompt) ++ sampling_params.append( ++ SamplingParams( ++ n=n, ++ temperature=1.0, ++ top_p=1.0, ++ ignore_eos=True, ++ max_tokens=request.expected_output_len, ++ guided_decoding=GuidedDecodingParams(json=request.schema) ++ if guided_decoding_rate > 0 else None, ++ )) ++ llm.generate(prompts, sampling_params, use_tqdm=False) ++ ++ print(">>>>> Benchmark started...") ++ prompts = [] ++ sampling_params = [] ++ for i, request in enumerate(requests): ++ prompts.append(request.prompt) ++ sampling_params.append( ++ SamplingParams( ++ n=n, ++ temperature=1.0, ++ top_p=1.0, ++ ignore_eos=True, ++ max_tokens=request.expected_output_len, ++ guided_decoding=GuidedDecodingParams( ++ **{request.structure_type: request.schema}) ++ if i in guided_decoding_req_idx else None, ++ )) ++ ++ start = time.perf_counter() ++ outputs = llm.generate(prompts, sampling_params, use_tqdm=False) ++ ret = [] ++ for output, request in zip(outputs, requests): ++ generated_text = output.outputs[0].text ++ ret.append({ ++ "generated": generated_text, ++ "expected": request.completion ++ }) ++ end = time.perf_counter() ++ return end - start, ret ++ ++ ++async def run_vllm_async( ++ requests: List[SampleRequest], ++ engine_args: AsyncEngineArgs, ++ n: int, ++ guided_decoding_rate: float = 1.0, ++ warmup: bool = False, ++ disable_frontend_multiprocessing: bool = False) -> float: ++ from vllm import SamplingParams ++ ++ async with build_async_engine_client_from_engine_args( ++ engine_args, disable_frontend_multiprocessing) as llm: ++ ++ # Add the requests to the engine. ++ prompts: List[str] = [] ++ sampling_params: List[SamplingParams] = [] ++ guided_decoding_req_idx = random.sample( ++ range(len(requests)), int(len(requests) * guided_decoding_rate)) ++ ++ if warmup: ++ print(">>>>>> Running warmup prompt, for the first 5") ++ # We setup the first 5 requests to warmup FSM ++ # if using xgrammar dataset, we will skip warmup ++ warmup_requests = requests[:5] ++ for i, request in enumerate(warmup_requests): ++ prompts.append(request.prompt) ++ sampling_params.append( ++ SamplingParams( ++ n=n, ++ temperature=1.0, ++ top_p=1.0, ++ ignore_eos=True, ++ max_tokens=request.expected_output_len, ++ guided_decoding=GuidedDecodingParams( ++ json=request.schema) ++ if guided_decoding_rate > 0 else None, ++ )) ++ generators = [] ++ for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): ++ generator = llm.generate(prompt, sp, request_id=f"test{i}") ++ generators.append(generator) ++ all_gens = merge_async_iterators(*generators) ++ async for i, res in all_gens: ++ pass ++ ++ print(">>>>> Benchmark started...") ++ prompts = [] ++ sampling_params = [] ++ for i, request in enumerate(requests): ++ prompts.append(request.prompt) ++ sampling_params.append( ++ SamplingParams( ++ n=n, ++ temperature=1.0, ++ top_p=1.0, ++ ignore_eos=True, ++ max_tokens=request.expected_output_len, ++ guided_decoding=GuidedDecodingParams(json=request.schema) ++ if i in guided_decoding_req_idx else None, ++ )) ++ ++ generators = [] ++ start_time = [] ++ latencies = [] ++ start = time.perf_counter() ++ for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): ++ generator = llm.generate(prompt, sp, request_id=f"test{i}") ++ generators.append(generator) ++ start_time.append(time.perf_counter()) ++ latencies.append([]) ++ all_gens = merge_async_iterators(*generators) ++ generated_texts = [''] * len(requests) ++ async for i, res in all_gens: ++ generated_texts[i] = res.outputs[0].text ++ lat = time.perf_counter() - start_time[i] ++ latencies[i].append(lat) ++ ret = [{ ++ 'generated': gt, ++ 'expected': req.completion ++ } for gt, req in zip(generated_texts, requests)] ++ end = time.perf_counter() ++ first_latency = pd.Series([lat[0] * 1000 for lat in latencies]) ++ next_latency = pd.Series([(lat[-1] - lat[0]) / len(lat[1:]) * 1000 ++ for lat in latencies]) ++ return end - start, ret, (first_latency, next_latency) ++ ++ ++def sample_requests(tokenizer: PreTrainedTokenizerBase, ++ args: argparse.Namespace) -> List[SampleRequest]: ++ if args.dataset == 'json': ++ if args.json_schema_path is None: ++ dir_path = os.path.dirname(os.path.realpath(__file__)) ++ args.json_schema_path = os.path.join(dir_path, ++ "structured_schemas", ++ "structured_schema_1.json") ++ with open(args.json_schema_path) as f: ++ schema = json.load(f) ++ prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501 ++ input_len = len(tokenizer(prompt).input_ids) ++ print(f"Input length of the prompt: {input_len} tokens") ++ requests = [ ++ SampleRequest(prompt=prompt, ++ prompt_len=input_len, ++ expected_output_len=args.output_len, ++ schema=schema, ++ structure_type=args.structure_type) ++ for _ in range(args.num_prompts) ++ ] ++ ++ elif args.dataset == "grammar": ++ schema = """ ++ ?start: select_statement ++ ++ ?select_statement: "SELECT " column_list " FROM " table_name ++ ++ ?column_list: column_name ("," column_name)* ++ ++ ?table_name: identifier ++ ++ ?column_name: identifier ++ ++ ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ ++ """ ++ prompt = "Generate an SQL query to show the 'username' \ ++ and 'email' from the 'users' table." ++ ++ input_len = len(tokenizer(prompt).input_ids) ++ print(f"Input length of the prompt: {input_len} tokens") ++ requests = [ ++ SampleRequest(prompt=prompt, ++ prompt_len=input_len, ++ expected_output_len=args.output_len, ++ schema=schema, ++ structure_type=args.structure_type) ++ for _ in range(args.num_prompts) ++ ] ++ ++ elif args.dataset == "regex": ++ regex = r"\w+@\w+\.com\n" ++ args.regex = regex ++ prompt = "Generate an email address for Alan Turing, \ ++ who works in Enigma. End in .com and new line. \ ++ Example result: alan.turing@enigma.com\n" ++ ++ input_len = len(tokenizer(prompt).input_ids) ++ print(f"Input length of the prompt: {input_len} tokens") ++ requests = [ ++ SampleRequest(prompt=prompt, ++ prompt_len=input_len, ++ expected_output_len=args.output_len, ++ schema=regex, ++ structure_type=args.structure_type) ++ for _ in range(args.num_prompts) ++ ] ++ ++ elif args.dataset == "choice": ++ choice = ["Positive", "Negative"] ++ args.choice = choice ++ prompt = "Classify this sentiment: vLLM is wonderful!" ++ input_len = len(tokenizer(prompt).input_ids) ++ print(f"Input length of the prompt: {input_len} tokens") ++ requests = [ ++ SampleRequest(prompt=prompt, ++ prompt_len=input_len, ++ expected_output_len=args.output_len, ++ schema=choice, ++ structure_type=args.structure_type) ++ for _ in range(args.num_prompts) ++ ] ++ ++ elif args.dataset == "xgrammar_bench": ++ args.warmup = False ++ requests: List[SampleRequest] = [] ++ dataset = datasets.load_dataset("NousResearch/json-mode-eval", ++ split="train") ++ print(f"dataset has {len(dataset)} entries") ++ len_dataset = len(dataset) ++ for data_point_idx in range(args.num_prompts): ++ idx = data_point_idx ++ while idx >= len_dataset: ++ idx -= len_dataset ++ schema = dataset["schema"][idx] ++ prompt = tokenizer.apply_chat_template(dataset["prompt"][idx], ++ tokenize=False) ++ input_len = len(tokenizer(prompt).input_ids) ++ completion = dataset["completion"][idx] ++ ++ requests.append( ++ SampleRequest(prompt=prompt, ++ prompt_len=input_len, ++ expected_output_len=args.output_len, ++ schema=schema, ++ completion=completion)) ++ ++ return requests ++ ++ ++def evaluate(ret, args): ++ ++ def _eval_correctness_json(expected, actual): ++ # extract json string from string using regex ++ import re ++ actual = actual.replace('\n', '').replace(' ', '').strip() ++ try: ++ actual = re.search(r'\{.*\}', actual).group() ++ actual = json.loads(actual) ++ except Exception: ++ return False ++ ++ return True ++ ++ def _eval_correctness_choice(expected, actual): ++ return actual in args.choice ++ ++ def _eval_correctness_regex(expected, actual): ++ import re ++ return re.match(args.regex, actual) is not None ++ ++ def _eval_correctness(expected, actual): ++ if args.structure_type == 'json': ++ return _eval_correctness_json(expected, actual) ++ elif args.structure_type == 'regex': ++ return _eval_correctness_regex(expected, actual) ++ elif args.structure_type == 'choice': ++ return _eval_correctness_choice(expected, actual) ++ else: ++ return None ++ ++ scores = [] ++ for res in ret: ++ score = _eval_correctness(res['expected'], res['generated']) ++ res['correctness'] = score ++ scores.append(score) ++ ++ not_none_scores = [score for score in scores if score is not None] ++ ++ return (sum(not_none_scores) / len(not_none_scores) * ++ 100) if len(not_none_scores) > 0 else None ++ ++ ++def main(args: argparse.Namespace): ++ print(args) ++ random.seed(args.seed) ++ ++ # async engine is working for 'regex', 'choice' and 'grammar' ++ if args.dataset == 'grammar': ++ args.structure_type = 'grammar' ++ args.async_engine = False ++ elif args.dataset == 'regex': ++ args.structure_type = 'regex' ++ args.async_engine = False ++ elif args.dataset == 'choice': ++ args.structure_type = 'choice' ++ args.async_engine = False ++ else: ++ args.structure_type = 'json' ++ ++ if args.no_guided_decoding: ++ args.guided_decoding_ratio = 0 ++ if args.save_results: ++ result_file_name = f'{args.guided_decoding_ratio}guided' ++ result_file_name += f"_{args.model.split('/')[-1]}" ++ result_file_name += f"_{args.dataset}" ++ result_file_name += f"_{args.num_prompts}" ++ result_file_name += f"_out{args.output_len}" ++ result_file_name += f"_async{args.async_engine}" ++ result_file_name += f"_warmup{args.warmup}" ++ result_file_name += f"_chunkedprefill{args.enable_chunked_prefill}" ++ result_file_name += ".txt" ++ else: ++ result_file_name = None ++ ++ # Synthesize a prompt with the given input length. ++ tokenizer = AutoTokenizer.from_pretrained( ++ args.tokenizer, trust_remote_code=args.trust_remote_code) ++ requests = sample_requests(tokenizer, args) ++ ++ if args.async_engine: ++ engine_args = AsyncEngineArgs.from_cli_args(args) ++ elapsed_time, ret, (first_latency, next_latency) = uvloop.run( ++ run_vllm_async(requests, engine_args, args.n, ++ args.guided_decoding_ratio, args.warmup, ++ args.disable_frontend_multiprocessing)) ++ else: ++ engine_args = EngineArgs.from_cli_args(args) ++ elapsed_time, ret = run_vllm(requests, engine_args, args.n, ++ args.guided_decoding_ratio, args.warmup) ++ first_latency, next_latency = None, None ++ ++ score = evaluate(ret, args) ++ total_num_tokens = sum(request.prompt_len + request.expected_output_len ++ for request in requests) ++ total_output_tokens = sum(request.expected_output_len ++ for request in requests) ++ if first_latency is not None: ++ latency_breakdown = "\nFirst token latency(msecs):\n" ++ latency_breakdown += f"{first_latency.describe()}" ++ latency_breakdown += "\nNext token latency(msecs):\n" ++ latency_breakdown += f"{next_latency.describe()}" ++ print( ++ f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " ++ f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " ++ f"{total_output_tokens / elapsed_time:.2f} output tokens/s", ++ f"Correct rate is {score} %", ++ f"{latency_breakdown if first_latency is not None else ''}") ++ ++ # Output JSON results if specified ++ if args.output_json or result_file_name: ++ results = { ++ "elapsed_time": elapsed_time, ++ "num_requests": len(requests), ++ "total_num_tokens": total_num_tokens, ++ "total_output_tokens": total_output_tokens, ++ "requests_per_second": len(requests) / elapsed_time, ++ "tokens_per_second": f"{total_num_tokens / elapsed_time:.2f}", ++ "output_tokens_per_second": ++ f"{total_output_tokens / elapsed_time:.2f}", ++ "correct_rate(%)": score ++ } ++ results = {"outputs": ret, **results} ++ if first_latency is not None: ++ results["first_token_latency(msecs)"] = first_latency.describe( ++ ).to_dict() ++ results["next_token_latency(msecs)"] = next_latency.describe( ++ ).to_dict() ++ if args.output_json: ++ with open(args.output_json, "w") as f: ++ json.dump(results, f, indent=4) ++ elif result_file_name: ++ with open(result_file_name, "w") as f: ++ json.dump(results, f, indent=4) ++ ++ ++if __name__ == "__main__": ++ parser = FlexibleArgumentParser(description="Benchmark guided decoding.") ++ parser = AsyncEngineArgs.add_cli_args(parser) ++ ++ parser.add_argument("--output-len", ++ type=int, ++ default=512, ++ help="Output length for each request. Overrides the " ++ "output length from the dataset.") ++ parser.add_argument( ++ "--dataset", ++ default='json', ++ choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench']) ++ parser.add_argument("--json_schema_path", ++ type=str, ++ default=None, ++ help="Path to json schema.") ++ parser.add_argument("--n", ++ type=int, ++ default=1, ++ help="Number of generated sequences per prompt.") ++ parser.add_argument("--num-prompts", ++ type=int, ++ default=10, ++ help="Number of prompts to process.") ++ parser.add_argument( ++ '--output-json', ++ type=str, ++ default=None, ++ help='Path to save the throughput results in JSON format.') ++ parser.add_argument("--async-engine", ++ action='store_true', ++ default=False, ++ help="Use vLLM async engine rather than LLM class.") ++ parser.add_argument("--no-guided-decoding", ++ action='store_true', ++ default=False, ++ help="Whether to disable JSON decoding or not.") ++ parser.add_argument("--guided-decoding-ratio", ++ type=float, ++ default=1.0, ++ help="Ratio of Guided Decoding requests") ++ parser.add_argument("--disable-frontend-multiprocessing", ++ action='store_true', ++ default=False, ++ help="Disable decoupled async engine frontend.") ++ parser.add_argument("--warmup", ++ action="store_true", ++ default=False, ++ help="Run warmup prompts before benchmark.") ++ parser.add_argument("--save-results", ++ action="store_true", ++ default=False, ++ help="save output results.") ++ args = parser.parse_args() ++ if args.tokenizer is None: ++ args.tokenizer = args.model ++ main(args) +diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py +index 44da3ba..77c4f6a 100644 +--- a/benchmarks/benchmark_latency.py ++++ b/benchmarks/benchmark_latency.py +@@ -1,42 +1,35 @@ + """Benchmark the latency of processing a single batch of requests.""" + import argparse ++import dataclasses ++import json + import time + from pathlib import Path +-from typing import Optional ++from typing import List, Optional + + import numpy as np + import torch + from tqdm import tqdm + + from vllm import LLM, SamplingParams +-from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS ++from vllm.engine.arg_utils import EngineArgs ++from vllm.inputs import PromptType ++from vllm.sampling_params import BeamSearchParams ++from vllm.utils import FlexibleArgumentParser + + + def main(args: argparse.Namespace): + print(args) + ++ engine_args = EngineArgs.from_cli_args(args) ++ + # NOTE(woosuk): If the request cannot be processed in a single batch, + # the engine will automatically process the request in multiple batches. +- llm = LLM(model=args.model, +- tokenizer=args.tokenizer, +- quantization=args.quantization, +- tensor_parallel_size=args.tensor_parallel_size, +- trust_remote_code=args.trust_remote_code, +- dtype=args.dtype, +- enforce_eager=args.enforce_eager, +- kv_cache_dtype=args.kv_cache_dtype, +- quantization_param_path=args.quantization_param_path, +- device=args.device, +- ray_workers_use_nsight=args.ray_workers_use_nsight, +- enable_chunked_prefill=args.enable_chunked_prefill, +- download_dir=args.download_dir, +- block_size=args.block_size) ++ llm = LLM(**dataclasses.asdict(engine_args)) + + sampling_params = SamplingParams( + n=args.n, +- temperature=0.0 if args.use_beam_search else 1.0, ++ temperature=1.0, + top_p=1.0, +- use_beam_search=args.use_beam_search, + ignore_eos=True, + max_tokens=args.output_len, + ) +@@ -44,7 +37,23 @@ def main(args: argparse.Namespace): + dummy_prompt_token_ids = np.random.randint(10000, + size=(args.batch_size, + args.input_len)) +- dummy_prompt_token_ids = dummy_prompt_token_ids.tolist() ++ dummy_prompts: List[PromptType] = [{ ++ "prompt_token_ids": batch ++ } for batch in dummy_prompt_token_ids.tolist()] ++ ++ def llm_generate(): ++ if not args.use_beam_search: ++ llm.generate(dummy_prompts, ++ sampling_params=sampling_params, ++ use_tqdm=False) ++ else: ++ llm.beam_search( ++ dummy_prompts, ++ BeamSearchParams( ++ beam_width=args.n, ++ max_tokens=args.output_len, ++ ignore_eos=True, ++ )) + + def run_to_completion(profile_dir: Optional[str] = None): + if profile_dir: +@@ -55,15 +64,11 @@ def main(args: argparse.Namespace): + ], + on_trace_ready=torch.profiler.tensorboard_trace_handler( + str(profile_dir))) as p: +- llm.generate(prompt_token_ids=dummy_prompt_token_ids, +- sampling_params=sampling_params, +- use_tqdm=False) +- print(p.key_averages()) ++ llm_generate() ++ print(p.key_averages().table(sort_by="self_cuda_time_total")) + else: + start_time = time.perf_counter() +- llm.generate(prompt_token_ids=dummy_prompt_token_ids, +- sampling_params=sampling_params, +- use_tqdm=False) ++ llm_generate() + end_time = time.perf_counter() + latency = end_time - start_time + return latency +@@ -87,24 +92,27 @@ def main(args: argparse.Namespace): + for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): + latencies.append(run_to_completion(profile_dir=None)) + latencies = np.array(latencies) +- percentages = [10, 25, 50, 75, 90] ++ percentages = [10, 25, 50, 75, 90, 99] + percentiles = np.percentile(latencies, percentages) + print(f'Avg latency: {np.mean(latencies)} seconds') + for percentage, percentile in zip(percentages, percentiles): + print(f'{percentage}% percentile latency: {percentile} seconds') + ++ # Output JSON results if specified ++ if args.output_json: ++ results = { ++ "avg_latency": np.mean(latencies), ++ "latencies": latencies.tolist(), ++ "percentiles": dict(zip(percentages, percentiles.tolist())), ++ } ++ with open(args.output_json, "w") as f: ++ json.dump(results, f, indent=4) ++ + + if __name__ == '__main__': +- parser = argparse.ArgumentParser( ++ parser = FlexibleArgumentParser( + description='Benchmark the latency of processing a single batch of ' + 'requests till completion.') +- parser.add_argument('--model', type=str, default='facebook/opt-125m') +- parser.add_argument('--tokenizer', type=str, default=None) +- parser.add_argument('--quantization', +- '-q', +- choices=[*QUANTIZATION_METHODS, None], +- default=None) +- parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) + parser.add_argument('--input-len', type=int, default=32) + parser.add_argument('--output-len', type=int, default=128) + parser.add_argument('--batch-size', type=int, default=8) +@@ -121,41 +129,6 @@ if __name__ == '__main__': + type=int, + default=30, + help='Number of iterations to run.') +- parser.add_argument('--trust-remote-code', +- action='store_true', +- help='trust remote code from huggingface') +- parser.add_argument( +- '--dtype', +- type=str, +- default='auto', +- choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], +- help='data type for model weights and activations. ' +- 'The "auto" option will use FP16 precision ' +- 'for FP32 and FP16 models, and BF16 precision ' +- 'for BF16 models.') +- parser.add_argument('--enforce-eager', +- action='store_true', +- help='enforce eager mode and disable CUDA graph') +- parser.add_argument( +- "--kv-cache-dtype", +- type=str, +- choices=['auto', 'fp8'], +- default='auto', +- help= +- 'Data type for kv cache storage. If "auto", will use model data type. ' +- 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' +- 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' +- 'common inference criteria.') +- parser.add_argument( +- '--quantization-param-path', +- type=str, +- default=None, +- help='Path to the JSON file containing the KV cache scaling factors. ' +- 'This should generally be supplied, when KV cache dtype is FP8. ' +- 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' +- 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' +- 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' +- 'instead supported for common inference criteria.') + parser.add_argument( + '--profile', + action='store_true', +@@ -167,29 +140,11 @@ if __name__ == '__main__': + help=('path to save the pytorch profiler output. Can be visualized ' + 'with ui.perfetto.dev or Tensorboard.')) + parser.add_argument( +- "--device", ++ '--output-json', + type=str, +- default="cuda", +- choices=["cuda", "cpu"], +- help='device type for vLLM execution, supporting CUDA and CPU.') +- parser.add_argument('--block-size', +- type=int, +- default=16, +- help='block size of key/value cache') +- parser.add_argument( +- '--enable-chunked-prefill', +- action='store_true', +- help='If True, the prefill requests can be chunked based on the ' +- 'max_num_batched_tokens') +- parser.add_argument( +- "--ray-workers-use-nsight", +- action='store_true', +- help="If specified, use nsight to profile ray workers", +- ) +- parser.add_argument('--download-dir', +- type=str, +- default=None, +- help='directory to download and load the weights, ' +- 'default to the default cache dir of huggingface') ++ default=None, ++ help='Path to save the latency results in JSON format.') ++ ++ parser = EngineArgs.add_cli_args(parser) + args = parser.parse_args() + main(args) +diff --git a/benchmarks/benchmark_long_document_qa_throughput.py b/benchmarks/benchmark_long_document_qa_throughput.py +new file mode 100644 +index 0000000..0b8fba3 +--- /dev/null ++++ b/benchmarks/benchmark_long_document_qa_throughput.py +@@ -0,0 +1,183 @@ ++""" ++Offline benchmark to test the long document QA throughput. ++ ++Example usage: ++ # This workload samples 8 different prompts with a default input ++ # length of 20000 tokens, then replicates each prompt 2 times ++ # in random order. ++ python benchmark_long_document_qa_throughput.py \ ++ --model meta-llama/Llama-2-7b-chat-hf \ ++ --enable-prefix-caching \ ++ --num-documents 8 \ ++ --repeat-count 2 ++ ++Commandline arguments: ++ --num-documents: The number of documents to sample prompts from. ++ ++ --document-length: The length of each document in tokens. ++ (Optional, default: 20000) ++ ++ --output-len: The number of tokens to generate for each prompt. ++ (Optional, default: 10) ++ ++ --repeat-count: The number of times to repeat each prompt. ++ (Optional, default: 2) ++ ++ --repeat-mode: The mode to repeat prompts. The supported modes are: ++ - 'random': shuffle the prompts randomly. (Default) ++ - 'tile': the entire prompt list is repeated in sequence. (Potentially ++ lowest cache hit) ++ - 'interleave': each prompt is repeated consecutively before ++ moving to the next element. (Highest cache hit) ++ ++ --shuffle-seed: Random seed when the repeat mode is "random". ++ (Optional, default: 0) ++ ++In the meantime, it also supports all the vLLM engine args to initialize the ++LLM engine. You can refer to the `vllm.engine.arg_utils.EngineArgs` for more ++details. ++""" ++ ++import dataclasses ++import random ++import time ++ ++from vllm import LLM, SamplingParams ++from vllm.engine.arg_utils import EngineArgs ++from vllm.utils import FlexibleArgumentParser ++ ++ ++def test_long_document_qa(llm=None, sampling_params=None, prompts=None): ++ """ ++ Test long document QA with the given prompts and sampling parameters. ++ Print the time spent in processing all the prompts. ++ ++ Args: ++ llm: The language model used for generating responses. ++ sampling_params: Sampling parameter used to generate the response. ++ prompts: A list of prompt strings to be processed by the LLM. ++ """ ++ start_time = time.time() ++ llm.generate(prompts, sampling_params=sampling_params) ++ end_time = time.time() ++ print(f"Time to execute all requests: {end_time - start_time:.4f} secs") ++ ++ ++def repeat_prompts(prompts, repeat_count, mode: str): ++ """ ++ Repeat each prompt in the list for a specified number of times. ++ The order of prompts in the output list depends on the mode. ++ ++ Args: ++ prompts: A list of prompts to be repeated. ++ repeat_count: The number of times each prompt is repeated. ++ mode: The mode of repetition. Supported modes are: ++ - 'random': Shuffle the prompts randomly after repetition. ++ - 'tile': Repeat the entire prompt list in sequence. ++ Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3]. ++ - 'interleave': Repeat each prompt consecutively before moving to ++ the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3]. ++ ++ Returns: ++ A list of repeated prompts in the specified order. ++ ++ Raises: ++ ValueError: If an invalid mode is provided. ++ """ ++ print("Repeat mode: ", mode) ++ if mode == 'random': ++ repeated_prompts = prompts * repeat_count ++ random.shuffle(repeated_prompts) ++ return repeated_prompts ++ elif mode == 'tile': ++ return prompts * repeat_count ++ elif mode == 'interleave': ++ repeated_prompts = [] ++ for prompt in prompts: ++ repeated_prompts.extend([prompt] * repeat_count) ++ return repeated_prompts ++ else: ++ raise ValueError(f"Invalid mode: {mode}, only support " ++ "'random', 'tile', 'interleave'") ++ ++ ++def main(args): ++ random.seed(args.shuffle_seed) ++ ++ # Prepare the prompts: ++ # we append the document id at the beginning to avoid any of the document ++ # being the prefix of other documents ++ prompts = [ ++ str(i) + ' '.join(['hi'] * args.document_length) ++ for i in range(args.num_documents) ++ ] ++ ++ prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode) ++ ++ warmup_prompts = [ ++ "This is warm up request " + str(i) + \ ++ ' '.join(['hi'] * args.document_length) ++ for i in range(args.num_documents)] ++ ++ # Create the LLM engine ++ engine_args = EngineArgs.from_cli_args(args) ++ llm = LLM(**dataclasses.asdict(engine_args)) ++ sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) ++ ++ print("------warm up------") ++ test_long_document_qa( ++ llm=llm, ++ prompts=warmup_prompts, ++ sampling_params=sampling_params, ++ ) ++ ++ print("------start generating------") ++ test_long_document_qa( ++ llm=llm, ++ prompts=prompts, ++ sampling_params=sampling_params, ++ ) ++ ++ ++if __name__ == "__main__": ++ parser = FlexibleArgumentParser( ++ description= ++ 'Benchmark the performance with or without automatic prefix caching.') ++ ++ parser.add_argument( ++ '--document-length', ++ type=int, ++ # Roughly the number of tokens for a system paper, ++ # excluding images ++ default=20000, ++ help='Range of input lengths for sampling prompts,' ++ 'specified as "min:max" (e.g., "128:256").') ++ ++ parser.add_argument('--num-documents', ++ type=int, ++ default=8, ++ help='Range of input lengths for sampling prompts,' ++ 'specified as "min:max" (e.g., "128:256").') ++ ++ parser.add_argument('--output-len', type=int, default=10) ++ ++ parser.add_argument('--repeat-count', ++ type=int, ++ default=2, ++ help='Number of times to repeat each prompt') ++ ++ parser.add_argument("--repeat-mode", ++ type=str, ++ default='random', ++ help='The mode to repeat prompts. The supported ' ++ 'modes are "random", "tile", and "interleave". ' ++ 'See repeat_prompts() in the source code for details.') ++ ++ parser.add_argument("--shuffle-seed", ++ type=int, ++ default=0, ++ help='Random seed when the repeat mode is "random"') ++ ++ parser = EngineArgs.add_cli_args(parser) ++ args = parser.parse_args() ++ main(args) +diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py +index 0899669..3ab421a 100644 +--- a/benchmarks/benchmark_prefix_caching.py ++++ b/benchmarks/benchmark_prefix_caching.py +@@ -1,7 +1,47 @@ +-import argparse ++""" ++Benchmark the efficiency of prefix caching. ++ ++This script allows you to benchmark the performance of ++a model with and without prefix caching using either fixed prompts ++or prompts sampled from the ShareGPT dataset. ++ ++Fixed example usage: ++ python benchmark_prefix_caching.py \ ++ --model meta-llama/Llama-2-7b-chat-hf \ ++ --enable-prefix-caching \ ++ --num-prompts 1 \ ++ --repeat-count 100 \ ++ --input-length-range 128:256 ++ ++ShareGPT example usage: ++ # This command samples 20 prompts with input lengths ++ # between 128 and 256 tokens from the ShareGPT dataset, ++ # then replicates each prompt 5 times. ++ python benchmark_prefix_caching.py \ ++ --model meta-llama/Llama-2-7b-chat-hf \ ++ --dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json \ ++ --enable-prefix-caching \ ++ --num-prompts 20 \ ++ --repeat-count 5 \ ++ --input-length-range 128:256 ++""" ++ ++import dataclasses ++import json ++import random + import time ++from typing import List, Optional, Tuple ++ ++from transformers import PreTrainedTokenizerBase + + from vllm import LLM, SamplingParams ++from vllm.engine.arg_utils import EngineArgs ++from vllm.utils import FlexibleArgumentParser ++ ++try: ++ from vllm.transformers_utils.tokenizer import get_tokenizer ++except ImportError: ++ from backend_request_func import get_tokenizer + + PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501 + +@@ -15,25 +55,150 @@ def test_prefix(llm=None, sampling_params=None, prompts=None): + print(f"cost time {end_time - start_time}") + + ++@dataclasses.dataclass ++class Request: ++ prompt: str ++ prompt_len: int ++ output_len: int ++ ++ ++def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str: ++ vocab = tokenizer.get_vocab() ++ # Remove the special tokens. ++ vocab = { ++ k: v ++ for k, v in vocab.items() if k not in tokenizer.all_special_ids ++ } ++ return random.choices(list(vocab.values()), k=length) ++ ++ ++def sample_requests_from_dataset( ++ dataset_path: str, ++ num_requests: int, ++ tokenizer: PreTrainedTokenizerBase, ++ input_length_range: Tuple[int, int], ++ fixed_output_len: Optional[int], ++) -> List[Request]: ++ if fixed_output_len is not None and fixed_output_len < 4: ++ raise ValueError("output_len too small") ++ ++ # Load the dataset. ++ with open(dataset_path) as f: ++ dataset = json.load(f) ++ # Filter out the conversations with less than 2 turns. ++ dataset = [data for data in dataset if len(data["conversations"]) >= 2] ++ # Only keep the first two turns of each conversation. ++ dataset = [(data["conversations"][0]["value"], ++ data["conversations"][1]["value"]) for data in dataset] ++ ++ # Shuffle the dataset. ++ random.shuffle(dataset) ++ ++ min_len, max_len = input_length_range ++ assert min_len >= 0 and max_len >= min_len, "input_length_range too small" ++ ++ # Filter out sequences that are too long or too short ++ filtered_requests: List[Request] = [] ++ ++ for i in range(len(dataset)): ++ if len(filtered_requests) == num_requests: ++ break ++ ++ # Tokenize the prompts and completions. ++ prompt_token_ids = tokenizer(dataset[i][0]).input_ids ++ prompt = tokenizer.decode(prompt_token_ids) ++ completion = dataset[i][1] ++ completion_token_ids = tokenizer(completion).input_ids ++ prompt_len = len(prompt_token_ids) ++ output_len = (len(completion_token_ids) ++ if fixed_output_len is None else fixed_output_len) ++ if min_len <= prompt_len <= max_len: ++ filtered_requests.append(Request(prompt, prompt_len, output_len)) ++ ++ return filtered_requests ++ ++ ++def sample_requests_from_random( ++ num_requests: int, ++ tokenizer: PreTrainedTokenizerBase, ++ input_length_range: Tuple[int, int], ++ fixed_output_len: Optional[int], ++ prefix_len: int, ++) -> List[Request]: ++ ++ requests = [] ++ prefix_token_ids = sample_tokens(tokenizer, prefix_len) ++ min_len, max_len = input_length_range ++ ++ for i in range(num_requests): ++ unique_part_token_ids = sample_tokens( ++ tokenizer, ++ random.randint(min_len - prefix_len, max_len - prefix_len)) ++ prompt_token_ids = prefix_token_ids + unique_part_token_ids ++ prompt = tokenizer.decode(prompt_token_ids) ++ prompt_len = len(prompt_token_ids) ++ assert (min_len <= prompt_len <= max_len ++ ), f"prompt_len {prompt_len} out of range {min_len}:{max_len}" ++ requests.append(Request(prompt, prompt_len, fixed_output_len)) ++ return requests ++ ++ ++def repeat_and_sort_requests(requests: List[Request], ++ repeat_count: int, ++ sort: bool = False) -> List[str]: ++ repeated_requests = requests * repeat_count ++ if sort: ++ repeated_requests.sort(key=lambda x: x[1]) ++ else: ++ random.shuffle(repeated_requests) ++ return [req.prompt for req in repeated_requests] ++ ++ + def main(args): +- llm = LLM(model=args.model, +- tokenizer_mode='auto', +- trust_remote_code=True, +- enforce_eager=True, +- use_v2_block_manager=args.use_v2_block_manager, +- tensor_parallel_size=args.tensor_parallel_size, +- enable_prefix_caching=args.enable_prefix_caching) +- +- num_prompts = 100 +- prompts = [PROMPT] * num_prompts ++ tokenizer = get_tokenizer(args.model, trust_remote_code=True) ++ input_length_range = tuple(map(int, args.input_length_range.split(':'))) ++ random.seed(args.seed) ++ if args.dataset_path is not None: ++ if args.prefix_len > 0: ++ raise ValueError("prefix-len is not supported when " ++ "dataset-path is provided.") ++ print(f"Start to sample {args.num_prompts} prompts " ++ f"from {args.dataset_path}") ++ filtered_requests = sample_requests_from_dataset( ++ dataset_path=args.dataset_path, ++ num_requests=args.num_prompts, ++ tokenizer=tokenizer, ++ input_length_range=input_length_range, ++ fixed_output_len=args.output_len, ++ ) ++ else: ++ print(f"Start to sample {args.num_prompts} prompts from random") ++ filtered_requests = sample_requests_from_random( ++ num_requests=args.num_prompts, ++ tokenizer=tokenizer, ++ input_length_range=input_length_range, ++ fixed_output_len=args.output_len, ++ prefix_len=args.prefix_len, ++ ) ++ ++ # Print some helpful stats of the requests. ++ print(f"Sampled {len(filtered_requests)} requests.") ++ prompt_lens = [req.prompt_len for req in filtered_requests] ++ print(f"Average input length: {sum(prompt_lens) / len(prompt_lens)}") ++ print(f"P50 input length: {sorted(prompt_lens)[len(prompt_lens) // 2]}") ++ print(f"Min Prompt Length: {min(prompt_lens)}") ++ print(f"Max Prompt Length: {max(prompt_lens)}") ++ ++ engine_args = EngineArgs.from_cli_args(args) ++ ++ llm = LLM(**dataclasses.asdict(engine_args)) ++ + sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) + +- print("------warm up------") +- test_prefix( +- llm=llm, +- prompts=prompts, +- sampling_params=sampling_params, +- ) ++ print("Testing filtered requests") ++ prompts = repeat_and_sort_requests(filtered_requests, ++ repeat_count=args.repeat_count, ++ sort=args.sort) + + print("------start generating------") + test_prefix( +@@ -44,19 +209,40 @@ def main(args): + + + if __name__ == "__main__": +- parser = argparse.ArgumentParser( +- description='Benchmark the performance with or without automatic ' +- 'prefix caching.') +- parser.add_argument('--model', ++ parser = FlexibleArgumentParser( ++ description= ++ 'Benchmark the performance with or without automatic prefix caching.') ++ parser.add_argument("--dataset-path", + type=str, +- default='baichuan-inc/Baichuan2-13B-Chat') +- parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) ++ default=None, ++ help="Path to the dataset.") + parser.add_argument('--output-len', type=int, default=10) +- parser.add_argument('--enable-prefix-caching', ++ parser.add_argument('--num-prompts', ++ type=int, ++ required=True, ++ help="Number of the prompts sampled from dataset") ++ parser.add_argument('--repeat-count', ++ type=int, ++ default=1, ++ help='Number of times to repeat each prompt') ++ parser.add_argument('--sort', + action='store_true', +- help='enable prefix caching') +- parser.add_argument('--use-v2-block-manager', +- action='store_true', +- help='Use BlockSpaceMangerV2') ++ help='Sort prompts by input length') ++ parser.add_argument('--input-length-range', ++ type=str, ++ required=True, ++ help='Range of input lengths for sampling prompts,' ++ 'specified as "min:max" (e.g., "128:256").') ++ parser.add_argument( ++ "--prefix-len", ++ type=int, ++ default=0, ++ help="Specifies the length of a common prefix to be " ++ "added to the input prompt. The input-length-range will " ++ "subtract this length when filtering prompts. Only used " ++ "when dataset-path is not provided.", ++ ) ++ ++ parser = EngineArgs.add_cli_args(parser) + args = parser.parse_args() + main(args) +diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py +new file mode 100644 +index 0000000..e0c9e6a +--- /dev/null ++++ b/benchmarks/benchmark_prioritization.py +@@ -0,0 +1,177 @@ ++"""Benchmark offline prioritization.""" ++import argparse ++import dataclasses ++import json ++import random ++import time ++from typing import List, Optional, Tuple ++ ++from transformers import AutoTokenizer, PreTrainedTokenizerBase ++ ++from vllm.engine.arg_utils import EngineArgs ++from vllm.utils import FlexibleArgumentParser ++ ++ ++def sample_requests( ++ dataset_path: str, ++ num_requests: int, ++ tokenizer: PreTrainedTokenizerBase, ++ fixed_output_len: Optional[int], ++) -> List[Tuple[str, int, int]]: ++ if fixed_output_len is not None and fixed_output_len < 4: ++ raise ValueError("output_len too small") ++ ++ # Load the dataset. ++ with open(dataset_path) as f: ++ dataset = json.load(f) ++ # Filter out the conversations with less than 2 turns. ++ dataset = [data for data in dataset if len(data["conversations"]) >= 2] ++ # Only keep the first two turns of each conversation. ++ dataset = [(data["conversations"][0]["value"], ++ data["conversations"][1]["value"]) for data in dataset] ++ ++ # Shuffle the dataset. ++ random.shuffle(dataset) ++ ++ # Filter out sequences that are too long or too short ++ filtered_dataset: List[Tuple[str, int, int]] = [] ++ for i in range(len(dataset)): ++ if len(filtered_dataset) == num_requests: ++ break ++ ++ # Tokenize the prompts and completions. ++ prompt = dataset[i][0] ++ prompt_token_ids = tokenizer(prompt).input_ids ++ completion = dataset[i][1] ++ completion_token_ids = tokenizer(completion).input_ids ++ prompt_len = len(prompt_token_ids) ++ output_len = len(completion_token_ids ++ ) if fixed_output_len is None else fixed_output_len ++ if prompt_len < 4 or output_len < 4: ++ # Prune too short sequences. ++ continue ++ if prompt_len > 1024 or prompt_len + output_len > 2048: ++ # Prune too long sequences. ++ continue ++ ++ #Select a equi-probable random priority ++ priority = 0 if random.random() < 0.5 else 1 ++ ++ filtered_dataset.append((prompt, prompt_len, output_len, priority)) ++ ++ return filtered_dataset ++ ++ ++def run_vllm( ++ requests: List[Tuple[str, int, int]], ++ n: int, ++ engine_args: EngineArgs, ++) -> float: ++ from vllm import LLM, SamplingParams ++ llm = LLM(**dataclasses.asdict(engine_args)) ++ ++ # Add the requests to the engine. ++ prompts = [] ++ sampling_params = [] ++ priority = [] ++ for prompt, _, output_len, _priority in requests: ++ prompts.append(prompt) ++ priority.append(_priority) ++ sampling_params.append( ++ SamplingParams( ++ n=n, ++ temperature=1.0, ++ top_p=1.0, ++ ignore_eos=True, ++ max_tokens=output_len, ++ )) ++ ++ start = time.perf_counter() ++ llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True) ++ end = time.perf_counter() ++ return end - start ++ ++ ++def main(args: argparse.Namespace): ++ print(args) ++ random.seed(args.seed) ++ ++ # Sample the requests. ++ tokenizer = AutoTokenizer.from_pretrained( ++ args.tokenizer, trust_remote_code=args.trust_remote_code) ++ if args.dataset is None: ++ # Synthesize a prompt with the given input length. ++ prompt = "hi" * (args.input_len - 1) ++ requests = [(prompt, args.input_len, args.output_len) ++ for _ in range(args.num_prompts)] ++ else: ++ requests = sample_requests(args.dataset, args.num_prompts, tokenizer, ++ args.output_len) ++ ++ if args.backend == "vllm": ++ elapsed_time = run_vllm(requests, args.n, ++ EngineArgs.from_cli_args(args)) ++ else: ++ raise ValueError(f"Unknown backend: {args.backend}") ++ total_num_tokens = sum(prompt_len + output_len ++ for _, prompt_len, output_len, priority in requests) ++ print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " ++ f"{total_num_tokens / elapsed_time:.2f} tokens/s") ++ ++ # Output JSON results if specified ++ if args.output_json: ++ results = { ++ "elapsed_time": elapsed_time, ++ "num_requests": len(requests), ++ "total_num_tokens": total_num_tokens, ++ "requests_per_second": len(requests) / elapsed_time, ++ "tokens_per_second": total_num_tokens / elapsed_time, ++ } ++ with open(args.output_json, "w") as f: ++ json.dump(results, f, indent=4) ++ ++ ++if __name__ == "__main__": ++ parser = FlexibleArgumentParser(description="Benchmark the throughput.") ++ parser.add_argument("--backend", ++ type=str, ++ choices=["vllm", "hf", "mii"], ++ default="vllm") ++ parser.add_argument("--dataset", ++ type=str, ++ default=None, ++ help="Path to the dataset.") ++ parser.add_argument("--input-len", ++ type=int, ++ default=None, ++ help="Input prompt length for each request") ++ parser.add_argument("--output-len", ++ type=int, ++ default=None, ++ help="Output length for each request. Overrides the " ++ "output length from the dataset.") ++ parser.add_argument("--n", ++ type=int, ++ default=1, ++ help="Number of generated sequences per prompt.") ++ parser.add_argument("--num-prompts", ++ type=int, ++ default=200, ++ help="Number of prompts to process.") ++ parser.add_argument( ++ '--output-json', ++ type=str, ++ default=None, ++ help='Path to save the throughput results in JSON format.') ++ ++ parser = EngineArgs.add_cli_args(parser) ++ args = parser.parse_args() ++ if args.tokenizer is None: ++ args.tokenizer = args.model ++ if args.dataset is None: ++ assert args.input_len is not None ++ assert args.output_len is not None ++ else: ++ assert args.input_len is None ++ ++ main(args) +diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py +index 2c2d69d..4eb0e1f 100644 +--- a/benchmarks/benchmark_serving.py ++++ b/benchmarks/benchmark_serving.py +@@ -1,9 +1,9 @@ +-"""Benchmark online serving throughput. ++r"""Benchmark online serving throughput. + + On the server side, run one of the following commands: + vLLM OpenAI API server +- python -m vllm.entrypoints.openai.api_server \ +- --model --swap-space 16 \ ++ vllm serve \ ++ --swap-space 16 \ + --disable-log-requests + + (TGI backend) +@@ -17,9 +17,15 @@ On the client side, run: + --dataset-path \ + --request-rate \ # By default is inf + --num-prompts # By default is 1000 ++ ++ when using tgi backend, add ++ --endpoint /generate_stream ++ to the end of the command above. + """ + import argparse + import asyncio ++import base64 ++import io + import json + import os + import random +@@ -27,15 +33,27 @@ import time + import warnings + from dataclasses import dataclass + from datetime import datetime +-from typing import AsyncGenerator, List, Optional, Tuple ++from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple + + import numpy as np + from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, + RequestFuncOutput) ++from datasets import load_dataset ++from PIL.Image import Image + from tqdm.asyncio import tqdm + from transformers import PreTrainedTokenizerBase + +-from vllm.transformers_utils.tokenizer import get_tokenizer ++try: ++ from vllm.transformers_utils.tokenizer import get_tokenizer ++except ImportError: ++ from backend_request_func import get_tokenizer ++ ++try: ++ from vllm.utils import FlexibleArgumentParser ++except ImportError: ++ from argparse import ArgumentParser as FlexibleArgumentParser ++ ++MILLISECONDS_TO_SECONDS_CONVERSION = 1000 + + + @dataclass +@@ -44,14 +62,28 @@ class BenchmarkMetrics: + total_input: int + total_output: int + request_throughput: float +- input_throughput: float ++ request_goodput: float + output_throughput: float ++ total_token_throughput: float + mean_ttft_ms: float + median_ttft_ms: float +- p99_ttft_ms: float ++ std_ttft_ms: float ++ percentiles_ttft_ms: List[Tuple[float, float]] + mean_tpot_ms: float + median_tpot_ms: float +- p99_tpot_ms: float ++ std_tpot_ms: float ++ percentiles_tpot_ms: List[Tuple[float, float]] ++ mean_itl_ms: float ++ median_itl_ms: float ++ std_itl_ms: float ++ percentiles_itl_ms: List[Tuple[float, float]] ++ # E2EL stands for end-to-end latency per request. ++ # It is the time taken on the client side from sending ++ # a request to receiving a complete response. ++ mean_e2el_ms: float ++ median_e2el_ms: float ++ std_e2el_ms: float ++ percentiles_e2el_ms: List[Tuple[float, float]] + + + def sample_sharegpt_requests( +@@ -59,12 +91,9 @@ def sample_sharegpt_requests( + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, +-) -> List[Tuple[str, int, int]]: +- if fixed_output_len is not None and fixed_output_len < 4: +- raise ValueError("output_len too small") +- ++) -> List[Tuple[str, int, int, None]]: + # Load the dataset. +- with open(dataset_path) as f: ++ with open(dataset_path, encoding='utf-8') as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] +@@ -89,13 +118,13 @@ def sample_sharegpt_requests( + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len +- if prompt_len < 4 or output_len < 4: ++ if prompt_len < 4 or (fixed_output_len is None and output_len < 4): + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue +- filtered_dataset.append((prompt, prompt_len, output_len)) ++ filtered_dataset.append((prompt, prompt_len, output_len, None)) + + return filtered_dataset + +@@ -107,13 +136,13 @@ def sample_sonnet_requests( + output_len: int, + prefix_len: int, + tokenizer: PreTrainedTokenizerBase, +-) -> List[Tuple[str, str, int, int]]: ++) -> List[Tuple[str, str, int, int, None]]: + assert ( + input_len > prefix_len + ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'." + + # Load the dataset. +- with open(dataset_path) as f: ++ with open(dataset_path, encoding='utf-8') as f: + poem_lines = f.readlines() + + # Tokenize the poem lines. +@@ -150,9 +179,9 @@ def sample_sonnet_requests( + # Sample the rest of lines per request. + sampled_requests: List[Tuple[str, int, int]] = [] + for _ in range(num_requests): +- sampled_lines = "".join( +- prefix_lines + +- random.sample(poem_lines, num_input_lines - num_prefix_lines)) ++ num_lines_needed = num_input_lines - num_prefix_lines ++ sampled_lines = "".join(prefix_lines + ++ random.choices(poem_lines, k=num_lines_needed)) + + prompt = f"{base_prompt}{sampled_lines}" + message = [ +@@ -165,24 +194,224 @@ def sample_sonnet_requests( + message, add_generation_prompt=True, tokenize=False) + prompt_len = len(tokenizer(prompt_formatted).input_ids) + sampled_requests.append( +- (prompt, prompt_formatted, prompt_len, output_len)) ++ (prompt, prompt_formatted, prompt_len, output_len, None)) ++ ++ return sampled_requests ++ ++ ++def sample_mmmu_pro_vision_requests( ++ dataset, ++ num_requests: int, ++ tokenizer: PreTrainedTokenizerBase, ++ fixed_output_len: Optional[int] = None, ++) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: ++ sampled_requests: List[Tuple[str, int, int, Dict[str, ++ Collection[str]]]] = [] ++ for data in dataset: ++ if len(sampled_requests) == num_requests: ++ break ++ ++ # MMMU-Pro vision direct prompt ++ # Ref: https://github.com/MMMU-Benchmark/MMMU/blob/6ce42f4d8f70c1841c67867152648974415b5cac/mmmu-pro/prompts.yaml#L5 ++ prompt = ( ++ "Answer with the option letter from the given choices directly. " ++ "The last line of your response should be of the following " ++ "format: 'Answer: $LETTER' (without quotes) where LETTER is one of " ++ "options.") ++ ++ prompt_token_ids = tokenizer(prompt).input_ids ++ if fixed_output_len is None: ++ # Default max output len is set to 128 ++ print("--hf-output-len is not provided. Using default value 128.") ++ fixed_output_len = 128 ++ ++ prompt_len = len(prompt_token_ids) ++ output_len = fixed_output_len ++ ++ assert isinstance( ++ data["image"], ++ Image), ("Input image format must be `PIL.Image.Image`, " ++ f"given {type(data['image'])}.") ++ image: Image = data["image"] ++ image = image.convert("RGB") ++ image_data = io.BytesIO() ++ image.save(image_data, format='JPEG') ++ image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8") ++ mm_content = { ++ "type": "image_url", ++ "image_url": { ++ "url": f"data:image/jpeg;base64,{image_base64}" ++ }, ++ } ++ ++ sampled_requests.append((prompt, prompt_len, output_len, mm_content)) ++ ++ return sampled_requests ++ ++ ++def sample_hf_requests( ++ dataset_path: str, ++ dataset_subset: str, ++ dataset_split: str, ++ num_requests: int, ++ tokenizer: PreTrainedTokenizerBase, ++ random_seed: int, ++ fixed_output_len: Optional[int] = None, ++) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: ++ ++ # Special case for MMMU-Pro vision dataset ++ if dataset_path == 'MMMU/MMMU_Pro' and dataset_subset == 'vision': ++ assert dataset_split == "test" ++ dataset = load_dataset(dataset_path, ++ name=dataset_subset, ++ split=dataset_split, ++ streaming=True) ++ assert "image" in dataset.features, ( ++ "MMMU/MMMU_Pro vision dataset must have 'image' column.") ++ filter_func = lambda x: isinstance(x["image"], Image) ++ dataset = dataset.shuffle(seed=random_seed).filter(filter_func) ++ return sample_mmmu_pro_vision_requests(dataset, num_requests, ++ tokenizer, fixed_output_len) ++ ++ dataset = load_dataset(dataset_path, ++ name=dataset_subset, ++ split=dataset_split, ++ streaming=True) ++ assert "conversations" in dataset.features, ( ++ "HF Dataset must have 'conversations' column.") ++ filter_func = lambda x: len(x["conversations"]) >= 2 ++ filtered_dataset = dataset.shuffle(seed=random_seed).filter(filter_func) ++ sampled_requests: List[Tuple[str, int, int, Dict[str, ++ Collection[str]]]] = [] ++ for data in filtered_dataset: ++ if len(sampled_requests) == num_requests: ++ break ++ ++ # Tokenize the prompts and completions. ++ prompt = data["conversations"][0]["value"] ++ prompt_token_ids = tokenizer(prompt).input_ids ++ completion = data["conversations"][1]["value"] ++ completion_token_ids = tokenizer(completion).input_ids ++ prompt_len = len(prompt_token_ids) ++ output_len = len(completion_token_ids ++ ) if fixed_output_len is None else fixed_output_len ++ if fixed_output_len is None and (prompt_len < 4 or output_len < 4): ++ # Prune too short sequences. ++ continue ++ if fixed_output_len is None and \ ++ (prompt_len > 1024 or prompt_len + output_len > 2048): ++ # Prune too long sequences. ++ continue ++ ++ if "image" in data and isinstance(data["image"], Image): ++ image: Image = data["image"] ++ image = image.convert("RGB") ++ image_data = io.BytesIO() ++ image.save(image_data, format='JPEG') ++ image_base64 = base64.b64encode( ++ image_data.getvalue()).decode("utf-8") ++ mm_content = { ++ "type": "image_url", ++ "image_url": { ++ "url": f"data:image/jpeg;base64,{image_base64}" ++ }, ++ } ++ elif "image" in data and isinstance(data["image"], str): ++ if (data["image"].startswith("http://") or \ ++ data["image"].startswith("file://")): ++ image_url = data["image"] ++ else: ++ image_url = f"file://{data['image']}" ++ ++ mm_content = { ++ "type": "image_url", ++ "image_url": { ++ "url": image_url ++ }, ++ } ++ else: ++ mm_content = None ++ ++ sampled_requests.append((prompt, prompt_len, output_len, mm_content)) + + return sampled_requests + + ++def sample_random_requests( ++ prefix_len: int, ++ input_len: int, ++ output_len: int, ++ num_prompts: int, ++ range_ratio: float, ++ tokenizer: PreTrainedTokenizerBase, ++) -> List[Tuple[str, int, int]]: ++ prefix_token_ids = np.random.randint(0, ++ tokenizer.vocab_size, ++ size=prefix_len).tolist() ++ ++ input_lens = np.random.randint( ++ int(input_len * range_ratio), ++ input_len + 1, ++ size=num_prompts, ++ ) ++ output_lens = np.random.randint( ++ int(output_len * range_ratio), ++ output_len + 1, ++ size=num_prompts, ++ ) ++ offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) ++ input_requests = [] ++ for i in range(num_prompts): ++ prompt = tokenizer.decode(prefix_token_ids + ++ [(offsets[i] + i + j) % tokenizer.vocab_size ++ for j in range(input_lens[i])]) ++ ++ input_requests.append((prompt, int(prefix_len + input_lens[i]), ++ int(output_lens[i]), None)) ++ ++ return input_requests ++ ++ + async def get_request( + input_requests: List[Tuple[str, int, int]], + request_rate: float, ++ burstiness: float = 1.0, + ) -> AsyncGenerator[Tuple[str, int, int], None]: ++ """ ++ Asynchronously generates requests at a specified rate ++ with OPTIONAL burstiness. ++ ++ Args: ++ input_requests: ++ A list of input requests, each represented as a tuple. ++ request_rate: ++ The rate at which requests are generated (requests/s). ++ burstiness (optional): ++ The burstiness factor of the request generation. ++ Only takes effect when request_rate is not inf. ++ Default value is 1, which follows a Poisson process. ++ Otherwise, the request intervals follow a gamma distribution. ++ A lower burstiness value (0 < burstiness < 1) results ++ in more bursty requests, while a higher burstiness value ++ (burstiness > 1) results in a more uniform arrival of requests. ++ """ + input_requests = iter(input_requests) ++ ++ # Calculate scale parameter theta to maintain the desired request_rate. ++ assert burstiness > 0, ( ++ f"A positive burstiness factor is expected, but given {burstiness}.") ++ theta = 1.0 / (request_rate * burstiness) ++ + for request in input_requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue +- # Sample the request interval from the exponential distribution. +- interval = np.random.exponential(1.0 / request_rate) ++ ++ # Sample the request interval from the gamma distribution. ++ # If burstiness is 1, it follows exponential distribution. ++ interval = np.random.gamma(shape=burstiness, scale=theta) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + +@@ -192,39 +421,100 @@ def calculate_metrics( + outputs: List[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, ++ selected_percentile_metrics: List[str], ++ selected_percentiles: List[float], ++ gootput_config_dict: Dict[str, float], + ) -> Tuple[BenchmarkMetrics, List[int]]: +- actual_output_lens = [] ++ actual_output_lens: List[int] = [] + total_input = 0 + completed = 0 +- tpots = [] +- ttfts = [] ++ good_completed = 0 ++ itls: List[float] = [] ++ tpots: List[float] = [] ++ all_tpots: List[float] = [] ++ ttfts: List[float] = [] ++ e2els: List[float] = [] + for i in range(len(outputs)): + if outputs[i].success: +- output_len = len(tokenizer(outputs[i].generated_text).input_ids) ++ # We use the tokenizer to count the number of output tokens for all ++ # serving backends instead of looking at len(outputs[i].itl) since ++ # multiple output tokens may be bundled together ++ # Note : this may inflate the output token count slightly ++ output_len = len( ++ tokenizer(outputs[i].generated_text, ++ add_special_tokens=False).input_ids) + actual_output_lens.append(output_len) + total_input += input_requests[i][1] ++ tpot = 0 + if output_len > 1: +- tpots.append( +- (outputs[i].latency - outputs[i].ttft) / (output_len - 1)) ++ tpot = (outputs[i].latency - outputs[i].ttft) / (output_len - ++ 1) ++ tpots.append(tpot) ++ # Note: if output_len <= 1, we regard tpot as 0 for goodput ++ all_tpots.append(tpot) ++ itls += outputs[i].itl + ttfts.append(outputs[i].ttft) ++ e2els.append(outputs[i].latency) + completed += 1 + else: + actual_output_lens.append(0) + ++ if gootput_config_dict: ++ valid_metrics = [] ++ slo_values = [] ++ ++ if "ttft" in gootput_config_dict: ++ valid_metrics.append(ttfts) ++ slo_values.append(gootput_config_dict["ttft"] / ++ MILLISECONDS_TO_SECONDS_CONVERSION) ++ if "tpot" in gootput_config_dict: ++ valid_metrics.append(all_tpots) ++ slo_values.append(gootput_config_dict["tpot"] / ++ MILLISECONDS_TO_SECONDS_CONVERSION) ++ if "e2el" in gootput_config_dict: ++ valid_metrics.append(e2els) ++ slo_values.append(gootput_config_dict["e2el"] / ++ MILLISECONDS_TO_SECONDS_CONVERSION) ++ ++ for req_metric in zip(*valid_metrics): ++ is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) ++ if is_good_req: ++ good_completed += 1 ++ ++ if completed == 0: ++ warnings.warn( ++ "All requests failed. This is likely due to a misconfiguration " ++ "on the benchmark arguments.", ++ stacklevel=2) + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=sum(actual_output_lens), + request_throughput=completed / dur_s, +- input_throughput=total_input / dur_s, ++ request_goodput=good_completed / dur_s, + output_throughput=sum(actual_output_lens) / dur_s, ++ total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) * + 1000, # ttfts is empty if streaming is not supported by backend ++ std_ttft_ms=np.std(ttfts or 0) * 1000, + median_ttft_ms=np.median(ttfts or 0) * 1000, +- p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, +- mean_tpot_ms=np.mean(tpots) * 1000, +- median_tpot_ms=np.median(tpots) * 1000, +- p99_tpot_ms=np.percentile(tpots, 99) * 1000, ++ percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) ++ for p in selected_percentiles], ++ mean_tpot_ms=np.mean(tpots or 0) * 1000, ++ std_tpot_ms=np.std(tpots or 0) * 1000, ++ median_tpot_ms=np.median(tpots or 0) * 1000, ++ percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) ++ for p in selected_percentiles], ++ mean_itl_ms=np.mean(itls or 0) * 1000, ++ std_itl_ms=np.std(itls or 0) * 1000, ++ median_itl_ms=np.median(itls or 0) * 1000, ++ percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) ++ for p in selected_percentiles], ++ mean_e2el_ms=np.mean(e2els or 0) * 1000, ++ std_e2el_ms=np.std(e2els or 0) * 1000, ++ median_e2el_ms=np.median(e2els or 0) * 1000, ++ percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) ++ for p in selected_percentiles], + ) + + return metrics, actual_output_lens +@@ -233,43 +523,129 @@ def calculate_metrics( + async def benchmark( + backend: str, + api_url: str, ++ base_url: str, + model_id: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: List[Tuple[str, int, int]], ++ logprobs: Optional[int], + best_of: int, +- use_beam_search: bool, + request_rate: float, ++ burstiness: float, + disable_tqdm: bool, ++ profile: bool, ++ selected_percentile_metrics: List[str], ++ selected_percentiles: List[str], ++ ignore_eos: bool, ++ gootput_config_dict: Dict[str, float], ++ max_concurrency: Optional[int], + ): + if backend in ASYNC_REQUEST_FUNCS: +- request_func = ASYNC_REQUEST_FUNCS.get(backend) ++ request_func = ASYNC_REQUEST_FUNCS[backend] + else: + raise ValueError(f"Unknown backend: {backend}") + ++ print("Starting initial single prompt test run...") ++ test_prompt, test_prompt_len, test_output_len, test_mm_content = ( ++ input_requests[0]) ++ if backend != "openai-chat" and test_mm_content is not None: ++ # multi-modal benchmark is only available on OpenAI Chat backend. ++ raise ValueError( ++ "Multi-modal content is only supported on 'openai-chat' backend.") ++ test_input = RequestFuncInput( ++ model=model_id, ++ prompt=test_prompt, ++ api_url=api_url, ++ prompt_len=test_prompt_len, ++ output_len=test_output_len, ++ logprobs=logprobs, ++ best_of=best_of, ++ multi_modal_content=test_mm_content, ++ ignore_eos=ignore_eos, ++ ) ++ test_output = await request_func(request_func_input=test_input) ++ if not test_output.success: ++ raise ValueError( ++ "Initial test run failed - Please make sure benchmark arguments " ++ f"are correctly specified. Error: {test_output.error}") ++ else: ++ print("Initial test run completed. Starting main benchmark run...") ++ ++ if profile: ++ print("Starting profiler...") ++ profile_input = RequestFuncInput(model=model_id, ++ prompt=test_prompt, ++ api_url=base_url + "/start_profile", ++ prompt_len=test_prompt_len, ++ output_len=test_output_len, ++ logprobs=logprobs, ++ best_of=best_of, ++ multi_modal_content=test_mm_content, ++ ignore_eos=ignore_eos) ++ profile_output = await request_func(request_func_input=profile_input) ++ if profile_output.success: ++ print("Profiler started") ++ ++ if burstiness == 1.0: ++ distribution = "Poisson process" ++ else: ++ distribution = "Gamma distribution" ++ + print(f"Traffic request rate: {request_rate}") ++ print(f"Burstiness factor: {burstiness} ({distribution})") ++ print(f"Maximum request concurrency: {max_concurrency}") + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + ++ # This can be used once the minimum Python version is 3.10 or higher, ++ # and it will simplify the code in limited_request_func. ++ # semaphore = (asyncio.Semaphore(max_concurrency) ++ # if max_concurrency else contextlib.nullcontext()) ++ semaphore = (asyncio.Semaphore(max_concurrency) ++ if max_concurrency else None) ++ ++ async def limited_request_func(request_func_input, pbar): ++ if semaphore is None: ++ return await request_func(request_func_input=request_func_input, ++ pbar=pbar) ++ async with semaphore: ++ return await request_func(request_func_input=request_func_input, ++ pbar=pbar) ++ + benchmark_start_time = time.perf_counter() +- tasks = [] +- async for request in get_request(input_requests, request_rate): +- prompt, prompt_len, output_len = request +- request_func_input = RequestFuncInput( +- model=model_id, +- prompt=prompt, +- api_url=api_url, +- prompt_len=prompt_len, +- output_len=output_len, +- best_of=best_of, +- use_beam_search=use_beam_search, +- ) ++ tasks: List[asyncio.Task] = [] ++ async for request in get_request(input_requests, request_rate, burstiness): ++ prompt, prompt_len, output_len, mm_content = request ++ request_func_input = RequestFuncInput(model=model_id, ++ prompt=prompt, ++ api_url=api_url, ++ prompt_len=prompt_len, ++ output_len=output_len, ++ logprobs=logprobs, ++ best_of=best_of, ++ multi_modal_content=mm_content, ++ ignore_eos=ignore_eos) + tasks.append( + asyncio.create_task( +- request_func(request_func_input=request_func_input, +- pbar=pbar))) ++ limited_request_func(request_func_input=request_func_input, ++ pbar=pbar))) + outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + +- if not disable_tqdm: ++ if profile: ++ print("Stopping profiler...") ++ profile_input = RequestFuncInput( ++ model=model_id, ++ prompt=test_prompt, ++ api_url=base_url + "/stop_profile", ++ prompt_len=test_prompt_len, ++ output_len=test_output_len, ++ logprobs=logprobs, ++ best_of=best_of, ++ ) ++ profile_output = await request_func(request_func_input=profile_input) ++ if profile_output.success: ++ print("Profiler stopped") ++ ++ if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time +@@ -279,6 +655,9 @@ async def benchmark( + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, ++ selected_percentile_metrics=selected_percentile_metrics, ++ selected_percentiles=selected_percentiles, ++ gootput_config_dict=gootput_config_dict, + ) + + print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) +@@ -290,23 +669,13 @@ async def benchmark( + metrics.total_output)) + print("{:<40} {:<10.2f}".format("Request throughput (req/s):", + metrics.request_throughput)) +- print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):", +- metrics.input_throughput)) ++ if gootput_config_dict: ++ print("{:<40} {:<10.2f}".format("Request goodput (req/s):", ++ metrics.request_goodput)) + print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", + metrics.output_throughput)) +- print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-')) +- print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) +- print("{:<40} {:<10.2f}".format("Median TTFT (ms):", +- metrics.median_ttft_ms)) +- print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) +- print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)', +- n=50, +- c='-')) +- print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) +- print("{:<40} {:<10.2f}".format("Median TPOT (ms):", +- metrics.median_tpot_ms)) +- print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) +- print("=" * 50) ++ print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", ++ metrics.total_token_throughput)) + + result = { + "duration": benchmark_duration, +@@ -314,14 +683,10 @@ async def benchmark( + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "request_throughput": metrics.request_throughput, +- "input_throughput": metrics.input_throughput, ++ "request_goodput:": ++ metrics.request_goodput if gootput_config_dict else None, + "output_throughput": metrics.output_throughput, +- "mean_ttft_ms": metrics.mean_ttft_ms, +- "median_ttft_ms": metrics.median_ttft_ms, +- "p99_ttft_ms": metrics.p99_ttft_ms, +- "mean_tpot_ms": metrics.mean_tpot_ms, +- "median_tpot_ms": metrics.median_tpot_ms, +- "p99_tpot_ms": metrics.p99_tpot_ms, ++ "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": actual_output_lens, + "ttfts": [output.ttft for output in outputs], +@@ -329,9 +694,85 @@ async def benchmark( + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + } ++ ++ def process_one_metric( ++ # E.g., "ttft" ++ metric_attribute_name: str, ++ # E.g., "TTFT" ++ metric_name: str, ++ # E.g., "Time to First Token" ++ metric_header: str, ++ ): ++ # This function prints and adds statistics of the specified ++ # metric. ++ if metric_attribute_name not in selected_percentile_metrics: ++ return ++ print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) ++ print("{:<40} {:<10.2f}".format( ++ f"Mean {metric_name} (ms):", ++ getattr(metrics, f"mean_{metric_attribute_name}_ms"))) ++ print("{:<40} {:<10.2f}".format( ++ f"Median {metric_name} (ms):", ++ getattr(metrics, f"median_{metric_attribute_name}_ms"))) ++ result[f"mean_{metric_attribute_name}_ms"] = getattr( ++ metrics, f"mean_{metric_attribute_name}_ms") ++ result[f"median_{metric_attribute_name}_ms"] = getattr( ++ metrics, f"median_{metric_attribute_name}_ms") ++ result[f"std_{metric_attribute_name}_ms"] = getattr( ++ metrics, f"std_{metric_attribute_name}_ms") ++ for p, value in getattr(metrics, ++ f"percentiles_{metric_attribute_name}_ms"): ++ p_word = str(int(p)) if int(p) == p else str(p) ++ print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", ++ value)) ++ result[f"p{p_word}_{metric_attribute_name}_ms"] = value ++ ++ process_one_metric("ttft", "TTFT", "Time to First Token") ++ process_one_metric("tpot", "TPOT", ++ "Time per Output Token (excl. 1st token)") ++ process_one_metric("itl", "ITL", "Inter-token Latency") ++ process_one_metric("e2el", "E2EL", "End-to-end Latency") ++ ++ print("=" * 50) ++ + return result + + ++def check_goodput_args(args): ++ # Check and parse goodput arguments ++ gootput_config_dict = {} ++ VALID_NAMES = ["ttft", "tpot", "e2el"] ++ if args.goodput: ++ gootput_config_dict = parse_goodput(args.goodput) ++ for slo_name, slo_val in gootput_config_dict.items(): ++ if slo_name not in VALID_NAMES: ++ raise ValueError( ++ f"Invalid metric name found, {slo_name}: {slo_val}. " ++ "The service level objective name should be one of " ++ f"{str(VALID_NAMES)}. ") ++ if slo_val < 0: ++ raise ValueError( ++ f"Invalid value found, {slo_name}: {slo_val}. " ++ "The service level objective value should be " ++ "non-negative.") ++ return gootput_config_dict ++ ++ ++def parse_goodput(slo_pairs): ++ gootput_config_dict = {} ++ try: ++ for slo_pair in slo_pairs: ++ slo_name, slo_val = slo_pair.split(":") ++ gootput_config_dict[slo_name] = float(slo_val) ++ except ValueError as err: ++ raise argparse.ArgumentTypeError( ++ "Invalid format found for service level objectives. " ++ "Specify service level objectives for goodput as \"KEY:VALUE\" " ++ "pairs, where the key is a metric name, and the value is a " ++ "number in milliseconds.") from err ++ return gootput_config_dict ++ ++ + def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) +@@ -340,13 +781,17 @@ def main(args: argparse.Namespace): + backend = args.backend + model_id = args.model + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model ++ tokenizer_mode = args.tokenizer_mode + + if args.base_url is not None: + api_url = f"{args.base_url}{args.endpoint}" ++ base_url = f"{args.base_url}" + else: + api_url = f"http://{args.host}:{args.port}{args.endpoint}" ++ base_url = f"http://{args.host}:{args.port}" + + tokenizer = get_tokenizer(tokenizer_id, ++ tokenizer_mode=tokenizer_mode, + trust_remote_code=args.trust_remote_code) + + if args.dataset is not None: +@@ -381,9 +826,9 @@ def main(args: argparse.Namespace): + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + ) +- input_requests = [(prompt, prompt_len, output_len) ++ input_requests = [(prompt, prompt_len, output_len, None) + for prompt, prompt_formatted, prompt_len, +- output_len in input_requests] ++ output_len, _ in input_requests] + else: + assert ( + tokenizer.chat_template or tokenizer.default_chat_template +@@ -396,29 +841,62 @@ def main(args: argparse.Namespace): + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + ) +- input_requests = [(prompt_formatted, prompt_len, output_len) ++ input_requests = [(prompt_formatted, prompt_len, output_len, None) + for prompt, prompt_formatted, prompt_len, +- output_len in input_requests] ++ output_len, _ in input_requests] ++ ++ elif args.dataset_name == "hf": ++ input_requests = sample_hf_requests( ++ dataset_path=args.dataset_path, ++ dataset_subset=args.hf_subset, ++ dataset_split=args.hf_split, ++ num_requests=args.num_prompts, ++ tokenizer=tokenizer, ++ random_seed=args.seed, ++ fixed_output_len=args.hf_output_len, ++ ) ++ ++ elif args.dataset_name == "random": ++ input_requests = sample_random_requests( ++ prefix_len=args.random_prefix_len, ++ input_len=args.random_input_len, ++ output_len=args.random_output_len, ++ num_prompts=args.num_prompts, ++ range_ratio=args.random_range_ratio, ++ tokenizer=tokenizer, ++ ) + + else: + raise ValueError(f"Unknown dataset: {args.dataset_name}") + ++ gootput_config_dict = check_goodput_args(args) ++ + benchmark_result = asyncio.run( + benchmark( + backend=backend, + api_url=api_url, ++ base_url=base_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, ++ logprobs=args.logprobs, + best_of=args.best_of, +- use_beam_search=args.use_beam_search, + request_rate=args.request_rate, ++ burstiness=args.burstiness, + disable_tqdm=args.disable_tqdm, ++ profile=args.profile, ++ selected_percentile_metrics=args.percentile_metrics.split(","), ++ selected_percentiles=[ ++ float(p) for p in args.metric_percentiles.split(",") ++ ], ++ ignore_eos=args.ignore_eos, ++ gootput_config_dict=gootput_config_dict, ++ max_concurrency=args.max_concurrency, + )) + + # Save config and results to json + if args.save_result: +- result_json = {} ++ result_json: Dict[str, Any] = {} + + # Setup + current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") +@@ -427,7 +905,6 @@ def main(args: argparse.Namespace): + result_json["model_id"] = model_id + result_json["tokenizer_id"] = tokenizer_id + result_json["best_of"] = args.best_of +- result_json["use_beam_search"] = args.use_beam_search + result_json["num_prompts"] = args.num_prompts + + # Metadata +@@ -444,21 +921,27 @@ def main(args: argparse.Namespace): + # Traffic + result_json["request_rate"] = ( + args.request_rate if args.request_rate < float("inf") else "inf") ++ result_json["burstiness"] = args.burstiness ++ result_json["max_concurrency"] = args.max_concurrency + + # Merge with benchmark result + result_json = {**result_json, **benchmark_result} + + # Save to file + base_model_id = model_id.split("/")[-1] +- file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" #noqa ++ max_concurrency_str = (f"-concurrency{args.max_concurrency}" ++ if args.max_concurrency is not None else "") ++ file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa ++ if args.result_filename: ++ file_name = args.result_filename + if args.result_dir: + file_name = os.path.join(args.result_dir, file_name) +- with open(file_name, "w") as outfile: ++ with open(file_name, "w", encoding='utf-8') as outfile: + json.dump(result_json, outfile) + + + if __name__ == "__main__": +- parser = argparse.ArgumentParser( ++ parser = FlexibleArgumentParser( + description="Benchmark the online serving throughput.") + parser.add_argument( + "--backend", +@@ -491,13 +974,27 @@ if __name__ == "__main__": + "--dataset-name", + type=str, + default="sharegpt", +- choices=["sharegpt", "sonnet"], ++ choices=["sharegpt", "sonnet", "random", "hf"], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument("--dataset-path", + type=str, + default=None, +- help="Path to the dataset.") ++ help="Path to the sharegpt/sonnet dataset. " ++ "Or the huggingface dataset ID if using HF dataset.") ++ parser.add_argument( ++ "--max-concurrency", ++ type=int, ++ default=None, ++ help="Maximum number of concurrent requests. This can be used " ++ "to help simulate an environment where a higher level component " ++ "is enforcing a maximum number of concurrent requests. While the " ++ "--request-rate argument controls the rate at which requests are " ++ "initiated, this argument will control how many are actually allowed " ++ "to execute at a time. This means that when used in combination, the " ++ "actual request rate may be lower than specified with --request-rate, " ++ "if the server is not processing requests fast enough to keep up.") ++ + parser.add_argument( + "--model", + type=str, +@@ -508,7 +1005,7 @@ if __name__ == "__main__": + "--tokenizer", + type=str, + help= +- "Name or path of the tokenizer, if not using the default tokenizer.", ++ "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + ) + parser.add_argument( + "--best-of", +@@ -525,31 +1022,14 @@ if __name__ == "__main__": + help="Number of prompts to process.", + ) + parser.add_argument( +- "--sharegpt-output-len", ++ "--logprobs", + type=int, + default=None, +- help="Output length for each request. Overrides the output length " +- "from the ShareGPT dataset.") +- parser.add_argument( +- "--sonnet-input-len", +- type=int, +- default=550, +- help= +- "Number of input tokens per request, used only for sonnet dataset.", +- ) +- parser.add_argument( +- "--sonnet-output-len", +- type=int, +- default=150, +- help= +- "Number of output tokens per request, used only for sonnet dataset.", +- ) +- parser.add_argument( +- "--sonnet-prefix-len", +- type=int, +- default=200, +- help= +- "Number of prefix tokens per request, used only for sonnet dataset.", ++ help=("Number of logprobs-per-token to compute & return as part of " ++ "the request. If unspecified, then either (1) if beam search " ++ "is disabled, no logprobs are computed & a single dummy " ++ "logprob is returned for each token; or (2) if beam search " ++ "is enabled 1 logprob per token is computed"), + ) + parser.add_argument( + "--request-rate", +@@ -557,8 +1037,20 @@ if __name__ == "__main__": + default=float("inf"), + help="Number of requests per second. If this is inf, " + "then all the requests are sent at time 0. " +- "Otherwise, we use Poisson process to synthesize " +- "the request arrival times.", ++ "Otherwise, we use Poisson process or gamma distribution " ++ "to synthesize the request arrival times.", ++ ) ++ parser.add_argument( ++ "--burstiness", ++ type=float, ++ default=1.0, ++ help="Burstiness factor of the request generation. " ++ "Only take effect when request_rate is not inf. " ++ "Default value is 1, which follows Poisson process. " ++ "Otherwise, the request intervals follow a gamma distribution. " ++ "A lower burstiness value (0 < burstiness < 1) results in more " ++ "bursty requests. A higher burstiness value (burstiness > 1) " ++ "results in a more uniform arrival of requests.", + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( +@@ -571,6 +1063,12 @@ if __name__ == "__main__": + action="store_true", + help="Specify to disable tqdm progress bar.", + ) ++ parser.add_argument( ++ "--profile", ++ action="store_true", ++ help="Use Torch Profiler. The endpoint must be launched with " ++ "VLLM_TORCH_PROFILER_DIR to enable profiler.", ++ ) + parser.add_argument( + "--save-result", + action="store_true", +@@ -591,6 +1089,138 @@ if __name__ == "__main__": + help="Specify directory to save benchmark json results." + "If not specified, results are saved in the current directory.", + ) ++ parser.add_argument( ++ "--result-filename", ++ type=str, ++ default=None, ++ help="Specify the filename to save benchmark json results." ++ "If not specified, results will be saved in " ++ "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" ++ " format.", ++ ) ++ parser.add_argument( ++ "--ignore-eos", ++ action="store_true", ++ help="Set ignore_eos flag when sending the benchmark request." ++ "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") ++ parser.add_argument( ++ "--percentile-metrics", ++ type=str, ++ default="ttft,tpot,itl", ++ help="Comma-seperated list of selected metrics to report percentils. " ++ "This argument specifies the metrics to report percentiles. " ++ "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " ++ "Default value is \"ttft,tpot,itl\".") ++ parser.add_argument( ++ "--metric-percentiles", ++ type=str, ++ default="99", ++ help="Comma-seperated list of percentiles for selected metrics. " ++ "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " ++ "Default value is \"99\". " ++ "Use \"--percentile-metrics\" to select metrics.", ++ ) ++ parser.add_argument( ++ "--goodput", ++ nargs="+", ++ required=False, ++ help="Specify service level objectives for goodput as \"KEY:VALUE\" " ++ "pairs, where the key is a metric name, and the value is in " ++ "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " ++ "separated by spaces. Allowed request level metric names are " ++ "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " ++ "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " ++ "and the blog: https://hao-ai-lab.github.io/blogs/distserve") ++ ++ # group for dataset specific arguments ++ sonnet_group = parser.add_argument_group("sonnet dataset options") ++ sonnet_group.add_argument( ++ "--sonnet-input-len", ++ type=int, ++ default=550, ++ help= ++ "Number of input tokens per request, used only for sonnet dataset.", ++ ) ++ sonnet_group.add_argument( ++ "--sonnet-output-len", ++ type=int, ++ default=150, ++ help= ++ "Number of output tokens per request, used only for sonnet dataset.", ++ ) ++ sonnet_group.add_argument( ++ "--sonnet-prefix-len", ++ type=int, ++ default=200, ++ help= ++ "Number of prefix tokens per request, used only for sonnet dataset.", ++ ) ++ ++ sharegpt_group = parser.add_argument_group("sharegpt dataset options") ++ sharegpt_group.add_argument( ++ "--sharegpt-output-len", ++ type=int, ++ default=None, ++ help="Output length for each request. Overrides the output length " ++ "from the ShareGPT dataset.") ++ ++ random_group = parser.add_argument_group("random dataset options") ++ random_group.add_argument( ++ "--random-input-len", ++ type=int, ++ default=1024, ++ help= ++ "Number of input tokens per request, used only for random sampling.", ++ ) ++ random_group.add_argument( ++ "--random-output-len", ++ type=int, ++ default=128, ++ help= ++ "Number of output tokens per request, used only for random sampling.", ++ ) ++ random_group.add_argument( ++ "--random-range-ratio", ++ type=float, ++ default=1.0, ++ help="Range of sampled ratio of input/output length, " ++ "used only for random sampling.", ++ ) ++ random_group.add_argument( ++ "--random-prefix-len", ++ type=int, ++ default=0, ++ help="Number of fixed prefix tokens before random " ++ " context. The length range of context in a random " ++ " request is [random-prefix-len, " ++ " random-prefix-len + random-prefix-len * random-range-ratio).") ++ ++ hf_group = parser.add_argument_group("hf dataset options") ++ hf_group.add_argument("--hf-subset", ++ type=str, ++ default=None, ++ help="Subset of the HF dataset.") ++ hf_group.add_argument("--hf-split", ++ type=str, ++ default=None, ++ help="Split of the HF dataset.") ++ hf_group.add_argument( ++ "--hf-output-len", ++ type=int, ++ default=None, ++ help="Output length for each request. Overrides the output lengths " ++ "from the sampled HF dataset.", ++ ) ++ ++ parser.add_argument( ++ '--tokenizer-mode', ++ type=str, ++ default="auto", ++ choices=['auto', 'slow', 'mistral'], ++ help='The tokenizer mode.\n\n* "auto" will use the ' ++ 'fast tokenizer if available.\n* "slow" will ' ++ 'always use the slow tokenizer. \n* ' ++ '"mistral" will always use the `mistral_common` tokenizer.') + + args = parser.parse_args() + main(args) +diff --git a/benchmarks/benchmark_serving_guided.py b/benchmarks/benchmark_serving_guided.py +new file mode 100644 +index 0000000..4435d87 +--- /dev/null ++++ b/benchmarks/benchmark_serving_guided.py +@@ -0,0 +1,881 @@ ++r"""Benchmark online serving throughput with guided decoding. ++ ++On the server side, run one of the following commands: ++ (vLLM OpenAI API server) ++ vllm serve --disable-log-requests ++ ++ (TGI backend) ++ ./launch_tgi_server.sh ++ ++On the client side, run: ++ python benchmarks/benchmark_serving.py \ ++ --backend \ ++ --model \ ++ --dataset json \ ++ --guided-decoding-ratio 1.0 \ ++ --guided-decoding-backend xgrammar \ ++ --request-rate 10 \ ++ --num-prompts 1000 ++ ++ when using tgi backend, add ++ --endpoint /generate_stream ++ to the end of the command above. ++""" ++import argparse ++import asyncio ++import dataclasses ++import json ++import os ++import random ++import time ++import warnings ++from dataclasses import dataclass ++from typing import AsyncGenerator, List, Optional, Tuple ++ ++import datasets ++import numpy as np ++import pandas as pd ++from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, ++ RequestFuncOutput) ++from tqdm.asyncio import tqdm ++from transformers import PreTrainedTokenizerBase ++ ++try: ++ from vllm.transformers_utils.tokenizer import get_tokenizer ++except ImportError: ++ from backend_request_func import get_tokenizer ++ ++try: ++ from vllm.utils import FlexibleArgumentParser ++except ImportError: ++ from argparse import ArgumentParser as FlexibleArgumentParser ++ ++MILLISECONDS_TO_SECONDS_CONVERSION = 1000 ++ ++ ++@dataclass ++class BenchmarkMetrics: ++ completed: int ++ total_input: int ++ total_output: int ++ request_throughput: float ++ request_goodput: float ++ output_throughput: float ++ total_token_throughput: float ++ mean_ttft_ms: float ++ median_ttft_ms: float ++ std_ttft_ms: float ++ percentiles_ttft_ms: List[Tuple[float, float]] ++ mean_tpot_ms: float ++ median_tpot_ms: float ++ std_tpot_ms: float ++ percentiles_tpot_ms: List[Tuple[float, float]] ++ mean_itl_ms: float ++ median_itl_ms: float ++ std_itl_ms: float ++ percentiles_itl_ms: List[Tuple[float, float]] ++ # E2EL stands for end-to-end latency per request. ++ # It is the time taken on the client side from sending ++ # a request to receiving a complete response. ++ mean_e2el_ms: float ++ median_e2el_ms: float ++ std_e2el_ms: float ++ percentiles_e2el_ms: List[Tuple[float, float]] ++ ++ ++@dataclasses.dataclass ++class SampleRequest: ++ """A class representing a single inference request for benchmarking. ++ ++ Attributes: ++ prompt: The input text prompt for the model. ++ multi_modal_data: Optional dictionary containing multi-modal data (e.g. ++ images). ++ prompt_len: The length of the prompt in tokens. ++ expected_output_len: The expected length of the output in tokens. ++ """ ++ prompt: str ++ prompt_len: int ++ expected_output_len: int ++ schema: dict ++ structure_type: str ++ completion: str = None ++ ++ ++def sample_requests(tokenizer: PreTrainedTokenizerBase, ++ args: argparse.Namespace) -> List[SampleRequest]: ++ if args.dataset == 'json': ++ if args.json_schema_path is None: ++ dir_path = os.path.dirname(os.path.realpath(__file__)) ++ args.json_schema_path = os.path.join(dir_path, ++ "structured_schemas", ++ "structured_schema_1.json") ++ with open(args.json_schema_path) as f: ++ schema = json.load(f) ++ prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501 ++ input_len = len(tokenizer(prompt).input_ids) ++ print(f"Input length of the prompt: {input_len} tokens") ++ requests = [ ++ SampleRequest(prompt=prompt, ++ prompt_len=input_len, ++ expected_output_len=args.output_len, ++ schema=schema, ++ structure_type=args.structure_type) ++ for _ in range(args.num_prompts) ++ ] ++ ++ elif args.dataset == "grammar": ++ schema = """ ++ ?start: select_statement ++ ++ ?select_statement: "SELECT " column_list " FROM " table_name ++ ++ ?column_list: column_name ("," column_name)* ++ ++ ?table_name: identifier ++ ++ ?column_name: identifier ++ ++ ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ ++ """ ++ prompt = "Generate an SQL query to show the 'username' \ ++ and 'email' from the 'users' table." ++ ++ input_len = len(tokenizer(prompt).input_ids) ++ print(f"Input length of the prompt: {input_len} tokens") ++ requests = [ ++ SampleRequest(prompt=prompt, ++ prompt_len=input_len, ++ expected_output_len=args.output_len, ++ schema=schema, ++ structure_type=args.structure_type) ++ for _ in range(args.num_prompts) ++ ] ++ ++ elif args.dataset == "regex": ++ regex = r"\w+@\w+\.com\n" ++ args.regex = regex ++ prompt = "Generate an email address for Alan Turing, \ ++ who works in Enigma. End in .com and new line. \ ++ Example result: alan.turing@enigma.com\n" ++ ++ input_len = len(tokenizer(prompt).input_ids) ++ print(f"Input length of the prompt: {input_len} tokens") ++ requests = [ ++ SampleRequest(prompt=prompt, ++ prompt_len=input_len, ++ expected_output_len=args.output_len, ++ schema=regex, ++ structure_type=args.structure_type) ++ for _ in range(args.num_prompts) ++ ] ++ ++ elif args.dataset == "choice": ++ choice = ["Positive", "Negative"] ++ args.choice = choice ++ prompt = "Classify this sentiment: vLLM is wonderful!" ++ input_len = len(tokenizer(prompt).input_ids) ++ print(f"Input length of the prompt: {input_len} tokens") ++ requests = [ ++ SampleRequest(prompt=prompt, ++ prompt_len=input_len, ++ expected_output_len=args.output_len, ++ schema=choice, ++ structure_type=args.structure_type) ++ for _ in range(args.num_prompts) ++ ] ++ ++ elif args.dataset == "xgrammar_bench": ++ requests: List[SampleRequest] = [] ++ dataset = datasets.load_dataset("NousResearch/json-mode-eval", ++ split="train") ++ print(f"dataset has {len(dataset)} entries") ++ len_dataset = len(dataset) ++ for data_point_idx in range(args.num_prompts): ++ idx = data_point_idx ++ while idx >= len_dataset: ++ idx -= len_dataset ++ schema = dataset["schema"][idx] ++ prompt = tokenizer.apply_chat_template(dataset["prompt"][idx], ++ tokenize=False) ++ input_len = len(tokenizer(prompt).input_ids) ++ completion = dataset["completion"][idx] ++ ++ requests.append( ++ SampleRequest(prompt=prompt, ++ prompt_len=input_len, ++ expected_output_len=args.output_len, ++ schema=schema, ++ structure_type=args.structure_type, ++ completion=completion)) ++ ++ return requests ++ ++ ++async def get_request( ++ input_requests: List[SampleRequest], ++ request_rate: float, ++ burstiness: float = 1.0, ++) -> AsyncGenerator[Tuple[int, SampleRequest], None]: ++ """ ++ Asynchronously generates requests at a specified rate ++ with OPTIONAL burstiness. ++ ++ Args: ++ input_requests: ++ A list of input requests, each represented as a tuple. ++ request_rate: ++ The rate at which requests are generated (requests/s). ++ burstiness (optional): ++ The burstiness factor of the request generation. ++ Only takes effect when request_rate is not inf. ++ Default value is 1, which follows a Poisson process. ++ Otherwise, the request intervals follow a gamma distribution. ++ A lower burstiness value (0 < burstiness < 1) results ++ in more bursty requests, while a higher burstiness value ++ (burstiness > 1) results in a more uniform arrival of requests. ++ """ ++ input_requests = iter(input_requests) ++ ++ # Calculate scale parameter theta to maintain the desired request_rate. ++ assert burstiness > 0, ( ++ f"A positive burstiness factor is expected, but given {burstiness}.") ++ theta = 1.0 / (request_rate * burstiness) ++ ++ for i, request in enumerate(input_requests): ++ yield i, request ++ ++ if request_rate == float("inf"): ++ # If the request rate is infinity, then we don't need to wait. ++ continue ++ ++ # Sample the request interval from the gamma distribution. ++ # If burstiness is 1, it follows exponential distribution. ++ interval = np.random.gamma(shape=burstiness, scale=theta) ++ # The next request will be sent after the interval. ++ await asyncio.sleep(interval) ++ ++ ++def calculate_metrics( ++ input_requests: List[Tuple[str, int, int]], ++ outputs: List[RequestFuncOutput], ++ dur_s: float, ++ tokenizer: PreTrainedTokenizerBase, ++ selected_percentile_metrics: List[str], ++ selected_percentiles: List[float], ++) -> Tuple[BenchmarkMetrics, List[int]]: ++ actual_output_lens: List[int] = [] ++ total_input = 0 ++ completed = 0 ++ good_completed = 0 ++ itls: List[float] = [] ++ tpots: List[float] = [] ++ all_tpots: List[float] = [] ++ ttfts: List[float] = [] ++ e2els: List[float] = [] ++ for i in range(len(outputs)): ++ if outputs[i].success: ++ # We use the tokenizer to count the number of output tokens for all ++ # serving backends instead of looking at len(outputs[i].itl) since ++ # multiple output tokens may be bundled together ++ # Note : this may inflate the output token count slightly ++ output_len = len( ++ tokenizer(outputs[i].generated_text, ++ add_special_tokens=False).input_ids) ++ actual_output_lens.append(output_len) ++ total_input += input_requests[i].prompt_len ++ tpot = 0 ++ if output_len > 1: ++ tpot = (outputs[i].latency - outputs[i].ttft) / (output_len - ++ 1) ++ tpots.append(tpot) ++ outputs[i].tpot = sum(tpots) / len(tpots) if len(tpots) else 0 ++ # Note: if output_len <= 1, we regard tpot as 0 for goodput ++ all_tpots.append(tpot) ++ itls += outputs[i].itl ++ ttfts.append(outputs[i].ttft) ++ e2els.append(outputs[i].latency) ++ completed += 1 ++ else: ++ actual_output_lens.append(0) ++ ++ if completed == 0: ++ warnings.warn( ++ "All requests failed. This is likely due to a misconfiguration " ++ "on the benchmark arguments.", ++ stacklevel=2) ++ metrics = BenchmarkMetrics( ++ completed=completed, ++ total_input=total_input, ++ total_output=sum(actual_output_lens), ++ request_throughput=completed / dur_s, ++ request_goodput=good_completed / dur_s, ++ output_throughput=sum(actual_output_lens) / dur_s, ++ total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, ++ mean_ttft_ms=np.mean(ttfts or 0) * ++ 1000, # ttfts is empty if streaming is not supported by backend ++ std_ttft_ms=np.std(ttfts or 0) * 1000, ++ median_ttft_ms=np.median(ttfts or 0) * 1000, ++ percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) ++ for p in selected_percentiles], ++ mean_tpot_ms=np.mean(tpots or 0) * 1000, ++ std_tpot_ms=np.std(tpots or 0) * 1000, ++ median_tpot_ms=np.median(tpots or 0) * 1000, ++ percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) ++ for p in selected_percentiles], ++ mean_itl_ms=np.mean(itls or 0) * 1000, ++ std_itl_ms=np.std(itls or 0) * 1000, ++ median_itl_ms=np.median(itls or 0) * 1000, ++ percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) ++ for p in selected_percentiles], ++ mean_e2el_ms=np.mean(e2els or 0) * 1000, ++ std_e2el_ms=np.std(e2els or 0) * 1000, ++ median_e2el_ms=np.median(e2els or 0) * 1000, ++ percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) ++ for p in selected_percentiles], ++ ) ++ ++ return metrics, actual_output_lens ++ ++ ++async def benchmark( ++ backend: str, ++ api_url: str, ++ base_url: str, ++ model_id: str, ++ tokenizer: PreTrainedTokenizerBase, ++ input_requests: List[SampleRequest], ++ request_rate: float, ++ burstiness: float, ++ disable_tqdm: bool, ++ profile: bool, ++ selected_percentile_metrics: List[str], ++ selected_percentiles: List[str], ++ ignore_eos: bool, ++ max_concurrency: Optional[int], ++ guided_decoding_ratio: float, ++ guided_decoding_backend: str, ++): ++ if backend in ASYNC_REQUEST_FUNCS: ++ request_func = ASYNC_REQUEST_FUNCS[backend] ++ else: ++ raise ValueError(f"Unknown backend: {backend}") ++ ++ def prepare_extra_body(request) -> dict: ++ extra_body = {} ++ # Add the schema to the extra_body ++ extra_body[request.structure_type] = request.schema ++ # Add the specific guided_decoding_backend ++ extra_body["guided_decoding_backend"] = guided_decoding_backend ++ return extra_body ++ ++ print("Starting initial single prompt test run...") ++ guided_decoding_req_idx = random.sample( ++ range(len(input_requests)), ++ int(len(input_requests) * guided_decoding_ratio)) ++ ++ test_request = input_requests[0] ++ test_input = RequestFuncInput( ++ model=model_id, ++ prompt=test_request.prompt, ++ api_url=api_url, ++ prompt_len=test_request.prompt_len, ++ output_len=test_request.expected_output_len, ++ ignore_eos=ignore_eos, ++ extra_body=prepare_extra_body(test_request), ++ ) ++ test_output = await request_func(request_func_input=test_input) ++ if not test_output.success: ++ raise ValueError( ++ "Initial test run failed - Please make sure benchmark arguments " ++ f"are correctly specified. Error: {test_output.error}") ++ else: ++ print("Initial test run completed. Starting main benchmark run...") ++ ++ if profile: ++ print("Starting profiler...") ++ profile_input = RequestFuncInput( ++ model=model_id, ++ prompt=test_request.prompt, ++ api_url=base_url + "/start_profile", ++ prompt_len=test_request.prompt_len, ++ output_len=test_request.expected_output_len, ++ ignore_eos=ignore_eos, ++ extra_body=prepare_extra_body(test_request), ++ ) ++ profile_output = await request_func(request_func_input=profile_input) ++ if profile_output.success: ++ print("Profiler started") ++ ++ if burstiness == 1.0: ++ distribution = "Poisson process" ++ else: ++ distribution = "Gamma distribution" ++ ++ print(f"Traffic request rate: {request_rate}") ++ print(f"Burstiness factor: {burstiness} ({distribution})") ++ print(f"Maximum request concurrency: {max_concurrency}") ++ ++ pbar = None if disable_tqdm else tqdm(total=len(input_requests)) ++ ++ # This can be used once the minimum Python version is 3.10 or higher, ++ # and it will simplify the code in limited_request_func. ++ # semaphore = (asyncio.Semaphore(max_concurrency) ++ # if max_concurrency else contextlib.nullcontext()) ++ semaphore = (asyncio.Semaphore(max_concurrency) ++ if max_concurrency else None) ++ ++ async def limited_request_func(request_func_input, pbar): ++ if semaphore is None: ++ return await request_func(request_func_input=request_func_input, ++ pbar=pbar) ++ async with semaphore: ++ return await request_func(request_func_input=request_func_input, ++ pbar=pbar) ++ ++ benchmark_start_time = time.perf_counter() ++ tasks: List[asyncio.Task] = [] ++ expected: List[str] = [] ++ async for i, request in get_request(input_requests, request_rate, ++ burstiness): ++ extra_body = prepare_extra_body( ++ request) if i in guided_decoding_req_idx else None ++ request_func_input = RequestFuncInput( ++ model=model_id, ++ prompt=request.prompt, ++ api_url=api_url, ++ prompt_len=request.prompt_len, ++ output_len=request.expected_output_len, ++ ignore_eos=ignore_eos, ++ extra_body=extra_body, ++ ) ++ expected.append(request.completion) ++ tasks.append( ++ asyncio.create_task( ++ limited_request_func(request_func_input=request_func_input, ++ pbar=pbar))) ++ outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) ++ ++ if profile: ++ print("Stopping profiler...") ++ profile_input = RequestFuncInput( ++ model=model_id, ++ prompt=test_request.prompt, ++ api_url=base_url + "/stop_profile", ++ prompt_len=test_request.prompt_len, ++ output_len=test_request.expected_output_len, ++ extra_body={test_request.structure_type: test_request.schema}, ++ ) ++ profile_output = await request_func(request_func_input=profile_input) ++ if profile_output.success: ++ print("Profiler stopped") ++ ++ if pbar is not None: ++ pbar.close() ++ ++ benchmark_duration = time.perf_counter() - benchmark_start_time ++ ++ metrics, actual_output_lens = calculate_metrics( ++ input_requests=input_requests, ++ outputs=outputs, ++ dur_s=benchmark_duration, ++ tokenizer=tokenizer, ++ selected_percentile_metrics=selected_percentile_metrics, ++ selected_percentiles=selected_percentiles, ++ ) ++ ++ print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) ++ print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) ++ print("{:<40} {:<10.2f}".format("Benchmark duration (s):", ++ benchmark_duration)) ++ print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) ++ print("{:<40} {:<10}".format("Total generated tokens:", ++ metrics.total_output)) ++ print("{:<40} {:<10.2f}".format("Request throughput (req/s):", ++ metrics.request_throughput)) ++ print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", ++ metrics.output_throughput)) ++ print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", ++ metrics.total_token_throughput)) ++ ++ result = { ++ "duration": ++ benchmark_duration, ++ "completed": ++ metrics.completed, ++ "total_input_tokens": ++ metrics.total_input, ++ "total_output_tokens": ++ metrics.total_output, ++ "request_throughput": ++ metrics.request_throughput, ++ "output_throughput": ++ metrics.output_throughput, ++ "total_token_throughput": ++ metrics.total_token_throughput, ++ "ttft_description": ++ pd.Series([output.ttft for output in outputs]).describe().to_dict(), ++ "tpot_description": ++ pd.Series([output.tpot for output in outputs]).describe().to_dict(), ++ "input_lens": [output.prompt_len for output in outputs], ++ "output_lens": ++ actual_output_lens, ++ "ttfts": [output.ttft for output in outputs], ++ "itls": [output.itl for output in outputs], ++ "errors": [output.error for output in outputs], ++ } ++ ++ ret = [{ ++ 'generated': output.generated_text, ++ 'expected': gt ++ } for output, gt in zip(outputs, expected)] ++ ++ def process_one_metric( ++ # E.g., "ttft" ++ metric_attribute_name: str, ++ # E.g., "TTFT" ++ metric_name: str, ++ # E.g., "Time to First Token" ++ metric_header: str, ++ ): ++ # This function prints and adds statistics of the specified ++ # metric. ++ if metric_attribute_name not in selected_percentile_metrics: ++ return ++ print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) ++ print("{:<40} {:<10.2f}".format( ++ f"Mean {metric_name} (ms):", ++ getattr(metrics, f"mean_{metric_attribute_name}_ms"))) ++ print("{:<40} {:<10.2f}".format( ++ f"Median {metric_name} (ms):", ++ getattr(metrics, f"median_{metric_attribute_name}_ms"))) ++ result[f"mean_{metric_attribute_name}_ms"] = getattr( ++ metrics, f"mean_{metric_attribute_name}_ms") ++ result[f"median_{metric_attribute_name}_ms"] = getattr( ++ metrics, f"median_{metric_attribute_name}_ms") ++ result[f"std_{metric_attribute_name}_ms"] = getattr( ++ metrics, f"std_{metric_attribute_name}_ms") ++ for p, value in getattr(metrics, ++ f"percentiles_{metric_attribute_name}_ms"): ++ p_word = str(int(p)) if int(p) == p else str(p) ++ print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", ++ value)) ++ result[f"p{p_word}_{metric_attribute_name}_ms"] = value ++ ++ process_one_metric("ttft", "TTFT", "Time to First Token") ++ process_one_metric("tpot", "TPOT", ++ "Time per Output Token (excl. 1st token)") ++ process_one_metric("itl", "ITL", "Inter-token Latency") ++ process_one_metric("e2el", "E2EL", "End-to-end Latency") ++ ++ print("=" * 50) ++ ++ return result, ret ++ ++ ++def evaluate(ret, args): ++ ++ def _eval_correctness_json(expected, actual): ++ # extract json string from string using regex ++ import re ++ actual = actual.replace('\n', '').replace(' ', '').strip() ++ try: ++ actual = re.search(r'\{.*\}', actual).group() ++ actual = json.loads(actual) ++ except Exception: ++ return False ++ ++ return True ++ ++ def _eval_correctness_choice(expected, actual): ++ return actual in args.choice ++ ++ def _eval_correctness_regex(expected, actual): ++ import re ++ return re.match(args.regex, actual) is not None ++ ++ def _eval_correctness(expected, actual): ++ if args.structure_type == 'guided_json': ++ return _eval_correctness_json(expected, actual) ++ elif args.structure_type == 'guided_regex': ++ return _eval_correctness_regex(expected, actual) ++ elif args.structure_type == 'guided_choice': ++ return _eval_correctness_choice(expected, actual) ++ else: ++ return None ++ ++ scores = [] ++ for res in ret: ++ score = _eval_correctness(res['expected'], res['generated']) ++ res['correctness'] = score ++ scores.append(score) ++ ++ not_none_scores = [score for score in scores if score is not None] ++ ++ return (sum(not_none_scores) / len(not_none_scores) * ++ 100) if len(not_none_scores) > 0 else None ++ ++ ++def main(args: argparse.Namespace): ++ print(args) ++ random.seed(args.seed) ++ np.random.seed(args.seed) ++ ++ backend = args.backend ++ model_id = args.model ++ tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model ++ ++ if args.base_url is not None: ++ api_url = f"{args.base_url}{args.endpoint}" ++ base_url = f"{args.base_url}" ++ else: ++ api_url = f"http://{args.host}:{args.port}{args.endpoint}" ++ base_url = f"http://{args.host}:{args.port}" ++ ++ tokenizer = get_tokenizer(tokenizer_id, ++ trust_remote_code=args.trust_remote_code) ++ ++ if args.dataset == 'grammar': ++ args.structure_type = 'guided_grammar' ++ elif args.dataset == 'regex': ++ args.structure_type = 'guided_regex' ++ elif args.dataset == 'choice': ++ args.structure_type = 'guided_choice' ++ else: ++ args.structure_type = 'guided_json' ++ ++ if args.no_guided_decoding: ++ args.guided_decoding_ratio = 0 ++ if args.save_results: ++ result_file_name = f'{args.guided_decoding_ratio}guided' ++ result_file_name += f"_{backend}" ++ result_file_name += f"_{args.request_rate}qps" ++ result_file_name += f"_{args.model.split('/')[-1]}" ++ result_file_name += f"_{args.dataset}" ++ result_file_name += f"_{args.num_prompts}" ++ result_file_name += f"_out{args.output_len}" ++ result_file_name += ".txt" ++ else: ++ result_file_name = None ++ ++ input_requests = sample_requests(tokenizer, args) ++ ++ benchmark_result, ret = asyncio.run( ++ benchmark( ++ backend=backend, ++ api_url=api_url, ++ base_url=base_url, ++ model_id=model_id, ++ tokenizer=tokenizer, ++ input_requests=input_requests, ++ request_rate=args.request_rate, ++ burstiness=args.burstiness, ++ disable_tqdm=args.disable_tqdm, ++ profile=args.profile, ++ selected_percentile_metrics=args.percentile_metrics.split(","), ++ selected_percentiles=[ ++ float(p) for p in args.metric_percentiles.split(",") ++ ], ++ ignore_eos=args.ignore_eos, ++ max_concurrency=args.max_concurrency, ++ guided_decoding_ratio=args.guided_decoding_ratio, ++ guided_decoding_backend=args.guided_decoding_backend, ++ )) ++ ++ # Save config and results to json ++ score = evaluate(ret, args) ++ print("correct_rate(%)", score, '\n') ++ if args.save_results: ++ results = { ++ "backend": ++ backend, ++ "model_id": ++ model_id, ++ "tokenizer_id": ++ tokenizer_id, ++ "num_prompts": ++ args.num_prompts, ++ "request_rate": ++ args.request_rate if args.request_rate < float("inf") else "inf", ++ "burstiness": ++ args.burstiness, ++ "max_concurrency": ++ args.max_concurrency, ++ "correct_rate(%)": ++ score ++ } ++ results = {"outputs": ret, **results, **benchmark_result} ++ ++ # Save to file ++ if args.result_filename: ++ result_file_name = args.result_filename ++ if args.result_dir: ++ result_file_name = os.path.join(args.result_dir, result_file_name) ++ with open(result_file_name, "w", encoding='utf-8') as outfile: ++ json.dump(results, outfile, indent=4) ++ ++ ++if __name__ == "__main__": ++ parser = FlexibleArgumentParser( ++ description="Benchmark the online serving throughput.") ++ parser.add_argument( ++ "--backend", ++ type=str, ++ default="vllm", ++ choices=list(ASYNC_REQUEST_FUNCS.keys()), ++ ) ++ parser.add_argument( ++ "--base-url", ++ type=str, ++ default=None, ++ help="Server or API base url if not using http host and port.", ++ ) ++ parser.add_argument("--host", type=str, default="localhost") ++ parser.add_argument("--port", type=int, default=8000) ++ parser.add_argument( ++ "--endpoint", ++ type=str, ++ default="/v1/completions", ++ help="API endpoint.", ++ ) ++ parser.add_argument( ++ "--dataset", ++ default='json', ++ choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench']) ++ parser.add_argument("--json_schema_path", ++ type=str, ++ default=None, ++ help="Path to json schema.") ++ parser.add_argument( ++ "--max-concurrency", ++ type=int, ++ default=None, ++ help="Maximum number of concurrent requests. This can be used " ++ "to help simulate an environment where a higher level component " ++ "is enforcing a maximum number of concurrent requests. While the " ++ "--request-rate argument controls the rate at which requests are " ++ "initiated, this argument will control how many are actually allowed " ++ "to execute at a time. This means that when used in combination, the " ++ "actual request rate may be lower than specified with --request-rate, " ++ "if the server is not processing requests fast enough to keep up.") ++ parser.add_argument( ++ "--model", ++ type=str, ++ required=True, ++ help="Name of the model.", ++ ) ++ parser.add_argument( ++ "--tokenizer", ++ type=str, ++ help= ++ "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ++ ) ++ parser.add_argument( ++ "--num-prompts", ++ type=int, ++ default=1000, ++ help="Number of prompts to process.", ++ ) ++ parser.add_argument( ++ "--output-len", ++ type=int, ++ default=128, ++ help="Number of output tokens.", ++ ) ++ parser.add_argument( ++ "--request-rate", ++ type=float, ++ default=float("inf"), ++ help="Number of requests per second. If this is inf, " ++ "then all the requests are sent at time 0. " ++ "Otherwise, we use Poisson process or gamma distribution " ++ "to synthesize the request arrival times.", ++ ) ++ parser.add_argument( ++ "--burstiness", ++ type=float, ++ default=1.0, ++ help="Burstiness factor of the request generation. " ++ "Only take effect when request_rate is not inf. " ++ "Default value is 1, which follows Poisson process. " ++ "Otherwise, the request intervals follow a gamma distribution. " ++ "A lower burstiness value (0 < burstiness < 1) results in more " ++ "bursty requests. A higher burstiness value (burstiness > 1) " ++ "results in a more uniform arrival of requests.", ++ ) ++ parser.add_argument("--seed", type=int, default=0) ++ parser.add_argument( ++ "--trust-remote-code", ++ action="store_true", ++ help="Trust remote code from huggingface", ++ ) ++ parser.add_argument( ++ "--disable-tqdm", ++ action="store_true", ++ help="Specify to disable tqdm progress bar.", ++ ) ++ parser.add_argument( ++ "--save-results", ++ action="store_true", ++ help="Specify to save benchmark results to a json file", ++ ) ++ parser.add_argument( ++ "--profile", ++ action="store_true", ++ help="Use Torch Profiler. The endpoint must be launched with " ++ "VLLM_TORCH_PROFILER_DIR to enable profiler.", ++ ) ++ parser.add_argument( ++ "--result-dir", ++ type=str, ++ default=None, ++ help="Specify directory to save benchmark json results." ++ "If not specified, results are saved in the current directory.", ++ ) ++ parser.add_argument( ++ "--result-filename", ++ type=str, ++ default=None, ++ help="Specify the filename to save benchmark json results." ++ "If not specified, results will be saved in " ++ "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" ++ " format.", ++ ) ++ parser.add_argument( ++ "--ignore-eos", ++ action="store_true", ++ help="Set ignore_eos flag when sending the benchmark request." ++ "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") ++ parser.add_argument( ++ "--percentile-metrics", ++ type=str, ++ default="ttft,tpot,itl", ++ help="Comma-seperated list of selected metrics to report percentils. " ++ "This argument specifies the metrics to report percentiles. " ++ "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " ++ "Default value is \"ttft,tpot,itl\".") ++ parser.add_argument( ++ "--metric-percentiles", ++ type=str, ++ default="99", ++ help="Comma-seperated list of percentiles for selected metrics. " ++ "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " ++ "Default value is \"99\". " ++ "Use \"--percentile-metrics\" to select metrics.", ++ ) ++ parser.add_argument("--no-guided-decoding", ++ action='store_true', ++ default=False, ++ help="Whether to disable JSON decoding or not.") ++ parser.add_argument("--guided-decoding-ratio", ++ type=float, ++ default=1.0, ++ help="Ratio of Guided Decoding requests") ++ parser.add_argument("--guided-decoding-backend", ++ type=str, ++ choices=["outlines", "lm-format-enforcer", "xgrammar"], ++ default="xgrammar", ++ help="Backend to use for guided decoding") ++ ++ args = parser.parse_args() ++ main(args) +diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py +index 695d06e..c1b10b3 100644 +--- a/benchmarks/benchmark_throughput.py ++++ b/benchmarks/benchmark_throughput.py +@@ -1,24 +1,99 @@ + """Benchmark offline inference throughput.""" + import argparse ++import dataclasses + import json + import random + import time +-from typing import List, Optional, Tuple ++from functools import cache ++from typing import Dict, List, Optional, Tuple + + import torch ++import uvloop ++from PIL import Image + from tqdm import tqdm + from transformers import (AutoModelForCausalLM, AutoTokenizer, + PreTrainedTokenizerBase) + +-from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS ++from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs ++from vllm.entrypoints.openai.api_server import ( ++ build_async_engine_client_from_engine_args) ++from vllm.inputs import TextPrompt ++from vllm.lora.request import LoRARequest ++from vllm.lora.utils import get_adapter_absolute_path ++from vllm.multimodal import MultiModalDataDict ++from vllm.sampling_params import BeamSearchParams ++from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer ++from vllm.utils import FlexibleArgumentParser, merge_async_iterators + + +-def sample_requests( +- dataset_path: str, +- num_requests: int, +- tokenizer: PreTrainedTokenizerBase, +- fixed_output_len: Optional[int], +-) -> List[Tuple[str, int, int]]: ++@dataclasses.dataclass ++class SampleRequest: ++ """A class representing a single inference request for benchmarking. ++ ++ Attributes: ++ prompt: The input text prompt for the model. ++ prompt_len: The length of the prompt in tokens. ++ expected_output_len: The expected length of the output in tokens. ++ multi_modal_data: Optional dictionary containing multi-modal data (e.g. ++ images). ++ lora_request: Optional LoRARequest specifying the LoRA to use. ++ """ ++ prompt: str ++ prompt_len: int ++ expected_output_len: int ++ multi_modal_data: Optional[MultiModalDataDict] = None ++ lora_request: Optional[LoRARequest] = None ++ ++ ++def _get_prompt_for_image_model(question: str, *, model: str) -> str: ++ """Prepend and append special tokens around the question to form a prompt. ++ ++ Args: ++ question: The input question text to wrap with special tokens ++ model: The name of the model being used, to determine which special ++ tokens to add ++ ++ Returns: ++ The formatted prompt string with appropriate special tokens for the ++ model ++ ++ Raises: ++ ValueError: If an unsupported model name is provided ++ """ ++ model = model.lower() ++ if "pixtral" in model: ++ return f"[INST]{question}\n[IMG][/INST]" ++ raise ValueError(f"Unsupported model {model}") ++ ++ ++@cache ++def lora_path_on_disk(lora_path: str) -> str: ++ return get_adapter_absolute_path(lora_path) ++ ++ ++lora_tokenizer_cache: Dict[int, AnyTokenizer] = {} ++ ++ ++def get_random_lora_request( ++ args: argparse.Namespace ++) -> Tuple[LoRARequest, Optional[AnyTokenizer]]: ++ global lora_tokenizer_cache ++ lora_id = random.randint(1, args.max_loras) ++ lora_request = LoRARequest(lora_name=str(lora_id), ++ lora_int_id=lora_id, ++ lora_path=lora_path_on_disk(args.lora_path)) ++ if lora_id not in lora_tokenizer_cache: ++ lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request) ++ return lora_request, lora_tokenizer_cache[lora_id] ++ ++ ++def sample_requests(tokenizer: PreTrainedTokenizerBase, ++ args: argparse.Namespace) -> List[SampleRequest]: ++ ++ dataset_path: str = args.dataset ++ num_requests: int = args.num_prompts ++ fixed_output_len: Optional[int] = args.output_len ++ model: str = args.model + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + +@@ -27,24 +102,46 @@ def sample_requests( + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] +- # Only keep the first two turns of each conversation. +- dataset = [(data["conversations"][0]["value"], +- data["conversations"][1]["value"]) for data in dataset] +- + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short +- filtered_dataset: List[Tuple[str, int, int]] = [] +- for i in range(len(dataset)): ++ filtered_dataset: List[SampleRequest] = [] ++ for data in tqdm(dataset, ++ total=len(filtered_dataset), ++ desc="sampling requests"): + if len(filtered_dataset) == num_requests: + break + ++ # Only keep the first two turns of each conversation. ++ prompt = data["conversations"][0]["value"] ++ completion = data["conversations"][1]["value"] ++ ++ multi_modal_data: Optional[MultiModalDataDict] = None ++ if "image" in data: ++ multi_modal_data = multi_modal_data or {} ++ image_path = data["image"] ++ # TODO(vllm-project/vllm/issues/9778): Support multiple images. ++ assert isinstance(image_path, ++ str), "Only support single image input" ++ try: ++ multi_modal_data["image"] = Image.open(image_path).convert( ++ "RGB") ++ except FileNotFoundError: ++ # Ignore datapoint where asset is missing ++ continue ++ prompt = _get_prompt_for_image_model(question=prompt, model=model) ++ ++ request_tokenizer = tokenizer ++ lora_request: Optional[LoRARequest] = None ++ if args.enable_lora: ++ lora_request, lora_tokenizer = get_random_lora_request(args) ++ if lora_tokenizer: ++ request_tokenizer = lora_tokenizer ++ + # Tokenize the prompts and completions. +- prompt = dataset[i][0] +- prompt_token_ids = tokenizer(prompt).input_ids +- completion = dataset[i][1] +- completion_token_ids = tokenizer(completion).input_ids ++ prompt_token_ids = request_tokenizer(prompt).input_ids ++ completion_token_ids = request_tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len +@@ -54,85 +151,124 @@ def sample_requests( + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue +- filtered_dataset.append((prompt, prompt_len, output_len)) ++ filtered_dataset.append( ++ SampleRequest(prompt=prompt, ++ prompt_len=prompt_len, ++ expected_output_len=output_len, ++ multi_modal_data=multi_modal_data, ++ lora_request=lora_request)) + + return filtered_dataset + + + def run_vllm( +- requests: List[Tuple[str, int, int]], +- model: str, +- tokenizer: str, +- quantization: Optional[str], +- tensor_parallel_size: int, +- seed: int, ++ requests: List[SampleRequest], + n: int, +- use_beam_search: bool, +- trust_remote_code: bool, +- dtype: str, +- max_model_len: Optional[int], +- enforce_eager: bool, +- kv_cache_dtype: str, +- quantization_param_path: Optional[str], +- device: str, +- enable_prefix_caching: bool, +- enable_chunked_prefill: bool, +- max_num_batched_tokens: int, +- gpu_memory_utilization: float = 0.9, +- download_dir: Optional[str] = None, ++ engine_args: EngineArgs, + ) -> float: + from vllm import LLM, SamplingParams +- llm = LLM( +- model=model, +- tokenizer=tokenizer, +- quantization=quantization, +- tensor_parallel_size=tensor_parallel_size, +- seed=seed, +- trust_remote_code=trust_remote_code, +- dtype=dtype, +- max_model_len=max_model_len, +- gpu_memory_utilization=gpu_memory_utilization, +- enforce_eager=enforce_eager, +- kv_cache_dtype=kv_cache_dtype, +- quantization_param_path=quantization_param_path, +- device=device, +- enable_prefix_caching=enable_prefix_caching, +- download_dir=download_dir, +- enable_chunked_prefill=enable_chunked_prefill, +- max_num_batched_tokens=max_num_batched_tokens, +- ) ++ llm = LLM(**dataclasses.asdict(engine_args)) + + # Add the requests to the engine. +- prompts = [] +- sampling_params = [] +- for prompt, _, output_len in requests: +- prompts.append(prompt) ++ prompts: List[TextPrompt] = [] ++ sampling_params: List[SamplingParams] = [] ++ for request in requests: ++ prompts.append( ++ TextPrompt(prompt=request.prompt, ++ multi_modal_data=request.multi_modal_data)) + sampling_params.append( + SamplingParams( + n=n, +- temperature=0.0 if use_beam_search else 1.0, ++ temperature=1.0, + top_p=1.0, +- use_beam_search=use_beam_search, + ignore_eos=True, +- max_tokens=output_len, ++ max_tokens=request.expected_output_len, + )) ++ lora_requests: Optional[List[LoRARequest]] = None ++ if engine_args.enable_lora: ++ lora_requests = [request.lora_request for request in requests] + +- start = time.perf_counter() +- llm.generate(prompts, sampling_params, use_tqdm=True) +- end = time.perf_counter() ++ use_beam_search = False ++ ++ if not use_beam_search: ++ start = time.perf_counter() ++ llm.generate(prompts, ++ sampling_params, ++ lora_request=lora_requests, ++ use_tqdm=True) ++ end = time.perf_counter() ++ else: ++ assert lora_requests is None, "BeamSearch API does not support LoRA" ++ prompts = [request.prompt for request in requests] ++ # output_len should be the same for all requests. ++ output_len = requests[0][2] ++ for request in requests: ++ assert request.expected_output_len == output_len ++ start = time.perf_counter() ++ llm.beam_search( ++ prompts, ++ BeamSearchParams( ++ beam_width=n, ++ max_tokens=output_len, ++ ignore_eos=True, ++ )) ++ end = time.perf_counter() + return end - start + + ++async def run_vllm_async( ++ requests: List[SampleRequest], ++ n: int, ++ engine_args: AsyncEngineArgs, ++ disable_frontend_multiprocessing: bool = False, ++) -> float: ++ from vllm import SamplingParams ++ ++ async with build_async_engine_client_from_engine_args( ++ engine_args, disable_frontend_multiprocessing) as llm: ++ ++ # Add the requests to the engine. ++ prompts: List[TextPrompt] = [] ++ sampling_params: List[SamplingParams] = [] ++ lora_requests: List[Optional[LoRARequest]] = [] ++ for request in requests: ++ prompts.append( ++ TextPrompt(prompt=request.prompt, ++ multi_modal_data=request.multi_modal_data)) ++ sampling_params.append( ++ SamplingParams( ++ n=n, ++ temperature=1.0, ++ top_p=1.0, ++ ignore_eos=True, ++ max_tokens=request.expected_output_len, ++ )) ++ lora_requests.append(request.lora_request) ++ ++ generators = [] ++ start = time.perf_counter() ++ for i, (prompt, sp, ++ lr) in enumerate(zip(prompts, sampling_params, lora_requests)): ++ generator = llm.generate(prompt, ++ sp, ++ lora_request=lr, ++ request_id=f"test{i}") ++ generators.append(generator) ++ all_gens = merge_async_iterators(*generators) ++ async for i, res in all_gens: ++ pass ++ end = time.perf_counter() ++ return end - start ++ ++ + def run_hf( +- requests: List[Tuple[str, int, int]], ++ requests: List[SampleRequest], + model: str, + tokenizer: PreTrainedTokenizerBase, + n: int, +- use_beam_search: bool, + max_batch_size: int, + trust_remote_code: bool, + ) -> float: +- assert not use_beam_search + llm = AutoModelForCausalLM.from_pretrained( + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + if llm.config.model_type == "llama": +@@ -164,7 +300,7 @@ def run_hf( + padding=True).input_ids + llm_outputs = llm.generate( + input_ids=input_ids.cuda(), +- do_sample=not use_beam_search, ++ do_sample=True, + num_return_sequences=n, + temperature=1.0, + top_p=1.0, +@@ -184,14 +320,14 @@ def run_hf( + + + def run_mii( +- requests: List[Tuple[str, int, int]], ++ requests: List[SampleRequest], + model: str, + tensor_parallel_size: int, + output_len: int, + ) -> float: + from mii import client, serve + llm = serve(model, tensor_parallel=tensor_parallel_size) +- prompts = [prompt for prompt, _, _ in requests] ++ prompts = [request.prompt for request in requests] + + start = time.perf_counter() + llm.generate(prompts, max_new_tokens=output_len) +@@ -209,42 +345,99 @@ def main(args: argparse.Namespace): + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + if args.dataset is None: +- # Synthesize a prompt with the given input length. +- prompt = "hi" * (args.input_len - 1) +- requests = [(prompt, args.input_len, args.output_len) +- for _ in range(args.num_prompts)] ++ vocab_size = tokenizer.vocab_size ++ requests = [] ++ for _ in range(args.num_prompts): ++ ++ request_tokenizer = tokenizer ++ lora_request: Optional[LoRARequest] = None ++ if args.enable_lora: ++ lora_request, lora_tokenizer = get_random_lora_request(args) ++ if lora_tokenizer: ++ request_tokenizer = lora_tokenizer ++ ++ # Synthesize a prompt with the given input length. ++ candidate_ids = [ ++ random.randint(0, vocab_size - 1) ++ for _ in range(args.input_len) ++ ] ++ # As tokenizer may add additional tokens like BOS, we need to try ++ # different lengths to get the desired input length. ++ for _ in range(5): # Max attempts to correct ++ candidate_prompt = request_tokenizer.decode(candidate_ids) ++ tokenized_len = len(request_tokenizer.encode(candidate_prompt)) ++ ++ if tokenized_len == args.input_len: ++ break ++ ++ # Adjust length based on difference ++ diff = args.input_len - tokenized_len ++ if diff > 0: ++ candidate_ids.extend([ ++ random.randint(100, vocab_size - 100) ++ for _ in range(diff) ++ ]) ++ else: ++ candidate_ids = candidate_ids[:diff] ++ requests.append( ++ SampleRequest(prompt=candidate_prompt, ++ prompt_len=args.input_len, ++ expected_output_len=args.output_len, ++ lora_request=lora_request)) + else: +- requests = sample_requests(args.dataset, args.num_prompts, tokenizer, +- args.output_len) ++ requests = sample_requests(tokenizer, args) + ++ is_multi_modal = any(request.multi_modal_data is not None ++ for request in requests) + if args.backend == "vllm": +- elapsed_time = run_vllm( +- requests, args.model, args.tokenizer, args.quantization, +- args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, +- args.trust_remote_code, args.dtype, args.max_model_len, +- args.enforce_eager, args.kv_cache_dtype, +- args.quantization_param_path, args.device, +- args.enable_prefix_caching, args.enable_chunked_prefill, +- args.max_num_batched_tokens, args.gpu_memory_utilization, +- args.download_dir) ++ if args.async_engine: ++ elapsed_time = uvloop.run( ++ run_vllm_async( ++ requests, ++ args.n, ++ AsyncEngineArgs.from_cli_args(args), ++ args.disable_frontend_multiprocessing, ++ )) ++ else: ++ elapsed_time = run_vllm(requests, args.n, ++ EngineArgs.from_cli_args(args)) + elif args.backend == "hf": + assert args.tensor_parallel_size == 1 + elapsed_time = run_hf(requests, args.model, tokenizer, args.n, +- args.use_beam_search, args.hf_max_batch_size, +- args.trust_remote_code) ++ args.hf_max_batch_size, args.trust_remote_code) + elif args.backend == "mii": + elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, + args.output_len) + else: + raise ValueError(f"Unknown backend: {args.backend}") +- total_num_tokens = sum(prompt_len + output_len +- for _, prompt_len, output_len in requests) ++ total_num_tokens = sum(request.prompt_len + request.expected_output_len ++ for request in requests) ++ total_output_tokens = sum(request.expected_output_len ++ for request in requests) ++ if is_multi_modal: ++ print("\033[91mWARNING\033[0m: Multi-modal request detected. The " ++ "following metrics are not accurate because image tokens are not" ++ " counted. See vllm-project/vllm/issues/9778 for details.") ++ # TODO(vllm-project/vllm/issues/9778): Count molti-modal token length. + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " +- f"{total_num_tokens / elapsed_time:.2f} tokens/s") ++ f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " ++ f"{total_output_tokens / elapsed_time:.2f} output tokens/s") ++ ++ # Output JSON results if specified ++ if args.output_json: ++ results = { ++ "elapsed_time": elapsed_time, ++ "num_requests": len(requests), ++ "total_num_tokens": total_num_tokens, ++ "requests_per_second": len(requests) / elapsed_time, ++ "tokens_per_second": total_num_tokens / elapsed_time, ++ } ++ with open(args.output_json, "w") as f: ++ json.dump(results, f, indent=4) + + + if __name__ == "__main__": +- parser = argparse.ArgumentParser(description="Benchmark the throughput.") ++ parser = FlexibleArgumentParser(description="Benchmark the throughput.") + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf", "mii"], +@@ -252,7 +445,9 @@ if __name__ == "__main__": + parser.add_argument("--dataset", + type=str, + default=None, +- help="Path to the dataset.") ++ help="Path to the dataset. The dataset is expected to " ++ "be a json in form of List[Dict[..., conversations: " ++ "List[Dict[..., value: ]]]]") + parser.add_argument("--input-len", + type=int, + default=None, +@@ -262,97 +457,40 @@ if __name__ == "__main__": + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.") +- parser.add_argument("--model", type=str, default="facebook/opt-125m") +- parser.add_argument("--tokenizer", type=str, default=None) +- parser.add_argument('--quantization', +- '-q', +- choices=[*QUANTIZATION_METHODS, None], +- default=None) +- parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") +- parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.") +- parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--hf-max-batch-size", + type=int, + default=None, + help="Maximum batch size for HF backend.") +- parser.add_argument('--trust-remote-code', +- action='store_true', +- help='trust remote code from huggingface') +- parser.add_argument( +- '--max-model-len', +- type=int, +- default=None, +- help='Maximum length of a sequence (including prompt and output). ' +- 'If None, will be derived from the model.') +- parser.add_argument( +- '--dtype', +- type=str, +- default='auto', +- choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], +- help='data type for model weights and activations. ' +- 'The "auto" option will use FP16 precision ' +- 'for FP32 and FP16 models, and BF16 precision ' +- 'for BF16 models.') +- parser.add_argument('--gpu-memory-utilization', +- type=float, +- default=0.9, +- help='the fraction of GPU memory to be used for ' +- 'the model executor, which can range from 0 to 1.' +- 'If unspecified, will use the default value of 0.9.') +- parser.add_argument("--enforce-eager", +- action="store_true", +- help="enforce eager execution") +- parser.add_argument( +- "--kv-cache-dtype", +- type=str, +- choices=["auto", "fp8"], +- default="auto", +- help= +- 'Data type for kv cache storage. If "auto", will use model data type. ' +- 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' +- 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' +- 'common inference criteria.') + parser.add_argument( +- '--quantization-param-path', ++ '--output-json', + type=str, + default=None, +- help='Path to the JSON file containing the KV cache scaling factors. ' +- 'This should generally be supplied, when KV cache dtype is FP8. ' +- 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' +- 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' +- 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' +- 'instead supported for common inference criteria.') ++ help='Path to save the throughput results in JSON format.') ++ parser.add_argument("--async-engine", ++ action='store_true', ++ default=False, ++ help="Use vLLM async engine rather than LLM class.") ++ parser.add_argument("--disable-frontend-multiprocessing", ++ action='store_true', ++ default=False, ++ help="Disable decoupled async engine frontend.") ++ # LoRA + parser.add_argument( +- "--device", ++ "--lora-path", + type=str, +- default="cuda", +- choices=["cuda", "cpu"], +- help='device type for vLLM execution, supporting CUDA and CPU.') +- parser.add_argument( +- "--enable-prefix-caching", +- action='store_true', +- help="enable automatic prefix caching for vLLM backend.") +- parser.add_argument("--enable-chunked-prefill", +- action='store_true', +- help="enable chunked prefill for vLLM backend.") +- parser.add_argument('--max-num-batched-tokens', +- type=int, +- default=None, +- help='maximum number of batched tokens per ' +- 'iteration') +- parser.add_argument('--download-dir', +- type=str, +- default=None, +- help='directory to download and load the weights, ' +- 'default to the default cache dir of huggingface') ++ default=None, ++ help="Path to the lora adapters to use. This can be an absolute path, " ++ "a relative path, or a Hugging Face model identifier.") ++ ++ parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + if args.tokenizer is None: + args.tokenizer = args.model +@@ -361,6 +499,8 @@ if __name__ == "__main__": + assert args.output_len is not None + else: + assert args.input_len is None ++ if args.enable_lora: ++ assert args.lora_path is not None + + if args.backend == "vllm": + if args.hf_max_batch_size is not None: +@@ -370,13 +510,14 @@ if __name__ == "__main__": + raise ValueError("HF max batch size is required for HF backend.") + if args.quantization is not None: + raise ValueError("Quantization is only for vLLM backend.") ++ if args.enable_lora is not None: ++ raise ValueError("LoRA benchmarking is only supported for vLLM" ++ " backend") + elif args.backend == "mii": + if args.dtype != "auto": + raise ValueError("dtype must be auto for MII backend.") + if args.n != 1: + raise ValueError("n must be 1 for MII backend.") +- if args.use_beam_search: +- raise ValueError("Beam search is not supported for MII backend.") + if args.quantization is not None: + raise ValueError("Quantization is only for vLLM backend.") + if args.hf_max_batch_size is not None: +@@ -384,4 +525,7 @@ if __name__ == "__main__": + if args.tokenizer != args.model: + raise ValueError("Tokenizer must be the same as the model for MII " + "backend.") ++ if args.enable_lora is not None: ++ raise ValueError("LoRA benchmarking is only supported for vLLM" ++ " backend") + main(args) +diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py +new file mode 100644 +index 0000000..3d1c5e3 +--- /dev/null ++++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py +@@ -0,0 +1,384 @@ ++import argparse ++import copy ++import itertools ++import pickle as pkl ++import time ++from typing import Callable, Iterable, List, Tuple ++ ++import torch ++import torch.utils.benchmark as TBenchmark ++from torch.utils.benchmark import Measurement as TMeasurement ++from utils import make_rand_sparse_tensors ++from weight_shapes import WEIGHT_SHAPES ++ ++from vllm import _custom_ops as ops ++from vllm.utils import FlexibleArgumentParser ++ ++DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) ++DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] ++DEFAULT_TP_SIZES = [1] ++ ++ ++# bench ++def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, ++ **kwargs) -> TMeasurement: ++ min_run_time = 1 ++ ++ globals = { ++ "args": args, ++ "kwargs": kwargs, ++ "fn": fn, ++ } ++ return TBenchmark.Timer( ++ stmt="fn(*args, **kwargs)", ++ globals=globals, ++ label=label, ++ sub_label=sub_label, ++ description=description, ++ ).blocked_autorange(min_run_time=min_run_time) ++ ++ ++def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, ++ sub_label: str) -> Iterable[TMeasurement]: ++ assert dtype == torch.int8 ++ b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) ++ scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) ++ scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) ++ bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) ++ ++ out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, ++ torch.bfloat16) ++ out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) ++ ++ if not torch.allclose(out, out_ref): ++ print("Incorrect results") ++ print(out) ++ print(out_ref) ++ else: ++ print("Correct results") ++ ++ timers = [] ++ # pytorch impl - bfloat16 ++ timers.append( ++ bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", ++ torch.mm, a.to(dtype=torch.bfloat16), ++ b.to(dtype=torch.bfloat16))) ++ ++ # pytorch impl - float16 ++ timers.append( ++ bench_fn(label, sub_label, ++ "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, ++ a.to(dtype=torch.float16), b.to(dtype=torch.float16))) ++ ++ # cutlass impl ++ timers.append( ++ bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", ++ ops.cutlass_scaled_mm, a, b, scale_a, scale_b, ++ torch.bfloat16)) ++ ++ # cutlass with bias ++ timers.append( ++ bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", ++ ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, ++ bias)) ++ ++ # cutlass sparse impl ++ timers.append( ++ bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm", ++ ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, ++ scale_b, torch.bfloat16)) ++ ++ # cutlass sparse with bias ++ timers.append( ++ bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", ++ ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, ++ scale_b, torch.bfloat16, bias)) ++ ++ return timers ++ ++ ++def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, ++ sub_label: str) -> Iterable[TMeasurement]: ++ assert dtype == torch.float8_e4m3fn ++ b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, ++ k) ++ scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) ++ scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) ++ bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) ++ ++ out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, ++ torch.bfloat16) ++ out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) ++ ++ if not torch.allclose(out, out_ref): ++ print("Incorrect results") ++ print(out) ++ print(out_ref) ++ else: ++ print("Correct results") ++ ++ timers = [] ++ ++ # pytorch impl w. bf16 ++ timers.append( ++ bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", ++ torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), ++ b.to(dtype=torch.bfloat16, device="cuda"))) ++ ++ # pytorch impl: bf16 output, without fp8 fast accum ++ timers.append( ++ bench_fn(label, ++ sub_label, ++ "pytorch_fp8_fp8_bf16_scaled_mm", ++ torch._scaled_mm, ++ a, ++ b, ++ scale_a=scale_a, ++ scale_b=scale_b, ++ out_dtype=torch.bfloat16)) ++ ++ # pytorch impl: bf16 output, with fp8 fast accum ++ timers.append( ++ bench_fn(label, ++ sub_label, ++ "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", ++ torch._scaled_mm, ++ a, ++ b, ++ scale_a=scale_a, ++ scale_b=scale_b, ++ out_dtype=torch.bfloat16, ++ use_fast_accum=True)) ++ ++ # pytorch impl: fp16 output, without fp8 fast accum ++ timers.append( ++ bench_fn(label, ++ sub_label, ++ "pytorch_fp8_fp8_fp16_scaled_mm", ++ torch._scaled_mm, ++ a, ++ b, ++ scale_a=scale_a, ++ scale_b=scale_b, ++ out_dtype=torch.float16)) ++ ++ # pytorch impl: fp16 output, with fp8 fast accum ++ timers.append( ++ bench_fn(label, ++ sub_label, ++ "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", ++ torch._scaled_mm, ++ a, ++ b, ++ scale_a=scale_a, ++ scale_b=scale_b, ++ out_dtype=torch.float16, ++ use_fast_accum=True)) ++ ++ # cutlass impl: bf16 output ++ timers.append( ++ bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", ++ ops.cutlass_scaled_mm, a, b, scale_a, scale_b, ++ torch.bfloat16)) ++ ++ # cutlass impl: bf16 output ++ timers.append( ++ bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm", ++ ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, ++ scale_b, torch.bfloat16)) ++ ++ # cutlass impl: fp16 output ++ timers.append( ++ bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm", ++ ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, ++ scale_b, torch.float16)) ++ ++ # cutlass impl: bf16 output, with bias ++ timers.append( ++ bench_fn(label, sub_label, ++ "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", ++ ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, ++ scale_b, torch.bfloat16, bias)) ++ ++ # cutlass impl: fp16 output, with bias ++ timers.append( ++ bench_fn(label, sub_label, ++ "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", ++ ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, ++ scale_b, torch.float16, bias.to(dtype=torch.float16))) ++ ++ return timers ++ ++ ++def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, ++ sub_label: str) -> Iterable[TMeasurement]: ++ if dtype == torch.int8: ++ return bench_int8(dtype, m, k, n, label, sub_label) ++ if dtype == torch.float8_e4m3fn: ++ return bench_fp8(dtype, m, k, n, label, sub_label) ++ raise ValueError("unsupported type") ++ ++ ++# runner ++def print_timers(timers: Iterable[TMeasurement]): ++ compare = TBenchmark.Compare(timers) ++ compare.print() ++ ++ ++def run(dtype: torch.dtype, ++ MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: ++ results = [] ++ for m, k, n in MKNs: ++ timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", ++ f"MKN=({m}x{k}x{n})") ++ print_timers(timers) ++ results.extend(timers) ++ ++ return results ++ ++ ++# output makers ++def make_output(data: Iterable[TMeasurement], ++ MKNs: Iterable[Tuple[int, int, int]], ++ base_description: str, ++ timestamp=None): ++ print(f"== All Results {base_description} ====") ++ print_timers(data) ++ ++ # pickle all the results ++ timestamp = int(time.time()) if timestamp is None else timestamp ++ with open(f"{base_description}-{timestamp}.pkl", "wb") as f: ++ pkl.dump(data, f) ++ ++ ++# argparse runners ++ ++ ++def run_square_bench(args): ++ dim_sizes = list( ++ range(args.dim_start, args.dim_end + 1, args.dim_increment)) ++ MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) ++ data = run(args.dtype, MKNs) ++ ++ make_output(data, MKNs, f"square_bench-{args.dtype}") ++ ++ ++def run_range_bench(args): ++ dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) ++ n = len(dim_sizes) ++ Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes ++ Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes ++ Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes ++ MKNs = list(zip(Ms, Ks, Ns)) ++ data = run(args.dtype, MKNs) ++ ++ make_output(data, MKNs, f"range_bench-{args.dtype}") ++ ++ ++def run_model_bench(args): ++ print("Benchmarking models:") ++ for i, model in enumerate(args.models): ++ print(f"[{i}] {model}") ++ ++ def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: ++ KNs = [] ++ for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): ++ KN[tp_split_dim] = KN[tp_split_dim] // tp_size ++ KNs.append(KN) ++ return KNs ++ ++ model_bench_data = [] ++ models_tps = list(itertools.product(args.models, args.tp_sizes)) ++ for model, tp_size in models_tps: ++ Ms = args.batch_sizes ++ KNs = model_shapes(model, tp_size) ++ MKNs = [] ++ for m in Ms: ++ for k, n in KNs: ++ MKNs.append((m, k, n)) ++ ++ data = run(args.dtype, MKNs) ++ model_bench_data.append(data) ++ ++ # Print all results ++ for data, model_tp in zip(model_bench_data, models_tps): ++ model, tp_size = model_tp ++ print(f"== Results {args.dtype} {model}-TP{tp_size} ====") ++ print_timers(data) ++ ++ timestamp = int(time.time()) ++ ++ all_data = [] ++ for d in model_bench_data: ++ all_data.extend(d) ++ # pickle all data ++ with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: ++ pkl.dump(all_data, f) ++ ++ ++if __name__ == '__main__': ++ ++ def to_torch_dtype(dt): ++ if dt == "int8": ++ return torch.int8 ++ if dt == "fp8": ++ return torch.float8_e4m3fn ++ raise ValueError("unsupported dtype") ++ ++ parser = FlexibleArgumentParser( ++ description=""" ++Benchmark Cutlass GEMM. ++ ++ To run square GEMMs: ++ python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 ++ ++ To run constant N and K and sweep M: ++ python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 ++ ++ To run dimensions from a model: ++ python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 ++ ++ Output: ++ - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. ++ """, # noqa: E501 ++ formatter_class=argparse.RawTextHelpFormatter) ++ ++ parser.add_argument("--dtype", ++ type=to_torch_dtype, ++ required=True, ++ help="Available options are ['int8', 'fp8']") ++ subparsers = parser.add_subparsers(dest="cmd") ++ ++ square_parser = subparsers.add_parser("square_bench") ++ square_parser.add_argument("--dim-start", type=int, required=True) ++ square_parser.add_argument("--dim-end", type=int, required=True) ++ square_parser.add_argument("--dim-increment", type=int, required=True) ++ square_parser.set_defaults(func=run_square_bench) ++ ++ range_parser = subparsers.add_parser("range_bench") ++ range_parser.add_argument("--dim-start", type=int, required=True) ++ range_parser.add_argument("--dim-end", type=int, required=True) ++ range_parser.add_argument("--dim-increment", type=int, required=True) ++ range_parser.add_argument("--m-constant", type=int, default=None) ++ range_parser.add_argument("--n-constant", type=int, default=None) ++ range_parser.add_argument("--k-constant", type=int, default=None) ++ range_parser.set_defaults(func=run_range_bench) ++ ++ model_parser = subparsers.add_parser("model_bench") ++ model_parser.add_argument("--models", ++ nargs="+", ++ type=str, ++ default=DEFAULT_MODELS, ++ choices=WEIGHT_SHAPES.keys()) ++ model_parser.add_argument("--tp-sizes", ++ nargs="+", ++ type=int, ++ default=DEFAULT_TP_SIZES) ++ model_parser.add_argument("--batch-sizes", ++ nargs="+", ++ type=int, ++ default=DEFAULT_BATCH_SIZES) ++ model_parser.set_defaults(func=run_model_bench) ++ ++ args = parser.parse_args() ++ args.func(args) +diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py +new file mode 100644 +index 0000000..ef06fcd +--- /dev/null ++++ b/benchmarks/cutlass_benchmarks/utils.py +@@ -0,0 +1,96 @@ ++# Cutlass bench utils ++from typing import Iterable, Tuple ++ ++import torch ++ ++import vllm._custom_ops as ops ++ ++ ++def to_fp8(tensor: torch.Tensor) -> torch.Tensor: ++ finfo = torch.finfo(torch.float8_e4m3fn) ++ return torch.round(tensor.clamp( ++ min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) ++ ++ ++def to_int8(tensor: torch.Tensor) -> torch.Tensor: ++ return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) ++ ++ ++def to_bf16(tensor: torch.Tensor) -> torch.Tensor: ++ return tensor.to(dtype=torch.bfloat16) ++ ++ ++def to_fp16(tensor: torch.Tensor) -> torch.Tensor: ++ return tensor.to(dtype=torch.float16) ++ ++ ++def make_rand_tensors(dtype: torch.dtype, m: int, n: int, ++ k: int) -> Tuple[torch.Tensor, torch.Tensor]: ++ a = torch.randn((m, k), device='cuda') * 5 ++ b = torch.randn((n, k), device='cuda').t() * 5 ++ ++ if dtype == torch.int8: ++ return to_int8(a), to_int8(b) ++ if dtype == torch.float8_e4m3fn: ++ return to_fp8(a), to_fp8(b) ++ ++ raise ValueError("unsupported dtype") ++ ++ ++def prune_to_2_4(tensor): ++ # Reshape tensor to [N, 4] where N is number of groups of 4 ++ original_shape = tensor.shape ++ reshaped = tensor.reshape(-1, 4) ++ ++ # Get indices of top 2 absolute values in each group of 4 ++ _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1) ++ ++ # Create binary mask ++ mask = torch.zeros_like(reshaped) ++ mask.scatter_(dim=1, ++ index=indices, ++ src=torch.ones_like(indices, dtype=mask.dtype)) ++ ++ # Apply mask and reshape back ++ pruned = reshaped * mask ++ ++ # Turn all -0.0 to 0.0 ++ pruned[pruned == -0.0] = 0.0 ++ ++ return pruned.reshape(original_shape) ++ ++ ++def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, ++ k: int) -> Tuple[torch.Tensor, torch.Tensor]: ++ a = torch.randn((m, k), device='cuda') * 5 ++ b = torch.randn((n, k), device='cuda').t() * 5 ++ ++ b = prune_to_2_4(b.t()).t() ++ ++ if dtype == torch.int8: ++ a, b = to_int8(a), to_int8(b) ++ elif dtype == torch.float8_e4m3fn: ++ a, b = to_fp8(a), to_fp8(b) ++ elif dtype == torch.float16: ++ a, b = to_fp16(a), to_fp16(b) ++ elif dtype == torch.bfloat16: ++ a, b = to_bf16(a), to_bf16(b) ++ else: ++ raise ValueError("unsupported dtype") ++ ++ b_compressed, e = ops.cutlass_sparse_compress(b.t()) ++ ++ # Compressed B, Metadata, Original A, B ++ return b_compressed, e, a, b ++ ++ ++def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype, ++ m: int, n: int, k: int) -> \ ++ Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: ++ ABs = [] ++ for _ in range(num_tensors): ++ b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) ++ if b_comp is not None: ++ ABs.append(make_rand_sparse_tensors(dtype, m, n, k)) ++ BComps, Es, As, Bs = zip(*ABs) ++ return list(BComps), list(Es), list(As), list(Bs) +diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +new file mode 100644 +index 0000000..d0353bc +--- /dev/null ++++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +@@ -0,0 +1,365 @@ ++import argparse ++import copy ++import itertools ++import pickle as pkl ++import time ++from typing import Callable, Iterable, List, Tuple ++ ++import torch ++import torch.utils.benchmark as TBenchmark ++from torch.utils.benchmark import Measurement as TMeasurement ++from utils import make_rand_tensors ++from weight_shapes import WEIGHT_SHAPES ++ ++from vllm import _custom_ops as ops ++from vllm.utils import FlexibleArgumentParser ++ ++DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) ++DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] ++DEFAULT_TP_SIZES = [1] ++ ++ ++# bench ++def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, ++ **kwargs) -> TMeasurement: ++ min_run_time = 1 ++ ++ globals = { ++ "args": args, ++ "kwargs": kwargs, ++ "fn": fn, ++ } ++ return TBenchmark.Timer( ++ stmt="fn(*args, **kwargs)", ++ globals=globals, ++ label=label, ++ sub_label=sub_label, ++ description=description, ++ ).blocked_autorange(min_run_time=min_run_time) ++ ++ ++def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, ++ sub_label: str) -> Iterable[TMeasurement]: ++ assert dtype == torch.int8 ++ a, b = make_rand_tensors(torch.int8, m, n, k) ++ scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) ++ scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) ++ bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) ++ azp = torch.zeros((m, ), device="cuda", dtype=torch.int32) ++ azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32) ++ ++ timers = [] ++ # pytorch impl - bfloat16 ++ timers.append( ++ bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", ++ torch.mm, a.to(dtype=torch.bfloat16), ++ b.to(dtype=torch.bfloat16))) ++ ++ # pytorch impl - float16 ++ timers.append( ++ bench_fn(label, sub_label, ++ "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, ++ a.to(dtype=torch.float16), b.to(dtype=torch.float16))) ++ ++ # cutlass impl ++ timers.append( ++ bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", ++ ops.cutlass_scaled_mm, a, b, scale_a, scale_b, ++ torch.bfloat16)) ++ ++ # cutlass with bias ++ timers.append( ++ bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", ++ ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, ++ bias)) ++ ++ # cutlass with azp per-tensor ++ timers.append( ++ bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp", ++ ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, ++ torch.bfloat16, azp_adj)) ++ ++ # cutlass with azp per-tensor + bias ++ timers.append( ++ bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_bias", ++ ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, ++ torch.bfloat16, azp_adj, None, bias)) ++ ++ # cutlass with azp per-token ++ timers.append( ++ bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt", ++ ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, ++ torch.bfloat16, azp_adj, azp)) ++ ++ # cutlass with azp per-token + bias ++ timers.append( ++ bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias", ++ ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, ++ torch.bfloat16, azp_adj, azp, bias)) ++ ++ return timers ++ ++ ++def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, ++ sub_label: str) -> Iterable[TMeasurement]: ++ assert dtype == torch.float8_e4m3fn ++ a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) ++ scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) ++ scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) ++ bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) ++ ++ timers = [] ++ ++ # pytorch impl w. bf16 ++ timers.append( ++ bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", ++ torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), ++ b.to(dtype=torch.bfloat16, device="cuda"))) ++ ++ # pytorch impl: bf16 output, without fp8 fast accum ++ timers.append( ++ bench_fn(label, ++ sub_label, ++ "pytorch_fp8_fp8_bf16_scaled_mm", ++ torch._scaled_mm, ++ a, ++ b, ++ scale_a=scale_a, ++ scale_b=scale_b, ++ out_dtype=torch.bfloat16)) ++ ++ # pytorch impl: bf16 output, with fp8 fast accum ++ timers.append( ++ bench_fn(label, ++ sub_label, ++ "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", ++ torch._scaled_mm, ++ a, ++ b, ++ scale_a=scale_a, ++ scale_b=scale_b, ++ out_dtype=torch.bfloat16, ++ use_fast_accum=True)) ++ ++ # pytorch impl: fp16 output, without fp8 fast accum ++ timers.append( ++ bench_fn(label, ++ sub_label, ++ "pytorch_fp8_fp8_fp16_scaled_mm", ++ torch._scaled_mm, ++ a, ++ b, ++ scale_a=scale_a, ++ scale_b=scale_b, ++ out_dtype=torch.float16)) ++ ++ # pytorch impl: fp16 output, with fp8 fast accum ++ timers.append( ++ bench_fn(label, ++ sub_label, ++ "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", ++ torch._scaled_mm, ++ a, ++ b, ++ scale_a=scale_a, ++ scale_b=scale_b, ++ out_dtype=torch.float16, ++ use_fast_accum=True)) ++ ++ # cutlass impl: bf16 output ++ timers.append( ++ bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", ++ ops.cutlass_scaled_mm, a, b, scale_a, scale_b, ++ torch.bfloat16)) ++ # cutlass impl: fp16 output ++ timers.append( ++ bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm", ++ ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16)) ++ ++ # cutlass impl: bf16 output, with bias ++ timers.append( ++ bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm_bias", ++ ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, ++ bias)) ++ ++ # cutlass impl: fp16 output, with bias ++ timers.append( ++ bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm_bias", ++ ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16, ++ bias.to(dtype=torch.float16))) ++ ++ return timers ++ ++ ++def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, ++ sub_label: str) -> Iterable[TMeasurement]: ++ if dtype == torch.int8: ++ return bench_int8(dtype, m, k, n, label, sub_label) ++ if dtype == torch.float8_e4m3fn: ++ return bench_fp8(dtype, m, k, n, label, sub_label) ++ raise ValueError("unsupported type") ++ ++ ++# runner ++def print_timers(timers: Iterable[TMeasurement]): ++ compare = TBenchmark.Compare(timers) ++ compare.print() ++ ++ ++def run(dtype: torch.dtype, ++ MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: ++ results = [] ++ for m, k, n in MKNs: ++ timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", ++ f"MKN=({m}x{k}x{n})") ++ print_timers(timers) ++ results.extend(timers) ++ ++ return results ++ ++ ++# output makers ++def make_output(data: Iterable[TMeasurement], ++ MKNs: Iterable[Tuple[int, int, int]], ++ base_description: str, ++ timestamp=None): ++ print(f"== All Results {base_description} ====") ++ print_timers(data) ++ ++ # pickle all the results ++ timestamp = int(time.time()) if timestamp is None else timestamp ++ with open(f"{base_description}-{timestamp}.pkl", "wb") as f: ++ pkl.dump(data, f) ++ ++ ++# argparse runners ++ ++ ++def run_square_bench(args): ++ dim_sizes = list( ++ range(args.dim_start, args.dim_end + 1, args.dim_increment)) ++ MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) ++ data = run(args.dtype, MKNs) ++ ++ make_output(data, MKNs, f"square_bench-{args.dtype}") ++ ++ ++def run_range_bench(args): ++ dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) ++ n = len(dim_sizes) ++ Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes ++ Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes ++ Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes ++ MKNs = list(zip(Ms, Ks, Ns)) ++ data = run(args.dtype, MKNs) ++ ++ make_output(data, MKNs, f"range_bench-{args.dtype}") ++ ++ ++def run_model_bench(args): ++ print("Benchmarking models:") ++ for i, model in enumerate(args.models): ++ print(f"[{i}] {model}") ++ ++ def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: ++ KNs = [] ++ for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): ++ KN[tp_split_dim] = KN[tp_split_dim] // tp_size ++ KNs.append(KN) ++ return KNs ++ ++ model_bench_data = [] ++ models_tps = list(itertools.product(args.models, args.tp_sizes)) ++ for model, tp_size in models_tps: ++ Ms = args.batch_sizes ++ KNs = model_shapes(model, tp_size) ++ MKNs = [] ++ for m in Ms: ++ for k, n in KNs: ++ MKNs.append((m, k, n)) ++ ++ data = run(args.dtype, MKNs) ++ model_bench_data.append(data) ++ ++ # Print all results ++ for data, model_tp in zip(model_bench_data, models_tps): ++ model, tp_size = model_tp ++ print(f"== Results {args.dtype} {model}-TP{tp_size} ====") ++ print_timers(data) ++ ++ timestamp = int(time.time()) ++ ++ all_data = [] ++ for d in model_bench_data: ++ all_data.extend(d) ++ # pickle all data ++ with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: ++ pkl.dump(all_data, f) ++ ++ ++if __name__ == '__main__': ++ ++ def to_torch_dtype(dt): ++ if dt == "int8": ++ return torch.int8 ++ if dt == "fp8": ++ return torch.float8_e4m3fn ++ raise ValueError("unsupported dtype") ++ ++ parser = FlexibleArgumentParser( ++ description=""" ++Benchmark Cutlass GEMM. ++ ++ To run square GEMMs: ++ python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 ++ ++ To run constant N and K and sweep M: ++ python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 ++ ++ To run dimensions from a model: ++ python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 ++ ++ Output: ++ - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. ++ """, # noqa: E501 ++ formatter_class=argparse.RawTextHelpFormatter) ++ ++ parser.add_argument("--dtype", ++ type=to_torch_dtype, ++ required=True, ++ help="Available options are ['int8', 'fp8']") ++ subparsers = parser.add_subparsers(dest="cmd") ++ ++ square_parser = subparsers.add_parser("square_bench") ++ square_parser.add_argument("--dim-start", type=int, required=True) ++ square_parser.add_argument("--dim-end", type=int, required=True) ++ square_parser.add_argument("--dim-increment", type=int, required=True) ++ square_parser.set_defaults(func=run_square_bench) ++ ++ range_parser = subparsers.add_parser("range_bench") ++ range_parser.add_argument("--dim-start", type=int, required=True) ++ range_parser.add_argument("--dim-end", type=int, required=True) ++ range_parser.add_argument("--dim-increment", type=int, required=True) ++ range_parser.add_argument("--m-constant", type=int, default=None) ++ range_parser.add_argument("--n-constant", type=int, default=None) ++ range_parser.add_argument("--k-constant", type=int, default=None) ++ range_parser.set_defaults(func=run_range_bench) ++ ++ model_parser = subparsers.add_parser("model_bench") ++ model_parser.add_argument("--models", ++ nargs="+", ++ type=str, ++ default=DEFAULT_MODELS, ++ choices=WEIGHT_SHAPES.keys()) ++ model_parser.add_argument("--tp-sizes", ++ nargs="+", ++ type=int, ++ default=DEFAULT_TP_SIZES) ++ model_parser.add_argument("--batch-sizes", ++ nargs="+", ++ type=int, ++ default=DEFAULT_BATCH_SIZES) ++ model_parser.set_defaults(func=run_model_bench) ++ ++ args = parser.parse_args() ++ args.func(args) +\ No newline at end of file +diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py +new file mode 100644 +index 0000000..d58fb0b +--- /dev/null ++++ b/benchmarks/cutlass_benchmarks/weight_shapes.py +@@ -0,0 +1,43 @@ ++# Weight Shapes are in the format ++# ([K, N], TP_SPLIT_DIM) ++# Example: ++# A shape of ([14336, 4096], 0) indicates the following GEMM shape, ++# - TP1 : K = 14336, N = 4096 ++# - TP2 : K = 7168, N = 4096 ++# A shape of ([4096, 6144], 1) indicates the following GEMM shape, ++# - TP1 : K = 4096, N = 6144 ++# - TP4 : K = 4096, N = 1536 ++ ++# TP1 shapes ++WEIGHT_SHAPES = { ++ "mistralai/Mistral-7B-v0.1": [ ++ ([4096, 6144], 1), ++ ([4096, 4096], 0), ++ ([4096, 28672], 1), ++ ([14336, 4096], 0), ++ ], ++ "meta-llama/Llama-2-7b-hf": [ ++ ([4096, 12288], 1), ++ ([4096, 4096], 0), ++ ([4096, 22016], 1), ++ ([11008, 4096], 0), ++ ], ++ "meta-llama/Llama-3-8b": [ ++ ([4096, 6144], 1), ++ ([4096, 4096], 0), ++ ([4096, 28672], 1), ++ ([14336, 4096], 0), ++ ], ++ "meta-llama/Llama-2-13b-hf": [ ++ ([5120, 15360], 1), ++ ([5120, 5120], 0), ++ ([5120, 27648], 1), ++ ([13824, 5120], 0), ++ ], ++ "meta-llama/Llama-2-70b-hf": [ ++ ([8192, 10240], 1), ++ ([8192, 8192], 0), ++ ([8192, 57344], 1), ++ ([28672, 8192], 0), ++ ], ++} +\ No newline at end of file +diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh +new file mode 100644 +index 0000000..9499963 +--- /dev/null ++++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh +@@ -0,0 +1,145 @@ ++#!/bin/bash ++ ++# benchmark the overhead of disaggregated prefill. ++# methodology: ++# - send all request to prefill vLLM instance. It will buffer KV cache. ++# - then send all request to decode instance. ++# - The TTFT of decode instance is the overhead. ++ ++set -ex ++ ++kill_gpu_processes() { ++ # kill all processes on GPU. ++ pgrep pt_main_thread | xargs -r kill -9 ++ pgrep python3 | xargs -r kill -9 ++ sleep 10 ++ ++ # remove vllm config file ++ rm -rf ~/.config/vllm ++ ++ # Print the GPU memory usage ++ # so that we know if all GPU processes are killed. ++ gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) ++ # The memory usage should be 0 MB. ++ echo "GPU 0 Memory Usage: $gpu_memory_usage MB" ++} ++ ++wait_for_server() { ++ # wait for vllm server to start ++ # return 1 if vllm server crashes ++ local port=$1 ++ timeout 1200 bash -c " ++ until curl -s localhost:${port}/v1/completions > /dev/null; do ++ sleep 1 ++ done" && return 0 || return 1 ++} ++ ++ ++benchmark() { ++ ++ export VLLM_LOGGING_LEVEL=DEBUG ++ export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') ++ ++ # compare chunked prefill with disaggregated prefill ++ ++ results_folder="./results" ++ model="meta-llama/Meta-Llama-3.1-8B-Instruct" ++ dataset_name="sonnet" ++ dataset_path="../sonnet_4x.txt" ++ num_prompts=10 ++ qps=$1 ++ prefix_len=50 ++ input_len=2048 ++ output_len=$2 ++ ++ ++ CUDA_VISIBLE_DEVICES=0 python3 \ ++ -m vllm.entrypoints.openai.api_server \ ++ --model $model \ ++ --port 8100 \ ++ --max-model-len 10000 \ ++ --gpu-memory-utilization 0.6 \ ++ --kv-transfer-config \ ++ '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & ++ ++ ++ CUDA_VISIBLE_DEVICES=1 python3 \ ++ -m vllm.entrypoints.openai.api_server \ ++ --model $model \ ++ --port 8200 \ ++ --max-model-len 10000 \ ++ --gpu-memory-utilization 0.6 \ ++ --kv-transfer-config \ ++ '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & ++ ++ wait_for_server 8100 ++ wait_for_server 8200 ++ ++ # let the prefill instance finish prefill ++ python3 ../benchmark_serving.py \ ++ --backend vllm \ ++ --model $model \ ++ --dataset-name $dataset_name \ ++ --dataset-path $dataset_path \ ++ --sonnet-input-len $input_len \ ++ --sonnet-output-len "$output_len" \ ++ --sonnet-prefix-len $prefix_len \ ++ --num-prompts $num_prompts \ ++ --port 8100 \ ++ --save-result \ ++ --result-dir $results_folder \ ++ --result-filename disagg_prefill_tp1.json \ ++ --request-rate "inf" ++ ++ ++ # send the request to decode. ++ # The TTFT of this command will be the overhead of disagg prefill impl. ++ python3 ../benchmark_serving.py \ ++ --backend vllm \ ++ --model $model \ ++ --dataset-name $dataset_name \ ++ --dataset-path $dataset_path \ ++ --sonnet-input-len $input_len \ ++ --sonnet-output-len "$output_len" \ ++ --sonnet-prefix-len $prefix_len \ ++ --num-prompts $num_prompts \ ++ --port 8200 \ ++ --save-result \ ++ --result-dir $results_folder \ ++ --result-filename disagg_prefill_tp1_overhead.json \ ++ --request-rate "$qps" ++ kill_gpu_processes ++ ++} ++ ++ ++main() { ++ ++ (which wget && which curl) || (apt-get update && apt-get install -y wget curl) ++ (which jq) || (apt-get -y install jq) ++ (which socat) || (apt-get -y install socat) ++ ++ pip install quart httpx datasets ++ ++ cd "$(dirname "$0")" ++ ++ cd .. ++ # create sonnet-4x.txt ++ echo "" > sonnet_4x.txt ++ for _ in {1..4} ++ do ++ cat sonnet.txt >> sonnet_4x.txt ++ done ++ cd disagg_benchmarks ++ ++ rm -rf results ++ mkdir results ++ ++ default_qps=1 ++ default_output_len=1 ++ benchmark $default_qps $default_output_len ++ ++} ++ ++ ++main "$@" +diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh +new file mode 100644 +index 0000000..eb5d891 +--- /dev/null ++++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh +@@ -0,0 +1,163 @@ ++#!/bin/bash ++ ++# Requirement: 2x GPUs. ++ ++ ++# Model: meta-llama/Meta-Llama-3.1-8B-Instruct ++# Query: 1024 input tokens, 6 output tokens, QPS 2/4/6/8, 100 requests ++# Resource: 2x GPU ++# Approaches: ++# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4 ++# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance ++# Prefilling instance: max_output_token=1 ++# Decoding instance: force the input tokens be the same across requests to bypass prefilling ++ ++set -ex ++ ++kill_gpu_processes() { ++ # kill all processes on GPU. ++ pgrep pt_main_thread | xargs -r kill -9 ++ pgrep python3 | xargs -r kill -9 ++ for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done ++ sleep 1 ++} ++ ++wait_for_server() { ++ # wait for vllm server to start ++ # return 1 if vllm server crashes ++ local port=$1 ++ timeout 1200 bash -c " ++ until curl -s localhost:${port}/v1/completions > /dev/null; do ++ sleep 1 ++ done" && return 0 || return 1 ++} ++ ++ ++launch_chunked_prefill() { ++ model="meta-llama/Meta-Llama-3.1-8B-Instruct" ++ # disagg prefill ++ CUDA_VISIBLE_DEVICES=0 python3 \ ++ -m vllm.entrypoints.openai.api_server \ ++ --model $model \ ++ --port 8100 \ ++ --max-model-len 10000 \ ++ --enable-chunked-prefill \ ++ --gpu-memory-utilization 0.6 & ++ CUDA_VISIBLE_DEVICES=1 python3 \ ++ -m vllm.entrypoints.openai.api_server \ ++ --model $model \ ++ --port 8200 \ ++ --max-model-len 10000 \ ++ --enable-chunked-prefill \ ++ --gpu-memory-utilization 0.6 & ++ wait_for_server 8100 ++ wait_for_server 8200 ++ python3 round_robin_proxy.py & ++ sleep 1 ++} ++ ++ ++launch_disagg_prefill() { ++ model="meta-llama/Meta-Llama-3.1-8B-Instruct" ++ # disagg prefill ++ CUDA_VISIBLE_DEVICES=0 python3 \ ++ -m vllm.entrypoints.openai.api_server \ ++ --model $model \ ++ --port 8100 \ ++ --max-model-len 10000 \ ++ --gpu-memory-utilization 0.6 \ ++ --kv-transfer-config \ ++ '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & ++ ++ CUDA_VISIBLE_DEVICES=1 python3 \ ++ -m vllm.entrypoints.openai.api_server \ ++ --model $model \ ++ --port 8200 \ ++ --max-model-len 10000 \ ++ --gpu-memory-utilization 0.6 \ ++ --kv-transfer-config \ ++ '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & ++ ++ wait_for_server 8100 ++ wait_for_server 8200 ++ python3 disagg_prefill_proxy_server.py & ++ sleep 1 ++} ++ ++ ++benchmark() { ++ results_folder="./results" ++ model="meta-llama/Meta-Llama-3.1-8B-Instruct" ++ dataset_name="sonnet" ++ dataset_path="../sonnet_4x.txt" ++ num_prompts=100 ++ qps=$1 ++ prefix_len=50 ++ input_len=1024 ++ output_len=$2 ++ tag=$3 ++ ++ python3 ../benchmark_serving.py \ ++ --backend vllm \ ++ --model $model \ ++ --dataset-name $dataset_name \ ++ --dataset-path $dataset_path \ ++ --sonnet-input-len $input_len \ ++ --sonnet-output-len "$output_len" \ ++ --sonnet-prefix-len $prefix_len \ ++ --num-prompts $num_prompts \ ++ --port 8000 \ ++ --save-result \ ++ --result-dir $results_folder \ ++ --result-filename "$tag"-qps-"$qps".json \ ++ --request-rate "$qps" ++ ++ sleep 2 ++} ++ ++ ++main() { ++ ++ (which wget && which curl) || (apt-get update && apt-get install -y wget curl) ++ (which jq) || (apt-get -y install jq) ++ (which socat) || (apt-get -y install socat) ++ (which lsof) || (apt-get -y install lsof) ++ ++ pip install quart httpx matplotlib aiohttp datasets ++ ++ cd "$(dirname "$0")" ++ ++ cd .. ++ # create sonnet-4x.txt so that we can sample 2048 tokens for input ++ echo "" > sonnet_4x.txt ++ for _ in {1..4} ++ do ++ cat sonnet.txt >> sonnet_4x.txt ++ done ++ cd disagg_benchmarks ++ ++ rm -rf results ++ mkdir results ++ ++ default_output_len=6 ++ ++ export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') ++ ++ launch_chunked_prefill ++ for qps in 2 4 6 8; do ++ benchmark $qps $default_output_len chunked_prefill ++ done ++ kill_gpu_processes ++ ++ launch_disagg_prefill ++ for qps in 2 4 6 8; do ++ benchmark $qps $default_output_len disagg_prefill ++ done ++ kill_gpu_processes ++ ++ python3 visualize_benchmark_results.py ++ ++} ++ ++ ++main "$@" +diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py +new file mode 100644 +index 0000000..4058b1c +--- /dev/null ++++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py +@@ -0,0 +1,61 @@ ++import os ++ ++import aiohttp ++from quart import Quart, make_response, request ++ ++AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) ++ ++app = Quart(__name__) ++ ++ ++async def forward_request(url, data): ++ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: ++ headers = { ++ "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" ++ } ++ async with session.post(url=url, json=data, ++ headers=headers) as response: ++ if response.status == 200: ++ # if response.headers.get('Transfer-Encoding') == 'chunked': ++ if True: ++ async for chunk_bytes in response.content.iter_chunked( ++ 1024): ++ yield chunk_bytes ++ else: ++ content = await response.read() ++ yield content ++ ++ ++@app.route('/v1/completions', methods=['POST']) ++async def handle_request(): ++ try: ++ original_request_data = await request.get_json() ++ ++ prefill_request = original_request_data.copy() ++ # change max_tokens = 1 to let it only do prefill ++ prefill_request['max_tokens'] = 1 ++ ++ # finish prefill ++ async for _ in forward_request('http://localhost:8100/v1/completions', ++ prefill_request): ++ continue ++ ++ # return decode ++ generator = forward_request('http://localhost:8200/v1/completions', ++ original_request_data) ++ response = await make_response(generator) ++ response.timeout = None ++ ++ return response ++ ++ except Exception as e: ++ import sys ++ import traceback ++ exc_info = sys.exc_info() ++ print("Error occurred in disagg prefill proxy server") ++ print(e) ++ print("".join(traceback.format_exception(*exc_info))) ++ ++ ++if __name__ == '__main__': ++ app.run(port=8000) +diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.py b/benchmarks/disagg_benchmarks/round_robin_proxy.py +new file mode 100644 +index 0000000..6eb5f63 +--- /dev/null ++++ b/benchmarks/disagg_benchmarks/round_robin_proxy.py +@@ -0,0 +1,60 @@ ++import asyncio ++import itertools ++ ++import aiohttp ++from aiohttp import web ++ ++ ++class RoundRobinProxy: ++ ++ def __init__(self, target_ports): ++ self.target_ports = target_ports ++ self.port_cycle = itertools.cycle(self.target_ports) ++ ++ async def handle_request(self, request): ++ target_port = next(self.port_cycle) ++ target_url = f"http://localhost:{target_port}{request.path_qs}" ++ ++ async with aiohttp.ClientSession() as session: ++ try: ++ # Forward the request ++ async with session.request( ++ method=request.method, ++ url=target_url, ++ headers=request.headers, ++ data=request.content, ++ ) as response: ++ # Start sending the response ++ resp = web.StreamResponse(status=response.status, ++ headers=response.headers) ++ await resp.prepare(request) ++ ++ # Stream the response content ++ async for chunk in response.content.iter_any(): ++ await resp.write(chunk) ++ ++ await resp.write_eof() ++ return resp ++ ++ except Exception as e: ++ return web.Response(text=f"Error: {str(e)}", status=500) ++ ++ ++async def main(): ++ proxy = RoundRobinProxy([8100, 8200]) ++ app = web.Application() ++ app.router.add_route('*', '/{path:.*}', proxy.handle_request) ++ ++ runner = web.AppRunner(app) ++ await runner.setup() ++ site = web.TCPSite(runner, 'localhost', 8000) ++ await site.start() ++ ++ print("Proxy server started on http://localhost:8000") ++ ++ # Keep the server running ++ await asyncio.Event().wait() ++ ++ ++if __name__ == '__main__': ++ asyncio.run(main()) +diff --git a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py +new file mode 100644 +index 0000000..e59d8bb +--- /dev/null ++++ b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py +@@ -0,0 +1,46 @@ ++import json ++ ++import matplotlib.pyplot as plt ++import pandas as pd ++ ++if __name__ == "__main__": ++ ++ data = [] ++ for name in ['disagg_prefill', 'chunked_prefill']: ++ for qps in [2, 4, 6, 8]: ++ with open(f"results/{name}-qps-{qps}.json") as f: ++ x = json.load(f) ++ x['name'] = name ++ x['qps'] = qps ++ data.append(x) ++ ++ df = pd.DataFrame.from_dict(data) ++ dis_df = df[df['name'] == 'disagg_prefill'] ++ chu_df = df[df['name'] == 'chunked_prefill'] ++ ++ plt.style.use('bmh') ++ plt.rcParams['font.size'] = 20 ++ ++ for key in [ ++ 'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms', ++ 'median_itl_ms', 'p99_itl_ms' ++ ]: ++ ++ fig, ax = plt.subplots(figsize=(11, 7)) ++ plt.plot(dis_df['qps'], ++ dis_df[key], ++ label='disagg_prefill', ++ marker='o', ++ linewidth=4) ++ plt.plot(chu_df['qps'], ++ chu_df[key], ++ label='chunked_prefill', ++ marker='o', ++ linewidth=4) ++ ax.legend() ++ ++ ax.set_xlabel('QPS') ++ ax.set_ylabel(key) ++ ax.set_ylim(bottom=0) ++ fig.savefig(f'results/{key}.png') ++ plt.close(fig) +diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py +new file mode 100644 +index 0000000..ef91f9f +--- /dev/null ++++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py +@@ -0,0 +1,173 @@ ++import pickle as pkl ++import time ++from dataclasses import dataclass ++from itertools import product ++from typing import Callable, Iterable, List, Optional ++ ++import torch ++import torch.utils.benchmark as TBenchmark ++from torch.utils.benchmark import Measurement as TMeasurement ++from tqdm import tqdm ++ ++import vllm._custom_ops as ops ++from vllm.model_executor.layers.layernorm import RMSNorm ++ ++ ++@dataclass ++class bench_params_t: ++ num_tokens: int ++ hidden_size: int ++ add_residual: bool ++ dtype: torch.dtype ++ ++ def description(self): ++ return (f'N {self.num_tokens} ' ++ f'x D {self.hidden_size} ' ++ f'x R {self.add_residual} ' ++ f'x DT {self.dtype}') ++ ++ ++def get_bench_params() -> List[bench_params_t]: ++ ## Test Fixtures ++ NUM_TOKENS = [2**x for x in range(11)] ++ HIDDEN_SIZES = list(range(1024, 8129, 1024)) ++ ADD_RESIDUAL = [True, False] ++ DTYPES = [torch.bfloat16, torch.float] ++ ++ combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES) ++ bench_params = list(map(lambda x: \ ++ bench_params_t(x[0], x[1], x[2], x[3]), combinations)) ++ return bench_params ++ ++ ++# Reference impls ++def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, ++ residual: Optional[torch.Tensor], ++ quant_dtype: torch.dtype): ++ # Norm ++ torch_out = None ++ if residual is None: ++ torch_out = rms_norm_layer.forward_cuda(x, residual) ++ else: ++ torch_out, _ = rms_norm_layer.forward_cuda(x, residual) ++ ++ # Quant ++ torch_out, _, _ = ops.scaled_int8_quant(torch_out) ++ ++ ++def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, ++ residual: Optional[torch.Tensor], ++ quant_dtype: torch.dtype): ++ # Norm ++ torch_out = None ++ if residual is None: ++ torch_out = rms_norm_layer.forward_cuda(x, residual) ++ else: ++ torch_out, _ = rms_norm_layer.forward_cuda(x, residual) ++ ++ # Quant ++ torch_out, _ = ops.scaled_fp8_quant(torch_out) ++ ++ ++def fused_impl( ++ rms_norm_layer: RMSNorm, # this stores the weights ++ x: torch.Tensor, ++ residual: Optional[torch.Tensor], ++ quant_dtype: torch.dtype): ++ out, _ = ops.rms_norm_dynamic_per_token_quant(x, ++ rms_norm_layer.weight, ++ 1e-6, ++ quant_dtype, ++ residual=residual) ++ ++ ++# Bench functions ++def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, ++ quant_dtype: torch.dtype, label: str, sub_label: str, ++ fn: Callable, description: str) -> TMeasurement: ++ ++ min_run_time = 1 ++ ++ globals = { ++ "rms_norm_layer": rms_norm_layer, ++ "x": x, ++ "residual": residual, ++ "quant_dtype": quant_dtype, ++ "fn": fn, ++ } ++ return TBenchmark.Timer( ++ stmt="fn(rms_norm_layer, x, residual, quant_dtype)", ++ globals=globals, ++ label=label, ++ sub_label=sub_label, ++ description=description, ++ ).blocked_autorange(min_run_time=min_run_time) ++ ++def bench(params: bench_params_t, label: str, sub_label: str) \ ++ -> Iterable[TMeasurement]: ++ ++ # Make inputs ++ layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype) ++ # Make weights ++ layer.weight.data.normal_(mean=1.0, std=0.1) ++ # Make inputs ++ scale = 1 / params.hidden_size ++ x = torch.randn(params.num_tokens, ++ params.hidden_size, ++ dtype=params.dtype, ++ device='cuda') * scale ++ residual = (torch.randn_like(x) * scale).to(device='cuda') \ ++ if params.add_residual else None ++ ++ timers = [] ++ ++ # unfused int8 impl. ++ timers.append( ++ bench_fn(layer, x, residual, torch.int8, label, sub_label, ++ unfused_int8_impl, "unfused_int8_impl")) ++ ++ # unfused fp8 impl. ++ timers.append( ++ bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, ++ unfused_fp8_impl, "unfused_fp8_impl")) ++ ++ # fused int8 impl. ++ timers.append( ++ bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl, ++ "fused_int8_impl")) ++ ++ # fused fp8 impl. ++ timers.append( ++ bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, ++ fused_impl, "fused_fp8_impl")) ++ ++ print_timers(timers) ++ ++ return timers ++ ++ ++# launch bench ++# runner ++def print_timers(timers: Iterable[TMeasurement]): ++ compare = TBenchmark.Compare(timers) ++ compare.print() ++ ++ ++def main(): ++ torch.set_default_device('cuda') ++ bench_params = get_bench_params() ++ ++ timers = [] ++ for bp in tqdm(bench_params): ++ timers.extend( ++ bench(bp, "rms-norm-dynamic-per-token-quant", bp.description())) ++ print_timers(timers) ++ ++ # pickle all the results ++ timestamp = int(time.time()) ++ with open(f"rms_norm_dpt_quant-{timestamp}.pkl", "wb") as f: ++ pkl.dump(timers, f) ++ ++ ++if __name__ == '__main__': ++ main() +diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py +index 5939294..601c4ea 100644 +--- a/benchmarks/kernels/benchmark_aqlm.py ++++ b/benchmarks/kernels/benchmark_aqlm.py +@@ -1,4 +1,3 @@ +-import argparse + import os + import sys + from typing import Optional +@@ -10,6 +9,7 @@ from vllm import _custom_ops as ops + from vllm.model_executor.layers.quantization.aqlm import ( + dequantize_weight, generic_dequantize_gemm, get_int_dtype, + optimized_dequantize_gemm) ++from vllm.utils import FlexibleArgumentParser + + os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +@@ -86,9 +86,9 @@ def dequant_no_scale( + # Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against + # the generic pytorch version. + # Just visual comparison. +-def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None: ++def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None: + +- n = parts.sum().item() ++ n = int(parts.sum().item()) + + device = torch.device('cuda:0') + +@@ -137,7 +137,7 @@ def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None: + + def main(): + +- parser = argparse.ArgumentParser(description="Benchmark aqlm performance.") ++ parser = FlexibleArgumentParser(description="Benchmark aqlm performance.") + + # Add arguments + parser.add_argument("--nbooks", +@@ -204,7 +204,7 @@ def main(): + sys.stdout = sys.__stdout__ + + +-def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int, ++def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, + methods): + + # I didn't see visible improvements from increasing these, but feel free :) +@@ -252,10 +252,10 @@ def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int, + print('') + + +-def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor, ++def run_timing(num_calls: int, m: int, k: int, parts: torch.Tensor, + nbooks: int, bits: int, method) -> float: + +- n = parts.sum().item() ++ n = int(parts.sum().item()) + + device = torch.device('cuda:0') + +diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py +new file mode 100644 +index 0000000..7acea60 +--- /dev/null ++++ b/benchmarks/kernels/benchmark_layernorm.py +@@ -0,0 +1,86 @@ ++import time ++ ++import torch ++ ++from vllm.model_executor.layers.layernorm import RMSNorm ++from vllm.platforms import current_platform ++from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser ++ ++ ++@torch.inference_mode() ++def main(num_tokens: int, ++ hidden_size: int, ++ add_residual: bool, ++ dtype: torch.dtype, ++ seed: int = 0, ++ do_profile: bool = False, ++ num_warmup_iters: int = 5, ++ num_iters: int = 100) -> None: ++ current_platform.seed_everything(seed) ++ torch.set_default_device("cuda") ++ ++ layer = RMSNorm(hidden_size).to(dtype=dtype) ++ layer.weight.data.normal_(mean=1.0, std=0.1) ++ scale = 1 / (2 * hidden_size) ++ x = torch.randn(num_tokens, hidden_size, dtype=dtype) ++ x *= scale ++ residual = torch.randn_like(x) * scale if add_residual else None ++ ++ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: ++ torch.cuda.synchronize() ++ if profile: ++ torch.cuda.cudart().cudaProfilerStart() ++ start_time = time.perf_counter() ++ ++ for _ in range(num_iters): ++ layer(x, residual) ++ torch.cuda.synchronize() ++ ++ end_time = time.perf_counter() ++ if profile: ++ torch.cuda.cudart().cudaProfilerStart() ++ return (end_time - start_time) / num_iters ++ ++ # Warmup. ++ print("Warming up...") ++ run_benchmark = run_cuda_benchmark ++ run_benchmark(num_iters=num_warmup_iters, profile=False) ++ ++ # Benchmark. ++ if do_profile: ++ latency = run_benchmark(num_iters=1, profile=True) ++ else: ++ latency = run_benchmark(num_iters=num_iters, profile=False) ++ print(f"Kernel running time: {latency * 1000000:.3f} us") ++ ++ ++if __name__ == '__main__': ++ parser = FlexibleArgumentParser( ++ description="Benchmark the layernorm kernel.") ++ parser.add_argument("--num-tokens", type=int, default=4096) ++ parser.add_argument("--hidden-size", type=int, default=8192) ++ parser.add_argument("--add-residual", action="store_true") ++ parser.add_argument("--dtype", ++ type=str, ++ choices=["half", "bfloat16", "float"], ++ default="half") ++ parser.add_argument("--seed", type=int, default=0) ++ parser.add_argument("--profile", action="store_true") ++ parser.add_argument("--num-warmup-iters", type=int, default=5) ++ parser.add_argument("--num-iters", ++ type=int, ++ default=100, ++ help="Number of benchmark iterations. " ++ "If --profile is set, this number is ignored") ++ ++ args = parser.parse_args() ++ print(args) ++ ++ main(num_tokens=args.num_tokens, ++ hidden_size=args.hidden_size, ++ add_residual=args.add_residual, ++ dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], ++ seed=args.seed, ++ do_profile=args.profile, ++ num_warmup_iters=args.num_warmup_iters, ++ num_iters=args.num_iters) +diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py +new file mode 100644 +index 0000000..46bab74 +--- /dev/null ++++ b/benchmarks/kernels/benchmark_machete.py +@@ -0,0 +1,672 @@ ++import argparse ++import copy ++import itertools ++import math ++import os ++import pickle as pkl ++import time ++from dataclasses import dataclass ++from itertools import product ++from typing import Callable, Iterable, List, Optional, Tuple ++ ++import pandas as pd ++import torch ++import torch.utils.benchmark as TBenchmark ++from torch.utils.benchmark import Measurement as TMeasurement ++from weight_shapes import WEIGHT_SHAPES ++ ++from vllm import _custom_ops as ops ++from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ++ GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales, ++ marlin_zero_points) ++from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( ++ MarlinWorkspace) ++from vllm.model_executor.layers.quantization.utils.quant_utils import ( ++ pack_rows, quantize_weights) ++from vllm.scalar_type import ScalarType, scalar_types ++from vllm.utils import FlexibleArgumentParser ++ ++DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"] ++DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024] ++DEFAULT_TP_SIZES = [1] ++ ++NVTX_PROFILE = os.environ.get("NVTX_PROFILE", False) ++ ++if NVTX_PROFILE: ++ import nvtx ++ ++ ++def terse_type_name(dt): ++ return { ++ torch.bfloat16: "bf16", ++ torch.float16: "fp16", ++ torch.int8: "int8", ++ torch.float8_e4m3fn: "fp8", ++ torch.bfloat16: "bf16", ++ torch.float: "float", ++ torch.int: "int", ++ }[dt] ++ ++ ++@dataclass ++class BenchmarkTensors: ++ w_ref: torch.Tensor ++ a: torch.Tensor ++ ++ w_q: torch.Tensor ++ group_size: Optional[int] ++ wtype: ScalarType ++ w_g_s: torch.Tensor ++ w_g_zp: Optional[torch.Tensor] ++ w_ch_s: Optional[torch.Tensor] ++ w_tok_s: Optional[torch.Tensor] ++ ++ ++@dataclass ++class TypeConfig: ++ act_type: torch.dtype ++ weight_type: ScalarType ++ output_type: Optional[torch.dtype] ++ group_scale_type: Optional[torch.dtype] ++ group_zero_type: Optional[torch.dtype] ++ channel_scale_type: Optional[torch.dtype] ++ token_scale_type: Optional[torch.dtype] ++ ++ ++def rand_data(shape, dtype=torch.float16, scale=1): ++ if dtype.is_floating_point: ++ return (scale * torch.rand(shape, device="cuda") - 0.3).to(dtype) ++ else: ++ return torch.randint(-15, 15, shape, dtype=dtype, device="cuda") ++ ++ ++def quantize_and_pack(atype: torch.dtype, ++ w: torch.Tensor, ++ wtype: ScalarType, ++ stype: Optional[torch.dtype], ++ group_size: Optional[int], ++ zero_points: bool = False): ++ assert wtype.is_integer(), "TODO: support floating point weights" ++ ++ w_ref, w_q, w_s, w_zp = quantize_weights( ++ w, ++ wtype, ++ group_size=group_size, ++ zero_points=zero_points, ++ # to match how the kernel applies zps ++ ref_zero_points_after_scales=True) ++ ++ w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) ++ return w_ref, w_q, w_s, w_zp ++ ++ ++def create_bench_tensors(shape: Tuple[int, int, int], types: TypeConfig, ++ group_size: Optional[int]) -> List[BenchmarkTensors]: ++ m, n, k = shape ++ ++ # we want to make sure that weights don't fit into L2 cache between runs so ++ # we construct enough weights to exceed L2 cache, which is 50mb on a H100 ++ # so we target total weight size > 2*50mb ++ num_weights = math.ceil(2 * 50 * 1024**2 * 8 / ++ (k * n * types.weight_type.size_bits)) ++ ++ a = rand_data((m, k), types.act_type, scale=5) ++ ++ benchmark_tensors: List[BenchmarkTensors] = [] ++ for _ in range(num_weights): ++ w = rand_data((k, n), types.act_type, scale=5) ++ ++ if types.group_scale_type is not None: ++ w = w.to(types.group_scale_type) ++ if w.dtype.itemsize == 1: ++ w = w.to(torch.float16) ++ ++ w_ref, w_q_packed, w_s, w_zp = quantize_and_pack( ++ a.dtype, w, types.weight_type, types.group_scale_type, group_size, ++ types.group_zero_type is not None) ++ ++ if not a.dtype.is_floating_point: ++ aiinfo = torch.iinfo(a.dtype) ++ w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max) ++ ++ w_ref = w_ref.to(torch.float32) ++ ++ w_ch_s = None if types.channel_scale_type is None else\ ++ rand_data((n,), types.channel_scale_type) ++ w_tok_s = None if types.token_scale_type is None else\ ++ rand_data((m,), types.token_scale_type) ++ ++ benchmark_tensors.append( ++ BenchmarkTensors(w_ref=w_ref, ++ a=a, ++ w_q=w_q_packed, ++ wtype=types.weight_type, ++ w_g_s=w_s, ++ w_g_zp=w_zp, ++ group_size=group_size, ++ w_ch_s=w_ch_s, ++ w_tok_s=w_tok_s)) ++ ++ return benchmark_tensors ++ ++ ++def torch_matmul_f16_create_bench_fn(bt: BenchmarkTensors) -> Callable: ++ a = bt.a ++ w = bt.w_ref.to(bt.a.dtype) # use float reference tensor ++ if a.dtype not in [torch.float16, torch.bfloat16]: ++ a = a.to(torch.float16) ++ w = w.to(torch.float16) ++ return lambda: torch.matmul(a, w) ++ ++ ++def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable: ++ if bt.w_ch_s is not None and bt.w_tok_s is not None: ++ scale_a = bt.w_tok_s.to(torch.float32) ++ scale_b = bt.w_ch_s.to(torch.float32) ++ else: ++ scale_a = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) ++ scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) ++ w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t() ++ return lambda: ops.cutlass_scaled_mm( ++ bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16) ++ ++ ++def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: ++ device = bt.a.device ++ ++ workspace = MarlinWorkspace(bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, ++ GPTQ_MARLIN_MAX_PARALLEL) ++ ++ if bt.w_g_zp is None: ++ w_zp = torch.empty(0, dtype=torch.int, device=device) ++ else: ++ w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0], ++ bt.w_ref.shape[1], bt.wtype.size_bits) ++ ++ if bt.group_size is None: ++ w_s = torch.tensor([], device="cuda", dtype=torch.half) ++ else: ++ w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0], ++ bt.w_ref.shape[1], bt.group_size) ++ ++ sort_indices = torch.empty(0, dtype=torch.int, device=device) ++ g_idx = torch.empty(0, dtype=torch.int, device=device) ++ w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0], ++ bt.w_ref.shape[1], bt.wtype.size_bits) ++ ++ if bt.a.dtype.is_floating_point: ++ assert bt.w_ch_s is None ++ assert bt.w_tok_s is None ++ assert bt.group_size is not None ++ ++ fn = lambda: ops.gptq_marlin_gemm(a=bt.a, ++ b_q_weight=w_q, ++ b_scales=w_s, ++ b_zeros=w_zp, ++ g_idx=g_idx, ++ perm=sort_indices, ++ workspace=workspace.scratch, ++ b_q_type=bt.wtype, ++ size_m=bt.a.shape[0], ++ size_n=bt.w_ref.shape[1], ++ size_k=bt.w_ref.shape[0], ++ is_k_full=True, ++ is_zp_float=False) ++ else: ++ assert bt.a.dtype == torch.int8 ++ assert bt.wtype == scalar_types.uint4b8 ++ ++ if bt.w_ch_s is not None: ++ s_ch = bt.w_ch_s.to(torch.float32) ++ else: ++ s_ch = torch.ones(bt.w_ref.shape[1], ++ dtype=torch.float32, ++ device=device) ++ ++ if bt.w_tok_s is not None: ++ s_tok = bt.w_tok_s.to(torch.float32) ++ else: ++ s_tok = torch.ones(bt.a.shape[0], ++ dtype=torch.float32, ++ device=device) ++ ++ fn = lambda: ops.marlin_qqq_gemm(a=bt.a, ++ b_q_weight=w_q, ++ s_group=w_s, ++ s_tok=s_tok, ++ s_ch=s_ch, ++ workspace=workspace.scratch, ++ size_m=bt.a.shape[0], ++ size_n=bt.w_ref.shape[1], ++ size_k=bt.w_ref.shape[0]) ++ ++ return fn ++ ++ ++def machete_create_bench_fn(bt: BenchmarkTensors, ++ out_type=torch.dtype, ++ schedule=None) -> Callable: ++ w_q = bt.w_q.t().contiguous().t() # make col major ++ w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype, ++ None if bt.w_g_s is None else bt.w_g_s.dtype) ++ ++ w_g_zp = bt.w_g_zp ++ if w_g_zp is not None: ++ w_g_zp = -1 * bt.w_g_s * (w_g_zp.to(bt.w_g_s.dtype)) ++ ++ return lambda: ops.machete_mm( ++ a=bt.a, ++ b_q=bt.w_q, ++ b_type=bt.wtype, ++ b_group_scales=bt.w_g_s, ++ b_group_zeros=w_g_zp, ++ b_group_size=bt.group_size, ++ b_channel_scales=bt.w_ch_s, ++ a_token_scales=bt.w_tok_s, ++ out_type=out_type, ++ schedule=schedule, ++ ) ++ ++ ++# impl ++ ++# bench ++ ++ ++def bench_fns(label: str, sub_label: str, description: str, ++ fns: List[Callable]): ++ ++ min_run_time = 1 if not NVTX_PROFILE else 0.1 ++ res = TBenchmark.Timer( ++ stmt=""" ++ for fn in fns: ++ fn() ++ """, ++ globals={ ++ "fns": fns ++ }, ++ label=label, ++ sub_label=sub_label, ++ description=description, ++ ).blocked_autorange(min_run_time=min_run_time) ++ ++ if NVTX_PROFILE: ++ with nvtx.annotate("mm-bench"), nvtx.annotate( ++ f"{label}|{sub_label}|{description}"): ++ fns[0]() ++ ++ return res ++ ++ ++_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None ++_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None ++ ++ ++def bench(types: TypeConfig, ++ group_size: int, ++ m: int, ++ k: int, ++ n: int, ++ label: str, ++ sub_label: str, ++ sweep_schedules: bool = True) -> List[TMeasurement]: ++ benchmark_tensors = create_bench_tensors((m, n, k), types, group_size) ++ sub_label += f", L={len(benchmark_tensors)}" ++ ++ name_type_string = f"W{types.weight_type}"+\ ++ f"-A{terse_type_name(types.act_type)}" ++ if types.group_scale_type is not None: ++ name_type_string += f"-GS{terse_type_name(types.group_scale_type)}" ++ if types.group_zero_type is not None: ++ name_type_string += f"-GZ{terse_type_name(types.group_zero_type)}" ++ if group_size is not None: ++ name_type_string += f"-G{group_size}" ++ if types.channel_scale_type is not None: ++ name_type_string += f"-CS{terse_type_name(types.channel_scale_type)}" ++ if types.token_scale_type is not None: ++ name_type_string += f"-TS{terse_type_name(types.token_scale_type)}" ++ ++ timers = [] ++ # pytorch impl ++ timers.append( ++ bench_fns( ++ label, sub_label, "torch.matmul (fp16)", ++ [torch_matmul_f16_create_bench_fn(bt) ++ for bt in benchmark_tensors])) ++ ++ if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn: ++ timers.append( ++ bench_fns( ++ label, sub_label, ++ f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [ ++ cutlass_scaled_mm_create_bench_fn(bt) ++ for bt in benchmark_tensors ++ ])) ++ ++ if types.act_type != torch.float8_e4m3fn: ++ timers.append( ++ bench_fns(label, sub_label, f"marlin ({name_type_string})", ++ [marlin_create_bench_fn(bt) ++ for bt in benchmark_tensors])) ++ ++ # machete ++ timers.append( ++ bench_fns(label, sub_label, f"machete ({name_type_string})", [ ++ machete_create_bench_fn(bt, out_type=types.output_type) ++ for bt in benchmark_tensors ++ ])) ++ ++ if sweep_schedules: ++ global _SWEEP_SCHEDULES_RESULTS ++ ++ print("Finding best schedule for machete") ++ best = None ++ best_schedule = None ++ schedules = ops.machete_supported_schedules( ++ a_type=types.act_type, ++ b_type=types.weight_type, ++ group_scales_type=types.group_scale_type, ++ group_zeros_type=types.group_zero_type, ++ token_scales_type=types.token_scale_type, ++ channel_scales_type=types.channel_scale_type, ++ out_type=types.output_type) ++ ++ if schedules is None or len(schedules) == 0: ++ raise ValueError("No schedules found to sweep") ++ ++ for schedule in reversed(schedules): ++ schedule_M = int(schedule.split("_")[0].split("x")[1]) ++ ++ # Prune known bad schedules ++ if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4: ++ continue ++ ++ res = bench_fns(label, sub_label, "machete_best", [ ++ machete_create_bench_fn( ++ bt, out_type=types.output_type, schedule=schedule) ++ for bt in benchmark_tensors ++ ]) ++ ++ results_row = { ++ "M": m, ++ "K": k, ++ "N": n, ++ "group_size": group_size, ++ "schedule": schedule, ++ "median": res.median, ++ } ++ if _SWEEP_SCHEDULES_RESULTS is None: ++ _SWEEP_SCHEDULES_RESULTS = pd.DataFrame( ++ columns=results_row.keys()) ++ _SWEEP_SCHEDULES_RESULTS.\ ++ loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row ++ ++ print(f" {res.median:5.5} ", schedule) ++ if not best or res.median < best.median: ++ best = res ++ best_schedule = schedule ++ print("Best schedule:", best_schedule) ++ timers.append(best) ++ ++ return timers ++ ++ ++# runner ++def print_timers(timers: List[TMeasurement]): ++ compare = TBenchmark.Compare(timers) ++ compare.print() ++ ++ ++def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: ++ types = TypeConfig( ++ act_type=args.act_type, ++ weight_type=scalar_types.uint4b8 if args.group_zero_type is None \ ++ else scalar_types.uint4, ++ output_type=args.out_type, ++ group_scale_type=args.group_scale_type, ++ group_zero_type=args.group_zero_type, ++ channel_scale_type=args.channel_scale_type, ++ token_scale_type=args.token_scale_type, ++ ) ++ ++ results: List[TMeasurement] = [] ++ for m, k, n in MKNs: ++ timers = bench(types, ++ args.group_size, ++ m, ++ k, ++ n, ++ f"{args.act_type}-gemm", ++ f"MKN=({m}x{k}x{n})", ++ sweep_schedules=args.sweep_schedules) ++ print_timers(timers) ++ results.extend(timers) ++ ++ return results ++ ++ ++# output makers ++def make_output( ++ data: List[TMeasurement], ++ MKNs: Iterable[Tuple[int, int, int]], ++ base_description: str, ++ timestamp=None, ++): ++ ++ print(f"== All Results {base_description} ====") ++ print_timers(data) ++ ++ # pickle all the results ++ timestamp = int(time.time()) if timestamp is None else timestamp ++ with open(f"{base_description}-{timestamp}.pkl", "wb") as f: ++ pkl.dump(data, f) ++ ++ ++# argparse runners ++ ++ ++def run_square_bench(args): ++ dim_sizes = list( ++ range(args.dim_start, args.dim_end + 1, args.dim_increment)) ++ MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) ++ data = run(args.dtype, args.sweep_schedules, MKNs) ++ ++ make_output(data, MKNs, f"square_bench-{args.dtype}") ++ ++ ++def run_range_bench(args): ++ m_start, k_start, n_start = (int(x) for x in args.dim_start.split(",")) ++ m_end, k_end, n_end = (int(x) for x in args.dim_end.split(",")) ++ m_increment, k_increment, n_increment = \ ++ (int(x) for x in args.dim_increment.split(",")) ++ Ms = list(range(m_start, m_end + 1, m_increment)) ++ Ks = list(range(k_start, k_end + 1, k_increment)) ++ Ns = list(range(n_start, n_end + 1, n_increment)) ++ MKNs = list(product(Ms, Ks, Ns)) ++ ++ data = run(args.dtype, args.sweep_schedules, MKNs) ++ ++ make_output(data, MKNs, f"range_bench-{args.dtype}") ++ ++ ++def run_model_bench(args): ++ ++ print("Benchmarking models:") ++ for i, model in enumerate(args.models): ++ print(f"[{i}] {model}") ++ ++ def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: ++ KNs = [] ++ for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): ++ KN[tp_split_dim] = KN[tp_split_dim] // tp_size ++ KNs.append(KN) ++ return KNs ++ ++ model_bench_data = [] ++ models_tps = list(itertools.product(args.models, args.tp_sizes)) ++ for model, tp_size in models_tps: ++ Ms = args.batch_sizes ++ KNs = model_shapes(model, tp_size) ++ MKNs = [] ++ for m in Ms: ++ for k, n in KNs: ++ MKNs.append((m, k, n)) ++ ++ data = run(args, MKNs) ++ model_bench_data.append(data) ++ ++ type_string = f"{args.act_type}" ++ ++ # Print all results ++ for data, model_tp in zip(model_bench_data, models_tps): ++ model, tp_size = model_tp ++ print(f"== Results {type_string} {model}-TP{tp_size} ====") ++ print_timers(data) ++ ++ timestr = time.strftime("%Y%m%d-%H%M%S") ++ ++ all_results = [] ++ for d in model_bench_data: ++ all_results.extend(d) ++ ++ # pickle all data ++ with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f: ++ args_dict = vars(args) ++ args_dict.pop("func") ++ pkl.dump({ ++ "args": args_dict, ++ "results": all_results, ++ }, f) ++ ++ ++if __name__ == "__main__": ++ ++ def to_torch_dtype(dt): ++ return { ++ "bfloat16": torch.bfloat16, ++ "float16": torch.float16, ++ "int8": torch.int8, ++ "float8_e4m3fn": torch.float8_e4m3fn, ++ "int": torch.int, ++ "float": torch.float, ++ }[dt] ++ ++ class ToTorchDtype(argparse.Action): ++ ++ def __call__(self, parser, namespace, values, option_string=None): ++ setattr(namespace, self.dest, to_torch_dtype(values)) ++ ++ parser = FlexibleArgumentParser( ++ description=""" ++Benchmark Machete GEMM. ++ ++ To run square GEMMs: ++ python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 ++ ++ To run constant N and K and sweep M: ++ python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 ++ ++ To run dimensions from a model: ++ python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 ++ ++ Output: ++ - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. ++ """, # noqa: E501 ++ formatter_class=argparse.RawTextHelpFormatter, ++ ) ++ parser.add_argument( ++ "--act-type", ++ action=ToTorchDtype, ++ required=True, ++ choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'], ++ ) ++ parser.add_argument( ++ "--group-scale-type", ++ action=ToTorchDtype, ++ choices=['bfloat16', 'float16'], ++ ) ++ parser.add_argument( ++ "--group-zero-type", ++ type=to_torch_dtype, ++ choices=['bfloat16', 'float16'], ++ ) ++ parser.add_argument( ++ "--channel-scale-type", ++ action=ToTorchDtype, ++ choices=['float'], ++ ) ++ parser.add_argument( ++ "--token-scale-type", ++ action=ToTorchDtype, ++ choices=['float'], ++ ) ++ parser.add_argument( ++ "--out-type", ++ action=ToTorchDtype, ++ choices=['bfloat16', 'float16'], ++ ) ++ parser.add_argument( ++ "--group-size", ++ type=int, ++ help="Available options are ['None', '-1', '128'], default=128", ++ default=128, ++ ) ++ parser.add_argument( ++ "--sweep-schedules", ++ action="store_true", ++ help="Run a sweep over all supported schedules", ++ ) ++ parser.add_argument("--sweep-csv-out", ++ help="CSV to store sweep results", ++ default="sch_sweep_results.csv") ++ subparsers = parser.add_subparsers(dest="cmd", required=True) ++ ++ square_parser = subparsers.add_parser("square_bench") ++ square_parser.add_argument("--dim-start", type=int, required=True) ++ square_parser.add_argument("--dim-end", type=int, required=True) ++ square_parser.add_argument("--dim-increment", type=int, required=True) ++ square_parser.set_defaults(func=run_square_bench) ++ ++ range_parser = subparsers.add_parser("range_bench") ++ range_parser.add_argument( ++ "--dim-start", ++ type=str, ++ required=True, ++ help="Start value for M,K,N as common separated list") ++ range_parser.add_argument( ++ "--dim-end", ++ type=str, ++ required=True, ++ help="End value (inclusive) for M,K,N as common separated list") ++ range_parser.add_argument( ++ "--dim-increment", ++ type=str, ++ required=True, ++ help="Increment value for M,K,N as common separated list") ++ range_parser.set_defaults(func=run_range_bench) ++ ++ model_parser = subparsers.add_parser("model_bench") ++ model_parser.add_argument( ++ "--models", ++ nargs="+", ++ type=str, ++ default=DEFAULT_MODELS, ++ choices=WEIGHT_SHAPES.keys(), ++ ) ++ model_parser.add_argument("--tp-sizes", ++ nargs="+", ++ type=int, ++ default=DEFAULT_TP_SIZES) ++ model_parser.add_argument("--batch-sizes", ++ nargs="+", ++ type=int, ++ default=DEFAULT_BATCH_SIZES) ++ model_parser.set_defaults(func=run_model_bench) ++ ++ args = parser.parse_args() ++ ++ _SWEEP_SCHEDULES_RESULTS_CSV = args.sweep_csv_out ++ args.func(args) ++ ++ if _SWEEP_SCHEDULES_RESULTS is not None: ++ _SWEEP_SCHEDULES_RESULTS.to_csv(_SWEEP_SCHEDULES_RESULTS_CSV) +diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py +new file mode 100644 +index 0000000..8fb44e3 +--- /dev/null ++++ b/benchmarks/kernels/benchmark_marlin.py +@@ -0,0 +1,254 @@ ++from typing import List ++ ++import torch ++import torch.utils.benchmark as benchmark ++from benchmark_shapes import WEIGHT_SHAPES ++ ++from vllm import _custom_ops as ops ++from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( ++ GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, ++ GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) ++from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ++ GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, ++ MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types) ++from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( ++ MarlinWorkspace, marlin_quantize) ++from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( ++ marlin_24_quantize) ++from vllm.model_executor.layers.quantization.utils.quant_utils import ( ++ gptq_pack, gptq_quantize_weights, sort_weights) ++from vllm.scalar_type import ScalarType ++from vllm.utils import FlexibleArgumentParser ++ ++DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] ++DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] ++ ++ACT_ORDER_OPTS = [False, True] ++K_FULL_OPTS = [False, True] ++ ++ ++def bench_run(results: List[benchmark.Measurement], model: str, ++ act_order: bool, is_k_full: bool, quant_type: ScalarType, ++ group_size: int, size_m: int, size_k: int, size_n: int): ++ label = "Quant Matmul" ++ ++ sub_label = ("{}, act={} k_full={}, q={}, g={}, " ++ "MKN=({}x{}x{})".format(model, act_order, is_k_full, ++ str(quant_type), group_size, size_m, ++ size_k, size_n)) ++ ++ print(f"Testing: {sub_label}") ++ ++ a = torch.randn(size_m, size_k).to(torch.half).cuda() ++ b = torch.rand(size_k, size_n).to(torch.half).cuda() ++ ++ a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda()) ++ ++ # Marlin quant ++ ( ++ marlin_w_ref, ++ marlin_q_w, ++ marlin_s, ++ marlin_g_idx, ++ marlin_sort_indices, ++ marlin_rand_perm, ++ ) = marlin_quantize(b, quant_type, group_size, act_order) ++ ++ # Marlin_24 quant ++ (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, ++ marlin_24_s) = marlin_24_quantize(b, quant_type, group_size) ++ ++ marlin_zp = torch.empty(0, dtype=torch.int, device=b.device) ++ ++ # GPTQ quant ++ (w_ref, q_w, s, g_idx, ++ rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order) ++ q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) ++ ++ # For act_order, sort the "weights" and "g_idx" ++ # so that group ids are increasing ++ repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device) ++ if act_order: ++ (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) ++ ++ # Prepare ++ marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, ++ GPTQ_MARLIN_MAX_PARALLEL) ++ ++ marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, ++ GPTQ_MARLIN_24_MAX_PARALLEL) ++ marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int) ++ ++ globals = { ++ # Gen params ++ "quant_type": quant_type, ++ "group_size": group_size, ++ "size_m": size_m, ++ "size_n": size_n, ++ "size_k": size_k, ++ "a": a, ++ "a_tmp": a_tmp, ++ # Marlin params ++ "marlin_w_ref": marlin_w_ref, ++ "marlin_q_w": marlin_q_w, ++ "marlin_s": marlin_s, ++ "marlin_zp": marlin_zp, ++ "marlin_g_idx": marlin_g_idx, ++ "marlin_sort_indices": marlin_sort_indices, ++ "marlin_rand_perm": marlin_rand_perm, ++ "marlin_workspace": marlin_workspace, ++ "is_k_full": is_k_full, ++ # Marlin_24 params ++ "marlin_24_w_ref": marlin_24_w_ref, ++ "marlin_24_q_w_comp": marlin_24_q_w_comp, ++ "marlin_24_meta": marlin_24_meta, ++ "marlin_24_s": marlin_24_s, ++ "marlin_24_workspace": marlin_24_workspace, ++ # GPTQ params ++ "q_w_gptq": q_w_gptq, ++ "repack_sort_indices": repack_sort_indices, ++ # Kernels ++ "gptq_marlin_gemm": ops.gptq_marlin_gemm, ++ "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, ++ "gptq_marlin_repack": ops.gptq_marlin_repack, ++ } ++ ++ min_run_time = 1 ++ ++ # Warmup pytorch ++ for i in range(5): ++ torch.matmul(a, marlin_w_ref) ++ ++ results.append( ++ benchmark.Timer( ++ stmt="torch.matmul(a, marlin_w_ref)", ++ globals=globals, ++ label=label, ++ sub_label=sub_label, ++ description="pytorch_gemm", ++ ).blocked_autorange(min_run_time=min_run_time)) ++ ++ results.append( ++ benchmark.Timer( ++ stmt= ++ "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 ++ globals=globals, ++ label=label, ++ sub_label=sub_label, ++ description="gptq_marlin_gemm_fp16", ++ ).blocked_autorange(min_run_time=min_run_time)) ++ ++ results.append( ++ benchmark.Timer( ++ stmt= ++ "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 ++ globals=globals, ++ label=label, ++ sub_label=sub_label, ++ description="gptq_marlin_gemm_fp32", ++ ).blocked_autorange(min_run_time=min_run_time)) ++ ++ if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES ++ and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES): ++ results.append( ++ benchmark.Timer( ++ stmt= ++ "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501 ++ globals=globals, ++ label=label, ++ sub_label=sub_label, ++ description="gptq_marlin_24_gemm", ++ ).blocked_autorange(min_run_time=min_run_time)) ++ ++ results.append( ++ benchmark.Timer( ++ stmt= ++ "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501 ++ globals=globals, ++ label=label, ++ sub_label=sub_label, ++ description="gptq_marlin_repack", ++ ).blocked_autorange(min_run_time=min_run_time)) ++ ++ ++def main(args): ++ print("Benchmarking models:") ++ for i, model in enumerate(args.models): ++ print(f"[{i}] {model}") ++ ++ results: List[benchmark.Measurement] = [] ++ ++ for model in args.models: ++ for layer in WEIGHT_SHAPES[model]: ++ size_k = layer[0] ++ size_n = layer[1] ++ ++ if len(args.limit_k) > 0 and size_k not in args.limit_k: ++ continue ++ ++ if len(args.limit_n) > 0 and size_n not in args.limit_n: ++ continue ++ ++ for act_order in ACT_ORDER_OPTS: ++ if len(args.limit_act_order ++ ) > 0 and act_order not in args.limit_act_order: ++ continue ++ ++ for is_k_full in K_FULL_OPTS: ++ if len(args.limit_k_full ++ ) > 0 and is_k_full not in args.limit_k_full: ++ continue ++ ++ for quant_type in query_marlin_supported_quant_types( ++ False): ++ if len(args.limit_num_bits) > 0 and \ ++ quant_type.size_bits not in args.limit_num_bits: ++ continue ++ ++ for group_size in MARLIN_SUPPORTED_GROUP_SIZES: ++ if len( ++ args.limit_group_size ++ ) > 0 and group_size not in args.limit_group_size: ++ continue ++ ++ # For act_order, the group_size must be less than ++ # size_k ++ if act_order and (group_size == size_k ++ or group_size == -1): ++ continue ++ ++ for size_m in args.batch_sizes: ++ bench_run(results, model, act_order, is_k_full, ++ quant_type, group_size, size_m, ++ size_k, size_n) ++ ++ compare = benchmark.Compare(results) ++ compare.print() ++ ++ ++# For quick benchmarking use: ++# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501 ++# ++if __name__ == "__main__": ++ parser = FlexibleArgumentParser( ++ description="Benchmark Marlin across specified models/shapes/batches") ++ parser.add_argument( ++ "--models", ++ nargs="+", ++ type=str, ++ default=DEFAULT_MODELS, ++ choices=WEIGHT_SHAPES.keys(), ++ ) ++ parser.add_argument("--batch-sizes", ++ nargs="+", ++ type=int, ++ default=DEFAULT_BATCH_SIZES) ++ parser.add_argument("--limit-k", nargs="+", type=int, default=[]) ++ parser.add_argument("--limit-n", nargs="+", type=int, default=[]) ++ parser.add_argument("--limit-group-size", nargs="+", type=int, default=[]) ++ parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[]) ++ parser.add_argument("--limit-act-order", nargs="+", type=int, default=[]) ++ parser.add_argument("--limit-k-full", nargs="+", type=int, default=[]) ++ ++ args = parser.parse_args() ++ main(args) +diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py +new file mode 100644 +index 0000000..8f538c2 +--- /dev/null ++++ b/benchmarks/kernels/benchmark_moe.py +@@ -0,0 +1,367 @@ ++import argparse ++import time ++from datetime import datetime ++from typing import Any, Dict, List, Tuple, TypedDict ++ ++import ray ++import torch ++import triton ++from ray.experimental.tqdm_ray import tqdm ++from transformers import AutoConfig ++ ++from vllm.model_executor.layers.fused_moe.fused_moe import * ++from vllm.platforms import current_platform ++from vllm.utils import FlexibleArgumentParser ++ ++ ++class BenchmarkConfig(TypedDict): ++ BLOCK_SIZE_M: int ++ BLOCK_SIZE_N: int ++ BLOCK_SIZE_K: int ++ GROUP_SIZE_M: int ++ num_warps: int ++ num_stages: int ++ ++ ++def benchmark_config( ++ config: BenchmarkConfig, ++ num_tokens: int, ++ num_experts: int, ++ shard_intermediate_size: int, ++ hidden_size: int, ++ topk: int, ++ dtype: torch.dtype, ++ use_fp8_w8a8: bool, ++ use_int8_w8a16: bool, ++ num_iters: int = 100, ++) -> float: ++ init_dtype = torch.float16 if use_fp8_w8a8 else dtype ++ x = torch.randn(num_tokens, hidden_size, dtype=dtype) ++ if use_int8_w8a16: ++ w1 = torch.randint(-127, ++ 127, ( ++ num_experts, ++ shard_intermediate_size, ++ hidden_size, ++ ), ++ dtype=torch.int8) ++ w2 = torch.randint(-127, ++ 127, ( ++ num_experts, ++ hidden_size, ++ shard_intermediate_size // 2, ++ ), ++ dtype=torch.int8) ++ else: ++ w1 = torch.randn(num_experts, ++ shard_intermediate_size, ++ hidden_size, ++ dtype=init_dtype) ++ w2 = torch.randn(num_experts, ++ hidden_size, ++ shard_intermediate_size // 2, ++ dtype=init_dtype) ++ gating_output = torch.randn(num_iters, ++ num_tokens, ++ num_experts, ++ dtype=torch.float32) ++ ++ w1_scale = None ++ w2_scale = None ++ a1_scale = None ++ a2_scale = None ++ if use_int8_w8a16: ++ w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size), ++ dtype=torch.float32) ++ w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) ++ if use_fp8_w8a8: ++ w1_scale = torch.randn(num_experts, dtype=torch.float32) ++ w2_scale = torch.randn(num_experts, dtype=torch.float32) ++ a1_scale = torch.randn(1, dtype=torch.float32) ++ a2_scale = torch.randn(1, dtype=torch.float32) ++ ++ w1 = w1.to(torch.float8_e4m3fn) ++ w2 = w2.to(torch.float8_e4m3fn) ++ ++ input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) ++ ++ def prepare(i: int): ++ input_gating.copy_(gating_output[i]) ++ ++ def run(): ++ from vllm.model_executor.layers.fused_moe import override_config ++ with override_config(config): ++ fused_moe( ++ x, ++ w1, ++ w2, ++ input_gating, ++ topk, ++ renormalize=True, ++ inplace=True, ++ use_fp8_w8a8=use_fp8_w8a8, ++ use_int8_w8a16=use_int8_w8a16, ++ w1_scale=w1_scale, ++ w2_scale=w2_scale, ++ a1_scale=a1_scale, ++ a2_scale=a2_scale, ++ ) ++ ++ # JIT compilation & warmup ++ run() ++ torch.cuda.synchronize() ++ ++ # Capture 10 invocations with CUDA graph ++ graph = torch.cuda.CUDAGraph() ++ with torch.cuda.graph(graph): ++ for _ in range(10): ++ run() ++ torch.cuda.synchronize() ++ ++ # Warmup ++ for _ in range(5): ++ graph.replay() ++ torch.cuda.synchronize() ++ ++ start_event = torch.cuda.Event(enable_timing=True) ++ end_event = torch.cuda.Event(enable_timing=True) ++ ++ latencies: List[float] = [] ++ for i in range(num_iters): ++ prepare(i) ++ torch.cuda.synchronize() ++ ++ start_event.record() ++ graph.replay() ++ end_event.record() ++ end_event.synchronize() ++ latencies.append(start_event.elapsed_time(end_event)) ++ avg = sum(latencies) / (num_iters * 10) * 1000 # us ++ graph.reset() ++ return avg ++ ++ ++def get_configs_compute_bound() -> List[Dict[str, int]]: ++ # Reduced search space for faster tuning. ++ # TODO(woosuk): Increase the search space and use a performance model to ++ # prune the search space. ++ configs: List[BenchmarkConfig] = [] ++ for num_stages in [2, 3, 4, 5]: ++ for block_m in [16, 32, 64, 128, 256]: ++ for block_k in [64, 128, 256]: ++ for block_n in [32, 64, 128, 256]: ++ for num_warps in [4, 8]: ++ for group_size in [1, 16, 32, 64]: ++ configs.append({ ++ "BLOCK_SIZE_M": block_m, ++ "BLOCK_SIZE_N": block_n, ++ "BLOCK_SIZE_K": block_k, ++ "GROUP_SIZE_M": group_size, ++ "num_warps": num_warps, ++ "num_stages": num_stages, ++ }) ++ return configs ++ ++ ++@ray.remote(num_gpus=1) ++class BenchmarkWorker: ++ ++ def __init__(self, seed: int) -> None: ++ torch.set_default_device("cuda") ++ current_platform.seed_everything(seed) ++ self.seed = seed ++ ++ def benchmark( ++ self, ++ num_tokens: int, ++ num_experts: int, ++ shard_intermediate_size: int, ++ hidden_size: int, ++ topk: int, ++ dtype: torch.dtype, ++ use_fp8_w8a8: bool, ++ use_int8_w8a16: bool, ++ ) -> Tuple[Dict[str, int], float]: ++ current_platform.seed_everything(self.seed) ++ dtype_str = get_config_dtype_str(dtype, ++ use_int8_w8a16=use_int8_w8a16, ++ use_fp8_w8a8=use_fp8_w8a8) ++ # NOTE(woosuk): The current naming convention uses w2.shape[2], which ++ # is the intermediate size after silu_and_mul. ++ op_config = get_moe_configs(num_experts, shard_intermediate_size // 2, ++ dtype_str) ++ if op_config is None: ++ config = get_default_config(num_tokens, num_experts, ++ shard_intermediate_size, hidden_size, ++ topk, dtype_str) ++ else: ++ config = op_config[min(op_config.keys(), ++ key=lambda x: abs(x - num_tokens))] ++ kernel_time = benchmark_config(config, num_tokens, num_experts, ++ shard_intermediate_size, hidden_size, ++ topk, dtype, use_fp8_w8a8, ++ use_int8_w8a16) ++ return config, kernel_time ++ ++ def tune( ++ self, ++ num_tokens: int, ++ num_experts: int, ++ shard_intermediate_size: int, ++ hidden_size: int, ++ topk: int, ++ dtype: torch.dtype, ++ use_fp8_w8a8: bool, ++ use_int8_w8a16: bool, ++ search_space: List[Dict[str, int]], ++ ) -> Dict[str, int]: ++ best_config = None ++ best_time = float("inf") ++ for config in tqdm(search_space): ++ try: ++ kernel_time = benchmark_config(config, ++ num_tokens, ++ num_experts, ++ shard_intermediate_size, ++ hidden_size, ++ topk, ++ dtype, ++ use_fp8_w8a8, ++ use_int8_w8a16, ++ num_iters=10) ++ except triton.runtime.autotuner.OutOfResources: ++ # Some configurations may be invalid and fail to compile. ++ continue ++ ++ if kernel_time < best_time: ++ best_time = kernel_time ++ best_config = config ++ now = datetime.now() ++ print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") ++ assert best_config is not None ++ return best_config ++ ++ ++def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: ++ return { ++ "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], ++ "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], ++ "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], ++ "GROUP_SIZE_M": config["GROUP_SIZE_M"], ++ "num_warps": config["num_warps"], ++ "num_stages": config["num_stages"], ++ } ++ ++ ++def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int, ++ shard_intermediate_size: int, hidden_size: int, topk: int, ++ dtype: torch.dtype, use_fp8_w8a8: bool, ++ use_int8_w8a16: bool) -> None: ++ dtype_str = get_config_dtype_str(dtype, ++ use_int8_w8a16=use_int8_w8a16, ++ use_fp8_w8a8=use_fp8_w8a8) ++ ++ # NOTE(woosuk): The current naming convention uses w2.shape[2], which ++ # is the intermediate size after silu_and_mul. ++ filename = get_config_file_name(num_experts, shard_intermediate_size // 2, ++ dtype_str) ++ ++ print(f"Writing best config to {filename}...") ++ with open(filename, "w") as f: ++ json.dump(configs, f, indent=4) ++ f.write("\n") ++ ++ ++def main(args: argparse.Namespace): ++ print(args) ++ ++ config = AutoConfig.from_pretrained(args.model) ++ if config.architectures[0] == "DbrxForCausalLM": ++ E = config.ffn_config.moe_num_experts ++ topk = config.ffn_config.moe_top_k ++ intermediate_size = config.ffn_config.ffn_hidden_size ++ shard_intermediate_size = 2 * intermediate_size // args.tp_size ++ elif config.architectures[0] == "JambaForCausalLM": ++ E = config.num_experts ++ topk = config.num_experts_per_tok ++ intermediate_size = config.intermediate_size ++ shard_intermediate_size = 2 * intermediate_size // args.tp_size ++ else: ++ # Default: Mixtral. ++ E = config.num_local_experts ++ topk = config.num_experts_per_tok ++ intermediate_size = config.intermediate_size ++ shard_intermediate_size = 2 * intermediate_size // args.tp_size ++ ++ hidden_size = config.hidden_size ++ dtype = config.torch_dtype ++ use_fp8_w8a8 = args.dtype == "fp8_w8a8" ++ use_int8_w8a16 = args.dtype == "int8_w8a16" ++ ++ if args.batch_size is None: ++ batch_sizes = [ ++ 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, ++ 2048, 3072, 4096 ++ ] ++ else: ++ batch_sizes = [args.batch_size] ++ ++ ray.init() ++ num_gpus = int(ray.available_resources()["GPU"]) ++ workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] ++ ++ def _distribute(method: str, inputs: List[Any]) -> List[Any]: ++ outputs = [] ++ worker_idx = 0 ++ for input_args in inputs: ++ worker = workers[worker_idx] ++ worker_method = getattr(worker, method) ++ output = worker_method.remote(*input_args) ++ outputs.append(output) ++ worker_idx = (worker_idx + 1) % num_gpus ++ return ray.get(outputs) ++ ++ if args.tune: ++ search_space = get_configs_compute_bound() ++ print(f"Start tuning over {len(search_space)} configurations...") ++ ++ start = time.time() ++ configs = _distribute( ++ "tune", [(batch_size, E, shard_intermediate_size, hidden_size, ++ topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space) ++ for batch_size in batch_sizes]) ++ best_configs = { ++ M: sort_config(config) ++ for M, config in zip(batch_sizes, configs) ++ } ++ save_configs(best_configs, E, shard_intermediate_size, hidden_size, ++ topk, dtype, use_fp8_w8a8, use_int8_w8a16) ++ end = time.time() ++ print(f"Tuning took {end - start:.2f} seconds") ++ else: ++ outputs = _distribute( ++ "benchmark", [(batch_size, E, shard_intermediate_size, hidden_size, ++ topk, dtype, use_fp8_w8a8, use_int8_w8a16) ++ for batch_size in batch_sizes]) ++ ++ for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): ++ print(f"Batch size: {batch_size}, config: {config}") ++ print(f"Kernel time: {kernel_time:.2f} us") ++ ++ ++if __name__ == "__main__": ++ parser = FlexibleArgumentParser() ++ parser.add_argument("--model", ++ type=str, ++ default="mistralai/Mixtral-8x7B-Instruct-v0.1") ++ parser.add_argument("--tp-size", "-tp", type=int, default=2) ++ parser.add_argument("--dtype", ++ type=str, ++ choices=["auto", "fp8_w8a8", "int8_w8a16"], ++ default="auto") ++ parser.add_argument("--seed", type=int, default=0) ++ parser.add_argument("--batch-size", type=int, required=False) ++ parser.add_argument("--tune", action="store_true") ++ args = parser.parse_args() ++ ++ main(args) +diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py +index ca7967c..14eef00 100644 +--- a/benchmarks/kernels/benchmark_paged_attention.py ++++ b/benchmarks/kernels/benchmark_paged_attention.py +@@ -1,12 +1,13 @@ +-import argparse + import random + import time +-from typing import Optional ++from typing import List, Optional + + import torch + + from vllm import _custom_ops as ops +-from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random ++from vllm.platforms import current_platform ++from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, ++ create_kv_caches_with_random) + + NUM_BLOCKS = 1024 + PARTITION_SIZE = 512 +@@ -28,10 +29,7 @@ def main( + device: str = "cuda", + kv_cache_dtype: Optional[str] = None, + ) -> None: +- random.seed(seed) +- torch.random.manual_seed(seed) +- if torch.cuda.is_available(): +- torch.cuda.manual_seed(seed) ++ current_platform.seed_everything(seed) + + scale = float(1.0 / (head_size**0.5)) + query = torch.empty(num_seqs, +@@ -54,14 +52,17 @@ def main( + + # Create the block tables. + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size +- block_tables = [] ++ block_tables_lst: List[List[int]] = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] +- block_tables.append(block_table) +- block_tables = torch.tensor(block_tables, dtype=torch.int, device=device) ++ block_tables_lst.append(block_table) ++ ++ block_tables = torch.tensor(block_tables_lst, ++ dtype=torch.int, ++ device=device) + + # Create the KV cache. + key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, +@@ -97,7 +98,7 @@ def main( + start_time = time.perf_counter() + + # Using default kv_scale +- kv_scale = 1.0 ++ k_scale = v_scale = 1.0 + + for _ in range(num_iters): + if version == "v1": +@@ -114,7 +115,8 @@ def main( + max_seq_len, + alibi_slopes, + kv_cache_dtype, +- kv_scale, ++ k_scale, ++ v_scale, + ) + elif version == "v2": + ops.paged_attention_v2( +@@ -133,7 +135,8 @@ def main( + max_seq_len, + alibi_slopes, + kv_cache_dtype, +- kv_scale, ++ k_scale, ++ v_scale, + ) + else: + raise ValueError(f"Invalid version: {version}") +@@ -158,19 +161,19 @@ def main( + + + if __name__ == '__main__': +- parser = argparse.ArgumentParser( ++ parser = FlexibleArgumentParser( + description="Benchmark the paged attention kernel.") + parser.add_argument("--version", + type=str, + choices=["v1", "v2"], + default="v2") + parser.add_argument("--batch-size", type=int, default=8) +- parser.add_argument("--seq_len", type=int, default=4096) ++ parser.add_argument("--seq-len", type=int, default=4096) + parser.add_argument("--num-query-heads", type=int, default=64) + parser.add_argument("--num-kv-heads", type=int, default=8) + parser.add_argument("--head-size", + type=int, +- choices=[64, 80, 96, 112, 128, 256], ++ choices=[64, 80, 96, 112, 120, 128, 192, 256], + default=128) + parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) + parser.add_argument("--use-alibi", action="store_true") +@@ -183,13 +186,11 @@ if __name__ == '__main__': + parser.add_argument( + "--kv-cache-dtype", + type=str, +- choices=["auto", "fp8"], ++ choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"], + default="auto", +- help= +- 'Data type for kv cache storage. If "auto", will use model data type. ' +- 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' +- 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' +- 'common inference criteria.') ++ help="Data type for kv cache storage. If 'auto', will use model " ++ "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " ++ "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)") + args = parser.parse_args() + print(args) + +diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py +new file mode 100644 +index 0000000..1d62483 +--- /dev/null ++++ b/benchmarks/kernels/benchmark_quant.py +@@ -0,0 +1,100 @@ ++import time ++ ++import torch ++ ++from vllm import _custom_ops as ops ++from vllm.platforms import current_platform ++from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser ++ ++ ++@torch.inference_mode() ++def main(num_tokens: int, ++ hidden_size: int, ++ static_scale: bool, ++ quant_dtype: torch.dtype, ++ dtype: torch.dtype, ++ seed: int = 0, ++ do_profile: bool = False, ++ num_warmup_iters: int = 5, ++ num_iters: int = 100) -> None: ++ current_platform.seed_everything(seed) ++ torch.set_default_device("cuda") ++ ++ x = torch.randn(num_tokens, hidden_size, dtype=dtype) ++ scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None ++ ++ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: ++ torch.cuda.synchronize() ++ if profile: ++ torch.cuda.cudart().cudaProfilerStart() ++ start_time = time.perf_counter() ++ ++ for _ in range(num_iters): ++ if quant_dtype == torch.int8: ++ ops.scaled_int8_quant(x, scale) ++ else: ++ ops.scaled_fp8_quant(x, scale) ++ torch.cuda.synchronize() ++ ++ end_time = time.perf_counter() ++ if profile: ++ torch.cuda.cudart().cudaProfilerStart() ++ return (end_time - start_time) / num_iters ++ ++ # Warmup. ++ print("Warming up...") ++ run_benchmark = run_cuda_benchmark ++ run_benchmark(num_iters=num_warmup_iters, profile=False) ++ ++ # Benchmark. ++ if do_profile: ++ latency = run_benchmark(num_iters=1, profile=True) ++ else: ++ latency = run_benchmark(num_iters=num_iters, profile=False) ++ print(f"Kernel running time: {latency * 1000000:.3f} us") ++ ++ ++if __name__ == '__main__': ++ ++ def to_torch_dtype(dt): ++ if dt == "int8": ++ return torch.int8 ++ if dt == "fp8": ++ return torch.float8_e4m3fn ++ raise ValueError(f"Unsupported dtype: {dt}") ++ ++ parser = FlexibleArgumentParser( ++ description="Benchmark the quantization (fp8 or int8) kernel.") ++ parser.add_argument("--num-tokens", type=int, default=4096) ++ parser.add_argument("--hidden-size", type=int, default=8192) ++ parser.add_argument("--static-scale", action="store_true") ++ parser.add_argument("--quant-dtype", ++ type=str, ++ choices=["fp8", "int8"], ++ default="int8") ++ parser.add_argument("--dtype", ++ type=str, ++ choices=["half", "bfloat16", "float"], ++ default="half") ++ ++ parser.add_argument("--seed", type=int, default=0) ++ parser.add_argument("--profile", action="store_true") ++ parser.add_argument("--num-warmup-iters", type=int, default=5) ++ parser.add_argument("--num-iters", ++ type=int, ++ default=100, ++ help="Number of benchmark iterations. " ++ "If --profile is set, this number is ignored") ++ ++ args = parser.parse_args() ++ print(args) ++ ++ main(num_tokens=args.num_tokens, ++ hidden_size=args.hidden_size, ++ static_scale=args.static_scale, ++ quant_dtype=to_torch_dtype(args.quant_dtype), ++ dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], ++ seed=args.seed, ++ do_profile=args.profile, ++ num_warmup_iters=args.num_warmup_iters, ++ num_iters=args.num_iters) +diff --git a/benchmarks/kernels/benchmark_rmsnorm.py b/benchmarks/kernels/benchmark_rmsnorm.py +new file mode 100644 +index 0000000..baa5de0 +--- /dev/null ++++ b/benchmarks/kernels/benchmark_rmsnorm.py +@@ -0,0 +1,262 @@ ++import itertools ++from typing import Optional, Tuple, Union ++ ++import torch ++import triton ++from flashinfer.norm import fused_add_rmsnorm, rmsnorm ++from torch import nn ++ ++from vllm import _custom_ops as vllm_ops ++ ++ ++class HuggingFaceRMSNorm(nn.Module): ++ ++ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: ++ super().__init__() ++ self.weight = nn.Parameter(torch.ones(hidden_size)) ++ self.variance_epsilon = eps ++ ++ def forward( ++ self, ++ x: torch.Tensor, ++ residual: Optional[torch.Tensor] = None, ++ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ++ orig_dtype = x.dtype ++ x = x.to(torch.float32) ++ if residual is not None: ++ x = x + residual.to(torch.float32) ++ residual = x.to(orig_dtype) ++ ++ variance = x.pow(2).mean(dim=-1, keepdim=True) ++ x = x * torch.rsqrt(variance + self.variance_epsilon) ++ x = x.to(orig_dtype) * self.weight ++ if residual is None: ++ return x ++ else: ++ return x, residual ++ ++ ++def rmsnorm_naive( ++ x: torch.Tensor, ++ weight: torch.Tensor, ++ residual: Optional[torch.Tensor] = None, ++ eps: float = 1e-6, ++): ++ naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps) ++ naive_norm.weight = nn.Parameter(weight) ++ naive_norm = naive_norm.to(x.device) ++ ++ orig_shape = x.shape ++ x = x.view(-1, x.shape[-1]) ++ if residual is not None: ++ residual = residual.view(-1, residual.shape[-1]) ++ ++ output = naive_norm(x, residual) ++ ++ if isinstance(output, tuple): ++ output = (output[0].view(orig_shape), output[1].view(orig_shape)) ++ else: ++ output = output.view(orig_shape) ++ return output ++ ++ ++def rmsnorm_flashinfer( ++ x: torch.Tensor, ++ weight: torch.Tensor, ++ residual: Optional[torch.Tensor] = None, ++ eps: float = 1e-6, ++): ++ orig_shape = x.shape ++ x = x.view(-1, x.shape[-1]) ++ if residual is not None: ++ residual = residual.view(-1, residual.shape[-1]) ++ ++ if residual is not None: ++ fused_add_rmsnorm(x, residual, weight, eps) ++ output = (x, residual) ++ else: ++ output = rmsnorm(x, weight, eps) ++ ++ if isinstance(output, tuple): ++ output = (output[0].view(orig_shape), output[1].view(orig_shape)) ++ else: ++ output = output.view(orig_shape) ++ return output ++ ++ ++def rmsnorm_vllm( ++ x: torch.Tensor, ++ weight: torch.Tensor, ++ residual: Optional[torch.Tensor] = None, ++ eps: float = 1e-6, ++): ++ orig_shape = x.shape ++ x = x.view(-1, x.shape[-1]) ++ if residual is not None: ++ residual = residual.view(-1, residual.shape[-1]) ++ ++ if residual is not None: ++ vllm_ops.fused_add_rms_norm(x, residual, weight, eps) ++ output = (x, residual) ++ else: ++ out = torch.empty_like(x) ++ vllm_ops.rms_norm(out, x, weight, eps) ++ output = out ++ ++ if isinstance(output, tuple): ++ output = (output[0].view(orig_shape), output[1].view(orig_shape)) ++ else: ++ output = output.view(orig_shape) ++ return output ++ ++ ++def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): ++ dtype = torch.bfloat16 ++ x = torch.randn(batch_size, ++ seq_len, ++ hidden_size, ++ dtype=dtype, ++ device="cuda") ++ weight = torch.ones(hidden_size, dtype=dtype, device="cuda") ++ residual = torch.randn_like(x) if use_residual else None ++ ++ output_naive = rmsnorm_naive( ++ x.clone(), weight, ++ residual.clone() if residual is not None else None) ++ output_flashinfer = rmsnorm_flashinfer( ++ x.clone(), weight, ++ residual.clone() if residual is not None else None) ++ output_vllm = rmsnorm_vllm( ++ x.clone(), weight, ++ residual.clone() if residual is not None else None) ++ ++ if use_residual: ++ output_naive = output_naive[0] ++ output_flashinfer = output_flashinfer[0] ++ output_vllm = output_vllm[0] ++ ++ print(f"Naive output={output_naive}") ++ print(f"FlashInfer output={output_flashinfer}") ++ print(f"VLLM output={output_vllm}") ++ ++ if torch.allclose(output_naive, output_flashinfer, atol=1e-2, ++ rtol=1e-2) and torch.allclose( ++ output_naive, output_vllm, atol=1e-2, rtol=1e-2): ++ print("✅ All implementations match") ++ else: ++ print("❌ Implementations differ") ++ ++ ++batch_size_range = [2**i for i in range(0, 7, 2)] ++seq_length_range = [2**i for i in range(6, 11, 1)] ++head_num_range = [32, 48] ++configs = list( ++ itertools.product(head_num_range, batch_size_range, seq_length_range)) ++ ++ ++def get_benchmark(use_residual): ++ ++ @triton.testing.perf_report( ++ triton.testing.Benchmark( ++ x_names=["head_num", "batch_size", "seq_len"], ++ x_vals=[list(_) for _ in configs], ++ line_arg="provider", ++ line_vals=["huggingface", "flashinfer", "vllm"], ++ line_names=["HuggingFace", "FlashInfer", "vLLM"], ++ styles=[("blue", "-"), ("green", "-"), ("red", "-")], ++ ylabel="us", ++ plot_name= ++ f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual", ++ args={}, ++ )) ++ def benchmark(head_num, batch_size, seq_len, provider): ++ dtype = torch.bfloat16 ++ hidden_size = head_num * 128 # assuming head_dim = 128 ++ ++ x = torch.randn(batch_size, ++ seq_len, ++ hidden_size, ++ dtype=dtype, ++ device="cuda") ++ weight = torch.ones(hidden_size, dtype=dtype, device="cuda") ++ residual = torch.randn_like(x) if use_residual else None ++ ++ quantiles = [0.5, 0.2, 0.8] ++ ++ if provider == "huggingface": ++ ms, min_ms, max_ms = triton.testing.do_bench( ++ lambda: rmsnorm_naive( ++ x.clone(), ++ weight, ++ residual.clone() if residual is not None else None, ++ ), ++ quantiles=quantiles, ++ ) ++ elif provider == "flashinfer": ++ ms, min_ms, max_ms = triton.testing.do_bench( ++ lambda: rmsnorm_flashinfer( ++ x.clone(), ++ weight, ++ residual.clone() if residual is not None else None, ++ ), ++ quantiles=quantiles, ++ ) ++ else: ++ ms, min_ms, max_ms = triton.testing.do_bench( ++ lambda: rmsnorm_vllm( ++ x.clone(), ++ weight, ++ residual.clone() if residual is not None else None, ++ ), ++ quantiles=quantiles, ++ ) ++ ++ return 1000 * ms, 1000 * max_ms, 1000 * min_ms ++ ++ return benchmark ++ ++ ++if __name__ == "__main__": ++ import argparse ++ ++ parser = argparse.ArgumentParser() ++ parser.add_argument( ++ "--batch-size", ++ type=int, ++ default=4, ++ help="Batch size", ++ ) ++ parser.add_argument( ++ "--seq-len", ++ type=int, ++ default=128, ++ help="Sequence length", ++ ) ++ parser.add_argument( ++ "--hidden-size", ++ type=int, ++ default=4096, ++ help="Hidden size (2nd dimension) of the sequence", ++ ) ++ parser.add_argument("--use-residual", ++ action="store_true", ++ help="Whether to use residual connection") ++ parser.add_argument( ++ "--save-path", ++ type=str, ++ default="./configs/rmsnorm/", ++ help="Path to save rmsnorm benchmark results", ++ ) ++ ++ args = parser.parse_args() ++ ++ # Run correctness test ++ calculate_diff(batch_size=args.batch_size, ++ seq_len=args.seq_len, ++ hidden_size=args.hidden_size, ++ use_residual=args.use_residual) ++ ++ # Get the benchmark function with proper use_residual setting ++ benchmark = get_benchmark(args.use_residual) ++ # Run performance benchmark ++ benchmark.run(print_data=True, save_path=args.save_path) +diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py +index 9188e81..250d505 100644 +--- a/benchmarks/kernels/benchmark_rope.py ++++ b/benchmarks/kernels/benchmark_rope.py +@@ -1,11 +1,13 @@ +-import argparse + from itertools import accumulate +-from typing import Optional ++from typing import List, Optional + + import nvtx + import torch + +-from vllm.model_executor.layers.rotary_embedding import get_rope ++from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding, ++ get_rope) ++from vllm.platforms import current_platform ++from vllm.utils import FlexibleArgumentParser + + + def benchmark_rope_kernels_multi_lora( +@@ -21,9 +23,7 @@ def benchmark_rope_kernels_multi_lora( + max_position: int = 8192, + base: int = 10000, + ) -> None: +- torch.random.manual_seed(seed) +- if torch.cuda.is_available(): +- torch.cuda.manual_seed(seed) ++ current_platform.seed_everything(seed) + torch.set_default_device(device) + if rotary_dim is None: + rotary_dim = head_size +@@ -32,17 +32,17 @@ def benchmark_rope_kernels_multi_lora( + # batched RoPE can take multiple scaling factors + batched_rope = get_rope(head_size, rotary_dim, max_position, base, + is_neox_style, { +- "type": "linear", ++ "rope_type": "linear", + "factor": tuple(scaling_factors) + }) + # non-batched RoPE takes only one scaling factor, we create multiple + # instances to simulate the same behavior +- non_batched_ropes = [] ++ non_batched_ropes: List[RotaryEmbedding] = [] + for scaling_factor in scaling_factors: + non_batched_ropes.append( + get_rope(head_size, rotary_dim, max_position, base, is_neox_style, + { +- "type": "linear", ++ "rope_type": "linear", + "factor": (scaling_factor, ) + })) + +@@ -85,7 +85,7 @@ def benchmark_rope_kernels_multi_lora( + + + if __name__ == '__main__': +- parser = argparse.ArgumentParser( ++ parser = FlexibleArgumentParser( + description="Benchmark the rotary embedding kernels.") + parser.add_argument("--is-neox-style", type=bool, default=True) + parser.add_argument("--batch-size", type=int, default=16) +@@ -93,7 +93,7 @@ if __name__ == '__main__': + parser.add_argument("--num-heads", type=int, default=8) + parser.add_argument("--head-size", + type=int, +- choices=[64, 80, 96, 112, 128, 256], ++ choices=[64, 80, 96, 112, 120, 128, 192, 256], + default=128) + parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) + parser.add_argument("--dtype", +diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py +new file mode 100644 +index 0000000..4eeeca3 +--- /dev/null ++++ b/benchmarks/kernels/benchmark_shapes.py +@@ -0,0 +1,75 @@ ++WEIGHT_SHAPES = { ++ "ideal": [[4 * 256 * 32, 256 * 32]], ++ "mistralai/Mistral-7B-v0.1/TP1": [ ++ [4096, 6144], ++ [4096, 4096], ++ [4096, 28672], ++ [14336, 4096], ++ ], ++ "mistralai/Mistral-7B-v0.1/TP2": [ ++ [4096, 3072], ++ [2048, 4096], ++ [4096, 14336], ++ [7168, 4096], ++ ], ++ "mistralai/Mistral-7B-v0.1/TP4": [ ++ [4096, 1536], ++ [1024, 4096], ++ [4096, 7168], ++ [3584, 4096], ++ ], ++ "meta-llama/Llama-2-7b-hf/TP1": [ ++ [4096, 12288], ++ [4096, 4096], ++ [4096, 22016], ++ [11008, 4096], ++ ], ++ "meta-llama/Llama-2-7b-hf/TP2": [ ++ [4096, 6144], ++ [2048, 4096], ++ [4096, 11008], ++ [5504, 4096], ++ ], ++ "meta-llama/Llama-2-7b-hf/TP4": [ ++ [4096, 3072], ++ [1024, 4096], ++ [4096, 5504], ++ [2752, 4096], ++ ], ++ "meta-llama/Llama-2-13b-hf/TP1": [ ++ [5120, 15360], ++ [5120, 5120], ++ [5120, 27648], ++ [13824, 5120], ++ ], ++ "meta-llama/Llama-2-13b-hf/TP2": [ ++ [5120, 7680], ++ [2560, 5120], ++ [5120, 13824], ++ [6912, 5120], ++ ], ++ "meta-llama/Llama-2-13b-hf/TP4": [ ++ [5120, 3840], ++ [1280, 5120], ++ [5120, 6912], ++ [3456, 5120], ++ ], ++ "meta-llama/Llama-2-70b-hf/TP1": [ ++ [8192, 10240], ++ [8192, 8192], ++ [8192, 57344], ++ [28672, 8192], ++ ], ++ "meta-llama/Llama-2-70b-hf/TP2": [ ++ [8192, 5120], ++ [4096, 8192], ++ [8192, 28672], ++ [14336, 8192], ++ ], ++ "meta-llama/Llama-2-70b-hf/TP4": [ ++ [8192, 2560], ++ [2048, 8192], ++ [8192, 14336], ++ [7168, 8192], ++ ], ++} +diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py +new file mode 100644 +index 0000000..7d0bd84 +--- /dev/null ++++ b/benchmarks/kernels/graph_machete_bench.py +@@ -0,0 +1,63 @@ ++import math ++import pickle ++import re ++from collections import defaultdict ++from typing import List ++ ++import matplotlib.pyplot as plt ++import pandas as pd ++import seaborn as sns ++from torch.utils.benchmark import Measurement as TMeasurement ++ ++from vllm.utils import FlexibleArgumentParser ++ ++if __name__ == "__main__": ++ parser = FlexibleArgumentParser( ++ description='Benchmark the latency of processing a single batch of ' ++ 'requests till completion.') ++ parser.add_argument('filename', type=str) ++ ++ args = parser.parse_args() ++ ++ with open(args.filename, 'rb') as f: ++ data = pickle.load(f) ++ raw_results: List[TMeasurement] = data["results"] ++ ++ results = defaultdict(lambda: list()) ++ for v in raw_results: ++ result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label) ++ if result is not None: ++ KN = result.group(1) ++ else: ++ raise Exception("MKN not found") ++ result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label) ++ if result is not None: ++ M = result.group(1) ++ else: ++ raise Exception("MKN not found") ++ ++ kernel = v.task_spec.description ++ results[KN].append({ ++ "kernel": kernel, ++ "batch_size": M, ++ "median": v.median ++ }) ++ ++ rows = int(math.ceil(len(results) / 2)) ++ fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows)) ++ axs = axs.flatten() ++ for axs_idx, (shape, data) in enumerate(results.items()): ++ plt.sca(axs[axs_idx]) ++ df = pd.DataFrame(data) ++ sns.lineplot(data=df, ++ x="batch_size", ++ y="median", ++ hue="kernel", ++ style="kernel", ++ markers=True, ++ dashes=False, ++ palette="Dark2") ++ plt.title(f"Shape: {shape}") ++ plt.ylabel("time (median, s)") ++ plt.tight_layout() ++ plt.savefig("graph_machete_bench.pdf") +diff --git a/benchmarks/kernels/requirements.txt b/benchmarks/kernels/requirements.txt +new file mode 100644 +index 0000000..1411a4a +--- /dev/null ++++ b/benchmarks/kernels/requirements.txt +@@ -0,0 +1 @@ ++pandas +\ No newline at end of file +diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py +new file mode 100644 +index 0000000..51f24f3 +--- /dev/null ++++ b/benchmarks/kernels/weight_shapes.py +@@ -0,0 +1,49 @@ ++# Weight Shapes are in the format ++# ([K, N], TP_SPLIT_DIM) ++# Example: ++# A shape of ([14336, 4096], 0) indicates the following GEMM shape, ++# - TP1 : K = 14336, N = 4096 ++# - TP2 : K = 7168, N = 4096 ++# A shape of ([4096, 6144], 1) indicates the following GEMM shape, ++# - TP1 : K = 4096, N = 6144 ++# - TP4 : K = 4096, N = 1536 ++ ++# TP1 shapes ++WEIGHT_SHAPES = { ++ "mistralai/Mistral-7B-v0.1": [ ++ ([4096, 6144], 1), ++ ([4096, 4096], 0), ++ ([4096, 28672], 1), ++ ([14336, 4096], 0), ++ ], ++ "meta-llama/Llama-2-7b-hf": [ ++ ([4096, 12288], 1), ++ ([4096, 4096], 0), ++ ([4096, 22016], 1), ++ ([11008, 4096], 0), ++ ], ++ "meta-llama/Llama-3-8b": [ ++ ([4096, 6144], 1), ++ ([4096, 4096], 0), ++ ([4096, 28672], 1), ++ ([14336, 4096], 0), ++ ], ++ "meta-llama/Llama-2-13b-hf": [ ++ ([5120, 15360], 1), ++ ([5120, 5120], 0), ++ ([5120, 27648], 1), ++ ([13824, 5120], 0), ++ ], ++ "meta-llama/Llama-2-70b-hf": [ ++ ([8192, 10240], 1), ++ ([8192, 8192], 0), ++ ([8192, 57344], 1), ++ ([28672, 8192], 0), ++ ], ++ "meta-llama/Llama-3.1-405b-hf": [ ++ ([16384, 18432], 1), ++ ([16384, 16384], 0), ++ ([16384, 106496], 1), ++ ([53248, 16384], 0), ++ ], ++} +diff --git a/benchmarks/launch_tgi_server.sh b/benchmarks/launch_tgi_server.sh +index 64d3c4f..ba7383d 100755 +--- a/benchmarks/launch_tgi_server.sh ++++ b/benchmarks/launch_tgi_server.sh +@@ -4,13 +4,13 @@ PORT=8000 + MODEL=$1 + TOKENS=$2 + +-docker run --gpus all --shm-size 1g -p $PORT:80 \ +- -v $PWD/data:/data \ +- ghcr.io/huggingface/text-generation-inference:1.4.0 \ +- --model-id $MODEL \ ++docker run -e "HF_TOKEN=$HF_TOKEN" --gpus all --shm-size 1g -p $PORT:80 \ ++ -v "$PWD/data:/data" \ ++ ghcr.io/huggingface/text-generation-inference:2.2.0 \ ++ --model-id "$MODEL" \ + --sharded false \ + --max-input-length 1024 \ + --max-total-tokens 2048 \ + --max-best-of 5 \ + --max-concurrent-requests 5000 \ +- --max-batch-total-tokens $TOKENS ++ --max-batch-total-tokens "$TOKENS" +diff --git a/benchmarks/overheads/benchmark_hashing.py b/benchmarks/overheads/benchmark_hashing.py +new file mode 100644 +index 0000000..d16d6f9 +--- /dev/null ++++ b/benchmarks/overheads/benchmark_hashing.py +@@ -0,0 +1,59 @@ ++import cProfile ++import pstats ++ ++from vllm import LLM, SamplingParams ++from vllm.utils import FlexibleArgumentParser ++ ++# A very long prompt, total number of tokens is about 15k. ++LONG_PROMPT = ["You are an expert in large language models, aren't you?" ++ ] * 1000 ++LONG_PROMPT = ' '.join(LONG_PROMPT) ++ ++ ++def main(args): ++ llm = LLM( ++ model=args.model, ++ enforce_eager=True, ++ enable_prefix_caching=True, ++ tensor_parallel_size=args.tensor_parallel_size, ++ ) ++ ++ sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) ++ profiler = cProfile.Profile() ++ ++ print("------warm up------") ++ for i in range(3): ++ output = llm.generate(LONG_PROMPT, sampling_params) ++ print(output[0].outputs[0].text) ++ ++ print("------start generating------") ++ for i in range(3): ++ profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)', ++ globals(), locals()) ++ ++ # analyze the runtime of hashing function ++ stats = pstats.Stats(profiler) ++ stats.sort_stats('cumulative') ++ total_time = 0 ++ total_calls = 0 ++ for func in stats.stats: ++ if 'hash_of_block' in func[2]: ++ total_time = stats.stats[func][3] ++ total_calls = stats.stats[func][0] ++ percentage = (total_time / stats.total_tt) * 100 ++ print(f"Hashing took {total_time:.2f} seconds," ++ f"{percentage:.2f}% of the total runtime.") ++ ++ ++if __name__ == "__main__": ++ parser = FlexibleArgumentParser( ++ description='Benchmark the performance of hashing function in' ++ 'automatic prefix caching.') ++ parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k') ++ parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) ++ parser.add_argument('--output-len', type=int, default=10) ++ parser.add_argument('--enable-prefix-caching', ++ action='store_true', ++ help='enable prefix caching') ++ args = parser.parse_args() ++ main(args) +diff --git a/benchmarks/structured_schemas/structured_schema_1.json b/benchmarks/structured_schemas/structured_schema_1.json +new file mode 100644 +index 0000000..6003698 +--- /dev/null ++++ b/benchmarks/structured_schemas/structured_schema_1.json +@@ -0,0 +1,113 @@ ++{ ++ "$schema": ++ "https://json-schema.org/draft/2020-12/schema", ++ "title": ++ "User Profile", ++ "type": ++ "object", ++ "properties": { ++ "userId": { ++ "type": "string", ++ "description": "Unique identifier for the user." ++ }, ++ "personalInfo": { ++ "type": "object", ++ "properties": { ++ "firstName": { ++ "type": "string", ++ "description": "The user's first name." ++ }, ++ "lastName": { ++ "type": "string", ++ "description": "The user's last name." ++ }, ++ "age": { ++ "type": "integer", ++ "minimum": 0, ++ "description": "The user's age." ++ }, ++ "phoneNumbers": { ++ "type": ++ "array", ++ "items": { ++ "type": "object", ++ "properties": { ++ "type": { ++ "type": "string", ++ "enum": ["home", "work", "mobile"], ++ "description": "Type of phone number." ++ }, ++ "number": { ++ "type": "string", ++ "pattern": "^\\+?[1-9]\\d{1,14}$", ++ "description": "Phone number in E.164 format." ++ } ++ }, ++ "required": ["type", "number"] ++ }, ++ "description": ++ "List of phone numbers associated with the user." ++ } ++ }, ++ "required": ["firstName", "lastName"] ++ }, ++ "address": { ++ "type": "object", ++ "properties": { ++ "street": { ++ "type": "string", ++ "description": "Street address." ++ }, ++ "city": { ++ "type": "string", ++ "description": "City name." ++ }, ++ "state": { ++ "type": "string", ++ "description": "State or province." ++ }, ++ "postalCode": { ++ "type": "string", ++ "pattern": "^\\d{5}(-\\d{4})?$", ++ "description": "Postal code." ++ }, ++ "country": { ++ "type": "string", ++ "description": "Country name." ++ } ++ }, ++ "required": ["street", "city", "state", "postalCode", "country"] ++ }, ++ "preferences": { ++ "type": "object", ++ "properties": { ++ "newsletterSubscribed": { ++ "type": ++ "boolean", ++ "description": ++ "Indicates if the user is subscribed to the newsletter." ++ }, ++ "favoriteCategories": { ++ "type": "array", ++ "items": { ++ "type": "string" ++ }, ++ "description": "List of user's favorite categories." ++ } ++ }, ++ "required": ["newsletterSubscribed"] ++ }, ++ "accountStatus": { ++ "type": "string", ++ "enum": ["active", "inactive", "suspended"], ++ "description": "Current status of the user's account." ++ }, ++ "registrationDate": { ++ "type": "string", ++ "format": "date-time", ++ "description": "ISO 8601 formatted date-time of user registration." ++ } ++ }, ++ "required": ++ ["userId", "personalInfo", "address", "accountStatus", "registrationDate"] ++} +\ No newline at end of file +diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake +index 0cf3776..714abca 100644 +--- a/cmake/cpu_extension.cmake ++++ b/cmake/cpu_extension.cmake +@@ -1,5 +1,14 @@ ++include(FetchContent) ++ ++set(CMAKE_CXX_STANDARD_REQUIRED ON) ++set(CMAKE_CXX_EXTENSIONS ON) + set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + ++if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") ++ set(MACOSX_FOUND TRUE) ++endif() ++ ++ + # + # Define environment variables for special configurations + # +@@ -9,21 +18,40 @@ endif() + + include_directories("${CMAKE_SOURCE_DIR}/csrc") + ++ ++set (ENABLE_NUMA TRUE) ++ + # + # Check the compile flags + # +-list(APPEND CXX_COMPILE_FLAGS +- "-fopenmp" +- "-DVLLM_CPU_EXTENSION") + +-execute_process(COMMAND cat /proc/cpuinfo +- RESULT_VARIABLE CPUINFO_RET +- OUTPUT_VARIABLE CPUINFO) ++if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64") ++ list(APPEND CXX_COMPILE_FLAGS ++ "-mf16c" ++ ) ++endif() + +-if (NOT CPUINFO_RET EQUAL 0) +- message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo") ++if(MACOSX_FOUND) ++ list(APPEND CXX_COMPILE_FLAGS ++ "-Xpreprocessor" ++ "-fopenmp" ++ "-DVLLM_CPU_EXTENSION") ++else() ++ list(APPEND CXX_COMPILE_FLAGS ++ "-fopenmp" ++ "-DVLLM_CPU_EXTENSION") + endif() + ++if (NOT MACOSX_FOUND) ++ execute_process(COMMAND cat /proc/cpuinfo ++ RESULT_VARIABLE CPUINFO_RET ++ OUTPUT_VARIABLE CPUINFO) ++ if (NOT CPUINFO_RET EQUAL 0) ++ message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo") ++ endif() ++endif() ++ ++ + function (find_isa CPUINFO TARGET OUT) + string(FIND ${CPUINFO} ${TARGET} ISA_FOUND) + if(NOT ISA_FOUND EQUAL -1) +@@ -33,9 +61,30 @@ function (find_isa CPUINFO TARGET OUT) + endif() + endfunction() + +-find_isa(${CPUINFO} "avx512f" AVX512_FOUND) ++function (is_avx512_disabled OUT) ++ set(DISABLE_AVX512 $ENV{VLLM_CPU_DISABLE_AVX512}) ++ if(DISABLE_AVX512 AND DISABLE_AVX512 STREQUAL "true") ++ set(${OUT} ON PARENT_SCOPE) ++ else() ++ set(${OUT} OFF PARENT_SCOPE) ++ endif() ++endfunction() ++ ++is_avx512_disabled(AVX512_DISABLED) + +-if (AVX512_FOUND) ++if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") ++ set(APPLE_SILICON_FOUND TRUE) ++else() ++ find_isa(${CPUINFO} "avx2" AVX2_FOUND) ++ find_isa(${CPUINFO} "avx512f" AVX512_FOUND) ++ find_isa(${CPUINFO} "POWER10" POWER10_FOUND) ++ find_isa(${CPUINFO} "POWER9" POWER9_FOUND) ++ find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support ++ find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support ++endif() ++ ++ ++if (AVX512_FOUND AND NOT AVX512_DISABLED) + list(APPEND CXX_COMPILE_FLAGS + "-mavx512f" + "-mavx512vl" +@@ -44,8 +93,8 @@ if (AVX512_FOUND) + + find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND) + if (AVX512BF16_FOUND OR ENABLE_AVX512BF16) +- if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND +- CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) ++ if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND ++ CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) + list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") + else() + message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3") +@@ -53,16 +102,75 @@ if (AVX512_FOUND) + else() + message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.") + endif() ++ ++elseif (AVX2_FOUND) ++ list(APPEND CXX_COMPILE_FLAGS "-mavx2") ++ message(WARNING "vLLM CPU backend using AVX2 ISA") ++ ++elseif (POWER9_FOUND OR POWER10_FOUND) ++ message(STATUS "PowerPC detected") ++ # Check for PowerPC VSX support ++ list(APPEND CXX_COMPILE_FLAGS ++ "-mvsx" ++ "-mcpu=native" ++ "-mtune=native") ++ ++elseif (ASIMD_FOUND) ++ message(STATUS "ARMv8 or later architecture detected") ++ if(ARM_BF16_FOUND) ++ message(STATUS "BF16 extension detected") ++ set(MARCH_FLAGS "-march=armv8.2-a+bf16+dotprod+fp16") ++ add_compile_definitions(ARM_BF16_SUPPORT) ++ else() ++ message(WARNING "BF16 functionality is not available") ++ set(MARCH_FLAGS "-march=armv8.2-a+dotprod+fp16") ++ endif() ++ list(APPEND CXX_COMPILE_FLAGS ${MARCH_FLAGS}) ++elseif(APPLE_SILICON_FOUND) ++ message(STATUS "Apple Silicon Detected") ++ set(ENABLE_NUMA OFF) + else() +- message(FATAL_ERROR "vLLM CPU backend requires AVX512 ISA support.") ++ message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA or ARMv8 support.") + endif() + +-message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") +- +- + # +-# Define extension targets ++# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms) + # ++if (AVX512_FOUND AND NOT AVX512_DISABLED) ++ FetchContent_Declare( ++ oneDNN ++ GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git ++ GIT_TAG v3.6 ++ GIT_PROGRESS TRUE ++ GIT_SHALLOW TRUE ++ ) ++ ++ set(ONEDNN_LIBRARY_TYPE "STATIC") ++ set(ONEDNN_BUILD_DOC "OFF") ++ set(ONEDNN_BUILD_EXAMPLES "OFF") ++ set(ONEDNN_BUILD_TESTS "OFF") ++ set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") ++ set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") ++ set(ONEDNN_BUILD_GRAPH "OFF") ++ set(ONEDNN_ENABLE_JIT_PROFILING "OFF") ++ set(ONEDNN_ENABLE_ITT_TASKS "OFF") ++ set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") ++ set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") ++ set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) ++ ++ FetchContent_MakeAvailable(oneDNN) ++ ++ list(APPEND LIBS dnnl) ++endif() ++ ++message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") ++ ++if(ENABLE_NUMA) ++ list(APPEND LIBS numa) ++else() ++ message(STATUS "NUMA is disabled") ++ add_compile_definitions(-DVLLM_NUMA_DISABLED) ++endif() + + # + # _C extension +@@ -71,20 +179,30 @@ set(VLLM_EXT_SRC + "csrc/cpu/activation.cpp" + "csrc/cpu/attention.cpp" + "csrc/cpu/cache.cpp" ++ "csrc/cpu/utils.cpp" + "csrc/cpu/layernorm.cpp" + "csrc/cpu/pos_encoding.cpp" +- "csrc/cpu/pybind.cpp") ++ "csrc/cpu/torch_bindings.cpp") ++ ++if (AVX512_FOUND AND NOT AVX512_DISABLED) ++ set(VLLM_EXT_SRC ++ "csrc/cpu/quant.cpp" ++ ${VLLM_EXT_SRC}) ++endif() ++ ++# ++# Define extension targets ++# + + define_gpu_extension_target( + _C + DESTINATION vllm + LANGUAGE CXX + SOURCES ${VLLM_EXT_SRC} ++ LIBRARIES ${LIBS} + COMPILE_FLAGS ${CXX_COMPILE_FLAGS} +- WITH_SOABI ++ USE_SABI 3 ++ WITH_SOABI + ) + +-add_custom_target(default) +-message(STATUS "Enabling C extension.") +-add_dependencies(default _C) +- ++message(STATUS "Enabling C extension.") +\ No newline at end of file +diff --git a/cmake/utils.cmake b/cmake/utils.cmake +index 7c71673..40430da 100644 +--- a/cmake/utils.cmake ++++ b/cmake/utils.cmake +@@ -5,7 +5,7 @@ + macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS) + file(REAL_PATH ${EXECUTABLE} EXECUTABLE) + set(Python_EXECUTABLE ${EXECUTABLE}) +- find_package(Python COMPONENTS Interpreter Development.Module) ++ find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule) + if (NOT Python_FOUND) + message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") + endif() +@@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) + "Failed to determine torch nvcc compiler flags") + + if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) +- list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2") ++ list(APPEND GPU_FLAGS "-DENABLE_FP8") + endif() + if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) + list(REMOVE_ITEM GPU_FLAGS +@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) + + list(APPEND GPU_FLAGS + "-DUSE_ROCM" +- "-DENABLE_FP8_E4M3" ++ "-DENABLE_FP8" + "-U__HIP_NO_HALF_CONVERSIONS__" + "-U__HIP_NO_HALF_OPERATORS__" + "-fno-gpu-rdc") +@@ -133,10 +133,181 @@ macro(string_to_ver OUT_VER IN_STR) + string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR}) + endmacro() + ++# ++# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in ++# `CUDA_ARCH_FLAGS`. ++# ++# Example: ++# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75" ++# clear_cuda_arches(CUDA_ARCH_FLAGS) ++# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75" ++# CMAKE_CUDA_FLAGS="-Wall" ++# ++macro(clear_cuda_arches CUDA_ARCH_FLAGS) ++ # Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS` ++ string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS ++ ${CMAKE_CUDA_FLAGS}) ++ ++ # Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified ++ # and passed back via the `CUDA_ARCHITECTURES` property. ++ string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS ++ ${CMAKE_CUDA_FLAGS}) ++endmacro() ++ ++# ++# Extract unique CUDA architectures from a list of compute capabilities codes in ++# the form `[]`, convert them to the form sort ++# `.`, dedupes them and then sorts them in ascending order and ++# stores them in `OUT_ARCHES`. ++# ++# Example: ++# CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a" ++# extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS) ++# OUT_ARCHES="7.5;...;9.0" ++function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS) ++ set(_CUDA_ARCHES) ++ foreach(_ARCH ${CUDA_ARCH_FLAGS}) ++ string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH}) ++ if (_COMPUTE) ++ set(_COMPUTE ${CMAKE_MATCH_1}) ++ endif() ++ ++ string_to_ver(_COMPUTE_VER ${_COMPUTE}) ++ list(APPEND _CUDA_ARCHES ${_COMPUTE_VER}) ++ endforeach() ++ ++ list(REMOVE_DUPLICATES _CUDA_ARCHES) ++ list(SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING) ++ set(${OUT_ARCHES} ${_CUDA_ARCHES} PARENT_SCOPE) ++endfunction() ++ ++# ++# For a specific file set the `-gencode` flag in compile options conditionally ++# for the CUDA language. ++# ++# Example: ++# set_gencode_flag_for_srcs( ++# SRCS "foo.cu" ++# ARCH "compute_75" ++# CODE "sm_75") ++# adds: "-gencode arch=compute_75,code=sm_75" to the compile options for ++# `foo.cu` (only for the CUDA language). ++# ++macro(set_gencode_flag_for_srcs) ++ set(options) ++ set(oneValueArgs ARCH CODE) ++ set(multiValueArgs SRCS) ++ cmake_parse_arguments(arg "${options}" "${oneValueArgs}" ++ "${multiValueArgs}" ${ARGN} ) ++ set(_FLAG -gencode arch=${arg_ARCH},code=${arg_CODE}) ++ set_property( ++ SOURCE ${arg_SRCS} ++ APPEND PROPERTY ++ COMPILE_OPTIONS "$<$:${_FLAG}>" ++ ) ++ ++ message(DEBUG "Setting gencode flag for ${arg_SRCS}: ${_FLAG}") ++endmacro(set_gencode_flag_for_srcs) ++ ++# ++# For a list of source files set the `-gencode` flags in the files specific ++# compile options (specifically for the CUDA language). ++# ++# arguments are: ++# SRCS: list of source files ++# CUDA_ARCHS: list of CUDA architectures in the form `.[letter]` ++# BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built ++# for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS ++# that is larger than BUILD_PTX_FOR_ARCH. ++# ++macro(set_gencode_flags_for_srcs) ++ set(options) ++ set(oneValueArgs BUILD_PTX_FOR_ARCH) ++ set(multiValueArgs SRCS CUDA_ARCHS) ++ cmake_parse_arguments(arg "${options}" "${oneValueArgs}" ++ "${multiValueArgs}" ${ARGN} ) ++ ++ foreach(_ARCH ${arg_CUDA_ARCHS}) ++ string(REPLACE "." "" _ARCH "${_ARCH}") ++ set_gencode_flag_for_srcs( ++ SRCS ${arg_SRCS} ++ ARCH "compute_${_ARCH}" ++ CODE "sm_${_ARCH}") ++ endforeach() ++ ++ if (${arg_BUILD_PTX_FOR_ARCH}) ++ list(SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) ++ list(GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH) ++ if (_HIGHEST_ARCH VERSION_GREATER_EQUAL ${arg_BUILD_PTX_FOR_ARCH}) ++ string(REPLACE "." "" _PTX_ARCH "${arg_BUILD_PTX_FOR_ARCH}") ++ set_gencode_flag_for_srcs( ++ SRCS ${arg_SRCS} ++ ARCH "compute_${_PTX_ARCH}" ++ CODE "compute_${_PTX_ARCH}") ++ endif() ++ endif() ++endmacro() ++ ++# ++# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form ++# `.[letter]` compute the "loose intersection" with the ++# `TGT_CUDA_ARCHS` list of gencodes. ++# The loose intersection is defined as: ++# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} } ++# where `<=` is the version comparison operator. ++# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version ++# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`. ++# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is ++# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add ++# 9.0a to the result. ++# The result is stored in `OUT_CUDA_ARCHS`. ++# ++# Example: ++# SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a" ++# TGT_CUDA_ARCHS="8.0;8.9;9.0" ++# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) ++# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a" ++# ++function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) ++ list(REMOVE_DUPLICATES SRC_CUDA_ARCHS) ++ ++ # if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should ++ # remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS ++ set(_CUDA_ARCHS) ++ if ("9.0a" IN_LIST SRC_CUDA_ARCHS) ++ list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a") ++ if ("9.0" IN_LIST TGT_CUDA_ARCHS) ++ set(_CUDA_ARCHS "9.0a") ++ endif() ++ endif() ++ ++ list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) ++ ++ # for each ARCH in CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that is ++ # less or eqault to ARCH ++ foreach(_ARCH ${CUDA_ARCHS}) ++ set(_TMP_ARCH) ++ foreach(_SRC_ARCH ${SRC_CUDA_ARCHS}) ++ if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH) ++ set(_TMP_ARCH ${_SRC_ARCH}) ++ else() ++ break() ++ endif() ++ endforeach() ++ if (_TMP_ARCH) ++ list(APPEND _CUDA_ARCHS ${_TMP_ARCH}) ++ endif() ++ endforeach() ++ ++ list(REMOVE_DUPLICATES _CUDA_ARCHS) ++ set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) ++endfunction() ++ + # + # Override the GPU architectures detected by cmake/torch and filter them by + # `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in +-# `GPU_ARCHES`. ++# `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set ++# the architectures on a per file basis. + # + # Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`. + # +@@ -147,16 +318,23 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) + if (${GPU_LANG} STREQUAL "HIP") + # + # `GPU_ARCHES` controls the `--offload-arch` flags. +- # `CMAKE_HIP_ARCHITECTURES` is set up by torch and can be controlled +- # via the `PYTORCH_ROCM_ARCH` env variable. + # +- ++ # If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list, ++ # if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling ++ # "rocm_agent_enumerator" in "enable_language(HIP)" ++ # (in file Modules/CMakeDetermineHIPCompiler.cmake) ++ # ++ if(DEFINED ENV{PYTORCH_ROCM_ARCH}) ++ set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH}) ++ else() ++ set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES}) ++ endif() + # + # Find the intersection of the supported + detected architectures to + # set the module architecture flags. + # + set(${GPU_ARCHES}) +- foreach (_ARCH ${CMAKE_HIP_ARCHITECTURES}) ++ foreach (_ARCH ${HIP_ARCHITECTURES}) + if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST) + list(APPEND ${GPU_ARCHES} ${_ARCH}) + endif() +@@ -164,112 +342,10 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) + + if(NOT ${GPU_ARCHES}) + message(FATAL_ERROR +- "None of the detected ROCm architectures: ${CMAKE_HIP_ARCHITECTURES} is" ++ "None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is" + " supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.") + endif() +- +- elseif(${GPU_LANG} STREQUAL "CUDA") +- # +- # Setup/process CUDA arch flags. +- # +- # The torch cmake setup hardcodes the detected architecture flags in +- # `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it +- # can't modified on a per-target basis, e.g. for the `punica` extension. +- # So, all the `-gencode` flags need to be extracted and removed from +- # `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method. +- # Since it's not possible to use `target_compiler_options` for adding target +- # specific `-gencode` arguments, the target's `CUDA_ARCHITECTURES` property +- # must be used instead. This requires repackaging the architecture flags +- # into a format that cmake expects for `CUDA_ARCHITECTURES`. +- # +- # This is a bit fragile in that it depends on torch using `-gencode` as opposed +- # to one of the other nvcc options to specify architectures. +- # +- # Note: torch uses the `TORCH_CUDA_ARCH_LIST` environment variable to override +- # detected architectures. +- # +- message(DEBUG "initial CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}") +- +- # Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS` +- string(REGEX MATCHALL "-gencode arch=[^ ]+" _CUDA_ARCH_FLAGS +- ${CMAKE_CUDA_FLAGS}) +- +- # Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified +- # and passed back via the `CUDA_ARCHITECTURES` property. +- string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS +- ${CMAKE_CUDA_FLAGS}) +- +- # If this error is triggered, it might mean that torch has changed how it sets +- # up nvcc architecture code generation flags. +- if (NOT _CUDA_ARCH_FLAGS) +- message(FATAL_ERROR +- "Could not find any architecture related code generation flags in " +- "CMAKE_CUDA_FLAGS. (${CMAKE_CUDA_FLAGS})") +- endif() +- +- message(DEBUG "final CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}") +- message(DEBUG "arch flags: ${_CUDA_ARCH_FLAGS}") +- +- # Initialize the architecture lists to empty. +- set(${GPU_ARCHES}) +- +- # Process each `gencode` flag. +- foreach(_ARCH ${_CUDA_ARCH_FLAGS}) +- # For each flag, extract the version number and whether it refers to PTX +- # or native code. +- # Note: if a regex matches then `CMAKE_MATCH_1` holds the binding +- # for that match. +- +- string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH}) +- if (_COMPUTE) +- set(_COMPUTE ${CMAKE_MATCH_1}) +- endif() +- +- string(REGEX MATCH "code=sm_\([0-9]+a?\)" _SM ${_ARCH}) +- if (_SM) +- set(_SM ${CMAKE_MATCH_1}) +- endif() +- +- string(REGEX MATCH "code=compute_\([0-9]+a?\)" _CODE ${_ARCH}) +- if (_CODE) +- set(_CODE ${CMAKE_MATCH_1}) +- endif() +- +- # Make sure the virtual architecture can be matched. +- if (NOT _COMPUTE) +- message(FATAL_ERROR +- "Could not determine virtual architecture from: ${_ARCH}.") +- endif() +- +- # One of sm_ or compute_ must exist. +- if ((NOT _SM) AND (NOT _CODE)) +- message(FATAL_ERROR +- "Could not determine a codegen architecture from: ${_ARCH}.") +- endif() +- +- if (_SM) +- # -real suffix let CMake to only generate elf code for the kernels. +- # we want this, otherwise the added ptx (default) will increase binary size. +- set(_VIRT "-real") +- set(_CODE_ARCH ${_SM}) +- else() +- # -virtual suffix let CMake to generate ptx code for the kernels. +- set(_VIRT "-virtual") +- set(_CODE_ARCH ${_CODE}) +- endif() +- +- # Check if the current version is in the supported arch list. +- string_to_ver(_CODE_VER ${_CODE_ARCH}) +- if (NOT _CODE_VER IN_LIST _GPU_SUPPORTED_ARCHES_LIST) +- message(STATUS "discarding unsupported CUDA arch ${_VER}.") +- continue() +- endif() +- +- # Add it to the arch list. +- list(APPEND ${GPU_ARCHES} "${_CODE_ARCH}${_VIRT}") +- endforeach() + endif() +- message(STATUS "${GPU_LANG} target arches: ${${GPU_ARCHES}}") + endmacro() + + # +@@ -294,6 +370,7 @@ endmacro() + # INCLUDE_DIRECTORIES - Extra include directories. + # LIBRARIES - Extra link libraries. + # WITH_SOABI - Generate library with python SOABI suffix name. ++# USE_SABI - Use python stable api + # + # Note: optimization level/debug info is set via cmake build type. + # +@@ -301,7 +378,7 @@ function (define_gpu_extension_target GPU_MOD_NAME) + cmake_parse_arguments(PARSE_ARGV 1 + GPU + "WITH_SOABI" +- "DESTINATION;LANGUAGE" ++ "DESTINATION;LANGUAGE;USE_SABI" + "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES") + + # Add hipify preprocessing step when building with HIP/ROCm. +@@ -315,7 +392,11 @@ function (define_gpu_extension_target GPU_MOD_NAME) + set(GPU_WITH_SOABI) + endif() + +- Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI}) ++ if (GPU_USE_SABI) ++ Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}") ++ else() ++ Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}") ++ endif() + + if (GPU_LANGUAGE STREQUAL "HIP") + # Make this target dependent on the hipify preprocessor step. +@@ -338,17 +419,15 @@ function (define_gpu_extension_target GPU_MOD_NAME) + target_include_directories(${GPU_MOD_NAME} PRIVATE csrc + ${GPU_INCLUDE_DIRECTORIES}) + +- target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${torch_python_LIBRARY} +- ${GPU_LIBRARIES}) ++ target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES}) + + # Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of + # dependencies that are not necessary and may not be installed. + if (GPU_LANGUAGE STREQUAL "CUDA") +- target_link_libraries(${GPU_MOD_NAME} PRIVATE ${CUDA_CUDA_LIB} +- ${CUDA_LIBRARIES}) ++ target_link_libraries(${GPU_MOD_NAME} PRIVATE CUDA::cudart CUDA::cuda_driver) + else() + target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES}) + endif() + +- install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION}) ++ install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME}) + endfunction() +diff --git a/collect_env.py b/collect_env.py +index 1ecfeb8..254c19b 100644 +--- a/collect_env.py ++++ b/collect_env.py +@@ -1,17 +1,19 @@ + # ruff: noqa + # code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py + +-# Unlike the rest of the PyTorch this file must be python2 compliant. +-# This script outputs relevant system environment info +-# Run it with `python collect_env.py` or `python -m torch.utils.collect_env` + import datetime + import locale + import os + import re + import subprocess + import sys ++# Unlike the rest of the PyTorch this file must be python2 compliant. ++# This script outputs relevant system environment info ++# Run it with `python collect_env.py` or `python -m torch.utils.collect_env` + from collections import namedtuple + ++from vllm.envs import environment_variables ++ + try: + import torch + TORCH_AVAILABLE = True +@@ -52,6 +54,7 @@ SystemEnv = namedtuple( + 'vllm_version', # vllm specific field + 'vllm_build_flags', # vllm specific field + 'gpu_topo', # vllm specific field ++ 'env_vars', + ]) + + DEFAULT_CONDA_PATTERNS = { +@@ -64,6 +67,10 @@ DEFAULT_CONDA_PATTERNS = { + "triton", + "optree", + "nccl", ++ "transformers", ++ "zmq", ++ "nvidia", ++ "pynvml", + } + + DEFAULT_PIP_PATTERNS = { +@@ -75,6 +82,10 @@ DEFAULT_PIP_PATTERNS = { + "optree", + "onnx", + "nccl", ++ "transformers", ++ "zmq", ++ "nvidia", ++ "pynvml", + } + + +@@ -259,12 +270,16 @@ def get_neuron_sdk_version(run_lambda): + + + def get_vllm_version(): +- try: +- import vllm +- return vllm.__version__ +- except ImportError: +- return 'N/A' ++ from vllm import __version__, __version_tuple__ ++ ++ if __version__ == "dev": ++ return "N/A (dev)" + ++ if len(__version_tuple__) == 4: # dev build ++ git_sha = __version_tuple__[-1][1:] # type: ignore ++ return f"{__version__} (git sha: {git_sha}" ++ ++ return __version__ + + def summarize_vllm_build_flags(): + # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. +@@ -276,9 +291,14 @@ def summarize_vllm_build_flags(): + + + def get_gpu_topo(run_lambda): ++ output = None ++ + if get_platform() == 'linux': +- return run_and_read_all(run_lambda, 'nvidia-smi topo -m') +- return None ++ output = run_and_read_all(run_lambda, 'nvidia-smi topo -m') ++ if output is None: ++ output = run_and_read_all(run_lambda, 'rocm-smi --showtopo') ++ ++ return output + + + # example outputs of CPU infos +@@ -495,6 +515,22 @@ def is_xnnpack_available(): + else: + return "N/A" + ++def get_env_vars(): ++ env_vars = '' ++ secret_terms=('secret', 'token', 'api', 'access', 'password') ++ report_prefix = ("TORCH", "NCCL", "PYTORCH", ++ "CUDA", "CUBLAS", "CUDNN", ++ "OMP_", "MKL_", ++ "NVIDIA") ++ for k, v in os.environ.items(): ++ if any(term in k.lower() for term in secret_terms): ++ continue ++ if k in environment_variables: ++ env_vars = env_vars + "{}={}".format(k, v) + "\n" ++ if k.startswith(report_prefix): ++ env_vars = env_vars + "{}={}".format(k, v) + "\n" ++ ++ return env_vars + + def get_env_info(): + run_lambda = run +@@ -566,6 +602,7 @@ def get_env_info(): + vllm_version=vllm_version, + vllm_build_flags=vllm_build_flags, + gpu_topo=gpu_topo, ++ env_vars=get_env_vars(), + ) + + +@@ -601,6 +638,11 @@ Versions of relevant libraries: + {conda_packages} + """.strip() + ++# both the above code and the following code use `strip()` to ++# remove leading/trailing whitespaces, so we need to add a newline ++# in between to separate the two sections ++env_info_fmt += "\n" ++ + env_info_fmt += """ + ROCM Version: {rocm_version} + Neuron SDK Version: {neuron_sdk_version} +@@ -609,6 +651,8 @@ vLLM Build Flags: + {vllm_build_flags} + GPU Topology: + {gpu_topo} ++ ++{env_vars} + """.strip() + + +diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu +index 24d9727..839dc36 100644 +--- a/csrc/activation_kernels.cu ++++ b/csrc/activation_kernels.cu +@@ -1,5 +1,5 @@ + #include +-#include ++#include + #include + + #include +@@ -10,11 +10,11 @@ + namespace vllm { + + // Activation and gating kernel template. +-template ++template + __global__ void act_and_mul_kernel( +- scalar_t* __restrict__ out, // [..., d] +- const scalar_t* __restrict__ input, // [..., 2, d] +- const int d) { ++ scalar_t* __restrict__ out, // [..., d] ++ const scalar_t* __restrict__ input, // [..., 2, d] ++ const int d) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); +@@ -23,84 +23,120 @@ __global__ void act_and_mul_kernel( + } + } + +-template ++template + __device__ __forceinline__ T silu_kernel(const T& x) { + // x * sigmoid(x) +- return (T) (((float) x) / (1.0f + expf((float) -x))); ++ return (T)(((float)x) / (1.0f + expf((float)-x))); + } + +-template ++template + __device__ __forceinline__ T gelu_kernel(const T& x) { + // Equivalent to PyTorch GELU with 'none' approximation. + // Refer to: + // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 +- const float f = (float) x; ++ const float f = (float)x; + constexpr float ALPHA = M_SQRT1_2; +- return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA))); ++ return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA))); + } + +-template ++template + __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { + // Equivalent to PyTorch GELU with 'tanh' approximation. + // Refer to: + // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 +- const float f = (float) x; ++ const float f = (float)x; + constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; + constexpr float KAPPA = 0.044715; + float x_cube = f * f * f; + float inner = BETA * (f + KAPPA * x_cube); +- return (T) (0.5f * f * (1.0f + ::tanhf(inner))); ++ return (T)(0.5f * f * (1.0f + ::tanhf(inner))); + } + +-} // namespace vllm ++} // namespace vllm + + // Launch activation and gating kernel. +-#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ +- int d = input.size(-1) / 2; \ +- int64_t num_tokens = input.numel() / input.size(-1); \ +- dim3 grid(num_tokens); \ +- dim3 block(std::min(d, 1024)); \ +- const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ +- const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ +- VLLM_DISPATCH_FLOATING_TYPES( \ +- input.scalar_type(), \ +- "act_and_mul_kernel", \ +- [&] { \ +- vllm::act_and_mul_kernel><<>>( \ +- out.data_ptr(), \ +- input.data_ptr(), \ +- d); \ +- }); +- +-void silu_and_mul( +- torch::Tensor& out, // [..., d] +- torch::Tensor& input) // [..., 2 * d] ++#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ ++ int d = input.size(-1) / 2; \ ++ int64_t num_tokens = input.numel() / input.size(-1); \ ++ dim3 grid(num_tokens); \ ++ dim3 block(std::min(d, 1024)); \ ++ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ ++ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ ++ VLLM_DISPATCH_FLOATING_TYPES( \ ++ input.scalar_type(), "act_and_mul_kernel", [&] { \ ++ vllm::act_and_mul_kernel> \ ++ <<>>(out.data_ptr(), \ ++ input.data_ptr(), d); \ ++ }); ++ ++void silu_and_mul(torch::Tensor& out, // [..., d] ++ torch::Tensor& input) // [..., 2 * d] + { + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); + } + +-void gelu_and_mul( +- torch::Tensor& out, // [..., d] +- torch::Tensor& input) // [..., 2 * d] ++void gelu_and_mul(torch::Tensor& out, // [..., d] ++ torch::Tensor& input) // [..., 2 * d] + { + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); + } + +-void gelu_tanh_and_mul( +- torch::Tensor& out, // [..., d] +- torch::Tensor& input) // [..., 2 * d] ++void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] ++ torch::Tensor& input) // [..., 2 * d] + { + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); + } + + namespace vllm { + ++template ++__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) { ++ const float f = (float)x; ++ return (T)(f > threshold ? f : 0.0f); ++} ++ ++template ++__global__ void act_and_mul_kernel_with_param( ++ scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d, ++ const float param) { ++ const int64_t token_idx = blockIdx.x; ++ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { ++ const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); ++ const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); ++ out[token_idx * d + idx] = ACT_FN(x, param) * y; ++ } ++} ++ ++} // namespace vllm ++ ++#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \ ++ int d = input.size(-1) / 2; \ ++ int64_t num_tokens = input.numel() / input.size(-1); \ ++ dim3 grid(num_tokens); \ ++ dim3 block(std::min(d, 1024)); \ ++ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ ++ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ ++ VLLM_DISPATCH_FLOATING_TYPES( \ ++ input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \ ++ vllm::act_and_mul_kernel_with_param> \ ++ <<>>(out.data_ptr(), \ ++ input.data_ptr(), d, \ ++ PARAM); \ ++ }); ++ ++void fatrelu_and_mul(torch::Tensor& out, // [..., d], ++ torch::Tensor& input, // [..., 2 * d] ++ double threshold) { ++ LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold); ++} ++namespace vllm { ++ + // Element-wise activation kernel template. +-template ++template + __global__ void activation_kernel( +- scalar_t* __restrict__ out, // [..., d] +- const scalar_t* __restrict__ input, // [..., d] +- const int d) { ++ scalar_t* __restrict__ out, // [..., d] ++ const scalar_t* __restrict__ input, // [..., d] ++ const int d) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); +@@ -108,54 +144,61 @@ __global__ void activation_kernel( + } + } + +-} // namespace vllm ++} // namespace vllm + + // Launch element-wise activation kernel. +-#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ +- int d = input.size(-1); \ +- int64_t num_tokens = input.numel() / d; \ +- dim3 grid(num_tokens); \ +- dim3 block(std::min(d, 1024)); \ +- const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ +- const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ +- VLLM_DISPATCH_FLOATING_TYPES( \ +- input.scalar_type(), \ +- "activation_kernel", \ +- [&] { \ +- vllm::activation_kernel><<>>( \ +- out.data_ptr(), \ +- input.data_ptr(), \ +- d); \ +- }); ++#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ ++ int d = input.size(-1); \ ++ int64_t num_tokens = input.numel() / d; \ ++ dim3 grid(num_tokens); \ ++ dim3 block(std::min(d, 1024)); \ ++ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ ++ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ ++ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \ ++ vllm::activation_kernel> \ ++ <<>>(out.data_ptr(), \ ++ input.data_ptr(), d); \ ++ }); + + namespace vllm { + +-template ++template + __device__ __forceinline__ T gelu_new_kernel(const T& x) { +- const float x3 = (float) (x * x * x); +- const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3)))); +- return ((T) 0.5) * x * (((T) 1.0) + t); ++ const float x3 = (float)(x * x * x); ++ const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); ++ return ((T)0.5) * x * (((T)1.0) + t); + } + +-template ++template + __device__ __forceinline__ T gelu_fast_kernel(const T& x) { +- const float f = (float) x; +- const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); +- return ((T) 0.5) * x * (((T) 1.0) + t); ++ const float f = (float)x; ++ const T t = ++ (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); ++ return ((T)0.5) * x * (((T)1.0) + t); ++} ++ ++template ++__device__ __forceinline__ T gelu_quick_kernel(const T& x) { ++ // x * sigmoid(1.702 * x) ++ return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x))); + } + +-} // namespace vllm ++} // namespace vllm + +-void gelu_new( +- torch::Tensor& out, // [..., d] +- torch::Tensor& input) // [..., d] ++void gelu_new(torch::Tensor& out, // [..., d] ++ torch::Tensor& input) // [..., d] + { + LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); + } + +-void gelu_fast( +- torch::Tensor& out, // [..., d] +- torch::Tensor& input) // [..., d] ++void gelu_fast(torch::Tensor& out, // [..., d] ++ torch::Tensor& input) // [..., d] + { + LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); + } ++ ++void gelu_quick(torch::Tensor& out, // [..., d] ++ torch::Tensor& input) // [..., d] ++{ ++ LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel); ++} +diff --git a/csrc/attention/attention_generic.cuh b/csrc/attention/attention_generic.cuh +index 31fb401..62409c0 100644 +--- a/csrc/attention/attention_generic.cuh ++++ b/csrc/attention/attention_generic.cuh +@@ -1,5 +1,6 @@ + /* +- * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h ++ * Adapted from ++ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * +@@ -22,31 +23,31 @@ + namespace vllm { + + // A vector type to store Q, K, V elements. +-template ++template + struct Vec {}; + + // A vector type to store FP32 accumulators. +-template ++template + struct FloatVec {}; + + // Template vector operations. +-template ++template + inline __device__ Acc mul(A a, B b); + +-template ++template + inline __device__ float sum(T v); + +-template ++template + inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); + } + +-template ++template + inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); + } + +-template ++template + inline __device__ void zero(T& dst) { + constexpr int WORDS = sizeof(T) / 4; + union { +@@ -61,4 +62,4 @@ inline __device__ void zero(T& dst) { + dst = tmp.raw; + } + +-} // namespace vllm ++} // namespace vllm +diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh +new file mode 100644 +index 0000000..563e143 +--- /dev/null ++++ b/csrc/attention/attention_kernels.cuh +@@ -0,0 +1,676 @@ ++/* ++ * Adapted from ++ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp ++ * Copyright (c) 2023, The vLLM team. ++ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++#include ++#include ++#include ++#include ++ ++#include "attention_dtypes.h" ++#include "attention_utils.cuh" ++ ++#ifdef USE_ROCM ++ #include ++ #include "../quantization/fp8/amd/quant_utils.cuh" ++typedef __hip_bfloat16 __nv_bfloat16; ++#else ++ #include "../quantization/fp8/nvidia/quant_utils.cuh" ++#endif ++ ++#ifndef USE_ROCM ++ #define WARP_SIZE 32 ++#else ++ #define WARP_SIZE warpSize ++#endif ++ ++#define MAX(a, b) ((a) > (b) ? (a) : (b)) ++#define MIN(a, b) ((a) < (b) ? (a) : (b)) ++#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) ++ ++namespace vllm { ++ ++// Utility function for attention softmax. ++template ++inline __device__ float block_sum(float* red_smem, float sum) { ++ // Decompose the thread index into warp / lane. ++ int warp = threadIdx.x / WARP_SIZE; ++ int lane = threadIdx.x % WARP_SIZE; ++ ++ // Compute the sum per warp. ++#pragma unroll ++ for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { ++ sum += VLLM_SHFL_XOR_SYNC(sum, mask); ++ } ++ ++ // Warp leaders store the data to shared memory. ++ if (lane == 0) { ++ red_smem[warp] = sum; ++ } ++ ++ // Make sure the data is in shared memory. ++ __syncthreads(); ++ ++ // The warps compute the final sums. ++ if (lane < NUM_WARPS) { ++ sum = red_smem[lane]; ++ } ++ ++ // Parallel reduction inside the warp. ++#pragma unroll ++ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { ++ sum += VLLM_SHFL_XOR_SYNC(sum, mask); ++ } ++ ++ // Broadcast to other threads. ++ return VLLM_SHFL_SYNC(sum, 0); ++} ++ ++// TODO(woosuk): Merge the last two dimensions of the grid. ++// Grid: (num_heads, num_seqs, max_num_partitions). ++template // Zero means no partitioning. ++__device__ void paged_attention_kernel( ++ float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] ++ float* __restrict__ max_logits, // [num_seqs, num_heads, ++ // max_num_partitions] ++ scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, ++ // head_size] ++ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] ++ const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, ++ // head_size/x, block_size, x] ++ const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, ++ // head_size, block_size] ++ const int num_kv_heads, // [num_heads] ++ const float scale, ++ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] ++ const int* __restrict__ seq_lens, // [num_seqs] ++ const int max_num_blocks_per_seq, ++ const float* __restrict__ alibi_slopes, // [num_heads] ++ const int q_stride, const int kv_block_stride, const int kv_head_stride, ++ const float k_scale, const float v_scale, const int tp_rank, ++ const int blocksparse_local_blocks, const int blocksparse_vert_stride, ++ const int blocksparse_block_size, const int blocksparse_head_sliding_step) { ++ const int seq_idx = blockIdx.y; ++ const int partition_idx = blockIdx.z; ++ const int max_num_partitions = gridDim.z; ++ constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; ++ const int seq_len = seq_lens[seq_idx]; ++ if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { ++ // No work to do. Terminate the thread block. ++ return; ++ } ++ ++ const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); ++ const int num_blocks_per_partition = ++ USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; ++ ++ // [start_block_idx, end_block_idx) is the range of blocks to process. ++ const int start_block_idx = ++ USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; ++ const int end_block_idx = ++ MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); ++ const int num_blocks = end_block_idx - start_block_idx; ++ ++ // [start_token_idx, end_token_idx) is the range of tokens to process. ++ const int start_token_idx = start_block_idx * BLOCK_SIZE; ++ const int end_token_idx = ++ MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); ++ const int num_tokens = end_token_idx - start_token_idx; ++ ++ constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); ++ constexpr int NUM_THREAD_GROUPS = ++ NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE ++ // divides NUM_THREADS ++ assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); ++ constexpr int NUM_TOKENS_PER_THREAD_GROUP = ++ DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); ++ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; ++ const int thread_idx = threadIdx.x; ++ const int warp_idx = thread_idx / WARP_SIZE; ++ const int lane = thread_idx % WARP_SIZE; ++ ++ const int head_idx = blockIdx.x; ++ const int num_heads = gridDim.x; ++ const int num_queries_per_kv = num_heads / num_kv_heads; ++ const int kv_head_idx = head_idx / num_queries_per_kv; ++ const float alibi_slope = ++ alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; ++ ++ // A vector type to store a part of a key or a query. ++ // The vector size is configured in such a way that the threads in a thread ++ // group fetch or compute 16 bytes at a time. For example, if the size of a ++ // thread group is 4 and the data type is half, then the vector size is 16 / ++ // (4 * sizeof(half)) == 2. ++ constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); ++ using K_vec = typename Vec::Type; ++ using Q_vec = typename Vec::Type; ++ using Quant_vec = typename Vec::Type; ++ ++ constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; ++ constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; ++ ++ const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; ++ const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; ++ ++ // Load the query to registers. ++ // Each thread in a thread group has a different part of the query. ++ // For example, if the the thread group size is 4, then the first thread in ++ // the group has 0, 4, 8, ... th vectors of the query, and the second thread ++ // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because ++ // q is split from a qkv tensor, it may not be contiguous. ++ const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; ++ __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; ++#pragma unroll ++ for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; ++ i += NUM_THREAD_GROUPS) { ++ const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; ++ q_vecs[thread_group_offset][i] = ++ *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); ++ } ++ __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a ++ // memory wall right before we use q_vecs ++ ++ // Memory planning. ++ extern __shared__ char shared_mem[]; ++ // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. ++ float* logits = reinterpret_cast(shared_mem); ++ // Workspace for reduction. ++ __shared__ float red_smem[2 * NUM_WARPS]; ++ ++ // x == THREAD_GROUP_SIZE * VEC_SIZE ++ // Each thread group fetches x elements from the key at a time. ++ constexpr int x = 16 / sizeof(cache_t); ++ float qk_max = -FLT_MAX; ++ ++ // Iterate over the key blocks. ++ // Each warp fetches a block of keys for each iteration. ++ // Each thread group in a warp fetches a key from the block, and computes ++ // dot product with the query. ++ const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; ++ ++ // blocksparse specific vars ++ int bs_block_offset; ++ int q_bs_block_id; ++ if constexpr (IS_BLOCK_SPARSE) { ++ // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len, ++ // blocksparse_block_size); ++ q_bs_block_id = (seq_len - 1) / blocksparse_block_size; ++ if (blocksparse_head_sliding_step >= 0) ++ // sliding on q heads ++ bs_block_offset = ++ (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; ++ else ++ // sliding on kv heads ++ bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * ++ (-blocksparse_head_sliding_step) + ++ 1; ++ } ++ ++ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; ++ block_idx += NUM_WARPS) { ++ // NOTE(woosuk): The block number is stored in int32. However, we cast it to ++ // int64 because int32 can lead to overflow when this variable is multiplied ++ // by large numbers (e.g., kv_block_stride). ++ // For blocksparse attention: skip computation on blocks that are not ++ // attended ++ if constexpr (IS_BLOCK_SPARSE) { ++ const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; ++ const bool is_remote = ++ ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0); ++ const bool is_local = ++ (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); ++ if (!is_remote && !is_local) { ++ for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { ++ const int physical_block_offset = ++ (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; ++ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; ++ ++ if (thread_group_offset == 0) { ++ // NOTE(linxihui): assign very large number to skipped tokens to ++ // avoid contribution to the sumexp softmax normalizer. This will ++ // not be used at computing sum(softmax*v) as the blocks will be ++ // skipped. ++ logits[token_idx - start_token_idx] = -FLT_MAX; ++ } ++ } ++ continue; ++ } ++ } ++ const int64_t physical_block_number = ++ static_cast(block_table[block_idx]); ++ ++ // Load a key to registers. ++ // Each thread in a thread group has a different part of the key. ++ // For example, if the the thread group size is 4, then the first thread in ++ // the group has 0, 4, 8, ... th vectors of the key, and the second thread ++ // has 1, 5, 9, ... th vectors of the key, and so on. ++ for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { ++ const int physical_block_offset = ++ (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; ++ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; ++ K_vec k_vecs[NUM_VECS_PER_THREAD]; ++ ++#pragma unroll ++ for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { ++ const cache_t* k_ptr = ++ k_cache + physical_block_number * kv_block_stride + ++ kv_head_idx * kv_head_stride + physical_block_offset * x; ++ const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; ++ const int offset1 = (vec_idx * VEC_SIZE) / x; ++ const int offset2 = (vec_idx * VEC_SIZE) % x; ++ ++ if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { ++ k_vecs[j] = *reinterpret_cast( ++ k_ptr + offset1 * BLOCK_SIZE * x + offset2); ++ } else { ++ // Vector conversion from Quant_vec to K_vec. ++ Quant_vec k_vec_quant = *reinterpret_cast( ++ k_ptr + offset1 * BLOCK_SIZE * x + offset2); ++ k_vecs[j] = fp8::scaled_convert( ++ k_vec_quant, k_scale); ++ } ++ } ++ ++ // Compute dot product. ++ // This includes a reduction across the threads in the same thread group. ++ float qk = scale * Qk_dot::dot( ++ q_vecs[thread_group_offset], k_vecs); ++ // Add the ALiBi bias if slopes are given. ++ qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; ++ ++ if (thread_group_offset == 0) { ++ // Store the partial reductions to shared memory. ++ // NOTE(woosuk): It is required to zero out the masked logits. ++ const bool mask = token_idx >= seq_len; ++ logits[token_idx - start_token_idx] = mask ? 0.f : qk; ++ // Update the max value. ++ qk_max = mask ? qk_max : fmaxf(qk_max, qk); ++ } ++ } ++ } ++ ++ // Perform reduction across the threads in the same warp to get the ++ // max qk value for each "warp" (not across the thread block yet). ++ // The 0-th thread of each thread group already has its max qk value. ++#pragma unroll ++ for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { ++ qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); ++ } ++ if (lane == 0) { ++ red_smem[warp_idx] = qk_max; ++ } ++ __syncthreads(); ++ ++ // TODO(woosuk): Refactor this part. ++ // Get the max qk value for the sequence. ++ qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; ++#pragma unroll ++ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { ++ qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); ++ } ++ // Broadcast the max qk value to all threads. ++ qk_max = VLLM_SHFL_SYNC(qk_max, 0); ++ ++ // Get the sum of the exp values. ++ float exp_sum = 0.f; ++ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { ++ float val = __expf(logits[i] - qk_max); ++ logits[i] = val; ++ exp_sum += val; ++ } ++ exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); ++ ++ // Compute softmax. ++ const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); ++ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { ++ logits[i] *= inv_sum; ++ } ++ __syncthreads(); ++ ++ // If partitioning is enabled, store the max logit and exp_sum. ++ if (USE_PARTITIONING && thread_idx == 0) { ++ float* max_logits_ptr = max_logits + ++ seq_idx * num_heads * max_num_partitions + ++ head_idx * max_num_partitions + partition_idx; ++ *max_logits_ptr = qk_max; ++ float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + ++ head_idx * max_num_partitions + partition_idx; ++ *exp_sums_ptr = exp_sum; ++ } ++ ++ // Each thread will fetch 16 bytes from the value cache at a time. ++ constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); ++ using V_vec = typename Vec::Type; ++ using L_vec = typename Vec::Type; ++ using V_quant_vec = typename Vec::Type; ++ using Float_L_vec = typename FloatVec::Type; ++ ++ constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; ++ constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; ++ constexpr int NUM_ROWS_PER_THREAD = ++ DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); ++ ++ // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. ++ float accs[NUM_ROWS_PER_THREAD]; ++#pragma unroll ++ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { ++ accs[i] = 0.f; ++ } ++ ++ scalar_t zero_value; ++ zero(zero_value); ++ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; ++ block_idx += NUM_WARPS) { ++ // NOTE(woosuk): The block number is stored in int32. However, we cast it to ++ // int64 because int32 can lead to overflow when this variable is multiplied ++ // by large numbers (e.g., kv_block_stride). ++ // For blocksparse attention: skip computation on blocks that are not ++ // attended ++ if constexpr (IS_BLOCK_SPARSE) { ++ int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; ++ if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && ++ !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { ++ continue; ++ } ++ } ++ const int64_t physical_block_number = ++ static_cast(block_table[block_idx]); ++ const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; ++ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; ++ L_vec logits_vec; ++ from_float(logits_vec, *reinterpret_cast(logits + token_idx - ++ start_token_idx)); ++ ++ const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + ++ kv_head_idx * kv_head_stride; ++#pragma unroll ++ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { ++ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; ++ if (row_idx < HEAD_SIZE) { ++ const int offset = row_idx * BLOCK_SIZE + physical_block_offset; ++ V_vec v_vec; ++ ++ if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { ++ v_vec = *reinterpret_cast(v_ptr + offset); ++ } else { ++ V_quant_vec v_quant_vec = ++ *reinterpret_cast(v_ptr + offset); ++ // Vector conversion from V_quant_vec to V_vec. ++ v_vec = fp8::scaled_convert(v_quant_vec, ++ v_scale); ++ } ++ if (block_idx == num_seq_blocks - 1) { ++ // NOTE(woosuk): When v_vec contains the tokens that are out of the ++ // context, we should explicitly zero out the values since they may ++ // contain NaNs. See ++ // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 ++ scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); ++#pragma unroll ++ for (int j = 0; j < V_VEC_SIZE; j++) { ++ v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; ++ } ++ } ++ accs[i] += dot(logits_vec, v_vec); ++ } ++ } ++ } ++ ++ // Perform reduction within each warp. ++#pragma unroll ++ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { ++ float acc = accs[i]; ++#pragma unroll ++ for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { ++ acc += VLLM_SHFL_XOR_SYNC(acc, mask); ++ } ++ accs[i] = acc; ++ } ++ ++ // NOTE(woosuk): A barrier is required because the shared memory space for ++ // logits is reused for the output. ++ __syncthreads(); ++ ++ // Perform reduction across warps. ++ float* out_smem = reinterpret_cast(shared_mem); ++#pragma unroll ++ for (int i = NUM_WARPS; i > 1; i /= 2) { ++ int mid = i / 2; ++ // Upper warps write to shared memory. ++ if (warp_idx >= mid && warp_idx < i) { ++ float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; ++#pragma unroll ++ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { ++ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; ++ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { ++ dst[row_idx] = accs[i]; ++ } ++ } ++ } ++ __syncthreads(); ++ ++ // Lower warps update the output. ++ if (warp_idx < mid) { ++ const float* src = &out_smem[warp_idx * HEAD_SIZE]; ++#pragma unroll ++ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { ++ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; ++ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { ++ accs[i] += src[row_idx]; ++ } ++ } ++ } ++ __syncthreads(); ++ } ++ ++ // Write the final output. ++ if (warp_idx == 0) { ++ scalar_t* out_ptr = ++ out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + ++ head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; ++#pragma unroll ++ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { ++ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; ++ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { ++ from_float(*(out_ptr + row_idx), accs[i]); ++ } ++ } ++ } ++} ++ ++// Grid: (num_heads, num_seqs, 1). ++template ++__global__ void paged_attention_v1_kernel( ++ scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] ++ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] ++ const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, ++ // head_size/x, block_size, x] ++ const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, ++ // head_size, block_size] ++ const int num_kv_heads, // [num_heads] ++ const float scale, ++ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] ++ const int* __restrict__ seq_lens, // [num_seqs] ++ const int max_num_blocks_per_seq, ++ const float* __restrict__ alibi_slopes, // [num_heads] ++ const int q_stride, const int kv_block_stride, const int kv_head_stride, ++ const float k_scale, const float v_scale, const int tp_rank, ++ const int blocksparse_local_blocks, const int blocksparse_vert_stride, ++ const int blocksparse_block_size, const int blocksparse_head_sliding_step) { ++ paged_attention_kernel( ++ /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, ++ v_cache, num_kv_heads, scale, block_tables, seq_lens, ++ max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, ++ kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, ++ blocksparse_vert_stride, blocksparse_block_size, ++ blocksparse_head_sliding_step); ++} ++ ++// Grid: (num_heads, num_seqs, max_num_partitions). ++template ++__global__ void paged_attention_v2_kernel( ++ float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] ++ float* __restrict__ max_logits, // [num_seqs, num_heads, ++ // max_num_partitions] ++ scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, ++ // max_num_partitions, head_size] ++ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] ++ const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, ++ // head_size/x, block_size, x] ++ const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, ++ // head_size, block_size] ++ const int num_kv_heads, // [num_heads] ++ const float scale, ++ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] ++ const int* __restrict__ seq_lens, // [num_seqs] ++ const int max_num_blocks_per_seq, ++ const float* __restrict__ alibi_slopes, // [num_heads] ++ const int q_stride, const int kv_block_stride, const int kv_head_stride, ++ const float k_scale, const float v_scale, const int tp_rank, ++ const int blocksparse_local_blocks, const int blocksparse_vert_stride, ++ const int blocksparse_block_size, const int blocksparse_head_sliding_step) { ++ paged_attention_kernel( ++ exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, ++ block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, ++ kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, ++ blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, ++ blocksparse_head_sliding_step); ++} ++ ++// Grid: (num_heads, num_seqs). ++template ++__global__ void paged_attention_v2_reduce_kernel( ++ scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] ++ const float* __restrict__ exp_sums, // [num_seqs, num_heads, ++ // max_num_partitions] ++ const float* __restrict__ max_logits, // [num_seqs, num_heads, ++ // max_num_partitions] ++ const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, ++ // max_num_partitions, head_size] ++ const int* __restrict__ seq_lens, // [num_seqs] ++ const int max_num_partitions) { ++ const int num_heads = gridDim.x; ++ const int head_idx = blockIdx.x; ++ const int seq_idx = blockIdx.y; ++ const int seq_len = seq_lens[seq_idx]; ++ const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); ++ if (num_partitions == 1) { ++ // No need to reduce. Only copy tmp_out to out. ++ scalar_t* out_ptr = ++ out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; ++ const scalar_t* tmp_out_ptr = ++ tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + ++ head_idx * max_num_partitions * HEAD_SIZE; ++ for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { ++ out_ptr[i] = tmp_out_ptr[i]; ++ } ++ // Terminate the thread block. ++ return; ++ } ++ ++ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; ++ const int warp_idx = threadIdx.x / WARP_SIZE; ++ const int lane = threadIdx.x % WARP_SIZE; ++ ++ // Size: 2 * num_partitions. ++ extern __shared__ char shared_mem[]; ++ // Workspace for reduction. ++ __shared__ float red_smem[2 * NUM_WARPS]; ++ ++ // Load max logits to shared memory. ++ float* shared_max_logits = reinterpret_cast(shared_mem); ++ const float* max_logits_ptr = max_logits + ++ seq_idx * num_heads * max_num_partitions + ++ head_idx * max_num_partitions; ++ float max_logit = -FLT_MAX; ++ for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { ++ const float l = max_logits_ptr[i]; ++ shared_max_logits[i] = l; ++ max_logit = fmaxf(max_logit, l); ++ } ++ __syncthreads(); ++ ++ // Get the global max logit. ++ // Reduce within the warp. ++#pragma unroll ++ for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { ++ max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); ++ } ++ if (lane == 0) { ++ red_smem[warp_idx] = max_logit; ++ } ++ __syncthreads(); ++ // Reduce across warps. ++ max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; ++#pragma unroll ++ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { ++ max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); ++ } ++ // Broadcast the max value to all threads. ++ max_logit = VLLM_SHFL_SYNC(max_logit, 0); ++ ++ // Load rescaled exp sums to shared memory. ++ float* shared_exp_sums = ++ reinterpret_cast(shared_mem + sizeof(float) * num_partitions); ++ const float* exp_sums_ptr = exp_sums + ++ seq_idx * num_heads * max_num_partitions + ++ head_idx * max_num_partitions; ++ float global_exp_sum = 0.0f; ++ for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { ++ float l = shared_max_logits[i]; ++ float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); ++ global_exp_sum += rescaled_exp_sum; ++ shared_exp_sums[i] = rescaled_exp_sum; ++ } ++ __syncthreads(); ++ global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); ++ const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); ++ ++ // Aggregate tmp_out to out. ++ const scalar_t* tmp_out_ptr = ++ tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + ++ head_idx * max_num_partitions * HEAD_SIZE; ++ scalar_t* out_ptr = ++ out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; ++#pragma unroll ++ for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { ++ float acc = 0.0f; ++ for (int j = 0; j < num_partitions; ++j) { ++ acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * ++ inv_global_exp_sum; ++ } ++ from_float(out_ptr[i], acc); ++ } ++} ++ ++} // namespace vllm ++ ++#undef WARP_SIZE ++#undef MAX ++#undef MIN ++#undef DIVIDE_ROUND_UP +diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh +index ff64c4b..826b0ed 100644 +--- a/csrc/attention/attention_utils.cuh ++++ b/csrc/attention/attention_utils.cuh +@@ -1,5 +1,6 @@ + /* +- * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp ++ * Adapted from ++ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * +@@ -26,14 +27,14 @@ + namespace vllm { + + // Q*K^T operation. +-template ++template + inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { + using A_vec = typename FloatVec::Type; + // Compute the parallel products for Q*K^T (treat vector lanes separately). + A_vec qk_vec = mul(q[0], k[0]); + #pragma unroll + for (int ii = 1; ii < N; ++ii) { +- qk_vec = fma(q[ii], k[ii], qk_vec); ++ qk_vec = vllm::fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. +@@ -45,12 +46,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { + return qk; + } + +-template ++template + struct Qk_dot { +- template ++ template + static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { + return qk_dot_(q, k); + } + }; + +-} // namespace vllm ++} // namespace vllm +diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh +index 31e0cee..97a25ba 100644 +--- a/csrc/attention/dtype_bfloat16.cuh ++++ b/csrc/attention/dtype_bfloat16.cuh +@@ -1,6 +1,8 @@ + /* +- * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp +- * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h ++ * Adapted from ++ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp ++ * and ++ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * +@@ -28,8 +30,8 @@ + #include + #include + +- typedef __hip_bfloat162 __nv_bfloat162; +- typedef __hip_bfloat16 __nv_bfloat16; ++typedef __hip_bfloat162 __nv_bfloat162; ++typedef __hip_bfloat16 __nv_bfloat16; + #endif + + #include +@@ -50,37 +52,37 @@ struct bf16_8_t { + }; + + // BF16 vector types for Q, K, V. +-template<> ++template <> + struct Vec<__nv_bfloat16, 1> { + using Type = __nv_bfloat16; + }; +-template<> ++template <> + struct Vec<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; + }; +-template<> ++template <> + struct Vec<__nv_bfloat16, 4> { + using Type = bf16_4_t; + }; +-template<> ++template <> + struct Vec<__nv_bfloat16, 8> { + using Type = bf16_8_t; + }; + + // FP32 accumulator vector types corresponding to Vec. +-template<> ++template <> + struct FloatVec<__nv_bfloat16> { + using Type = float; + }; +-template<> ++template <> + struct FloatVec<__nv_bfloat162> { + using Type = float2; + }; +-template<> ++template <> + struct FloatVec { + using Type = Float4_; + }; +-template<> ++template <> + struct FloatVec { + using Type = Float8_; + }; +@@ -92,6 +94,7 @@ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { + #else + return __bfloat1622float2(val); + #endif ++ __builtin_unreachable(); // Suppress missing return statement warning + } + + inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { +@@ -100,6 +103,7 @@ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { + #else + return __bfloat162bfloat162(val); + #endif ++ __builtin_unreachable(); // Suppress missing return statement warning + } + + // Vector addition. +@@ -108,11 +112,12 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { + assert(false); + #else + #ifndef USE_ROCM +- return a + b; ++ return a + b; + #else +- return __hadd(a, b); ++ return __hadd(a, b); + #endif + #endif ++ __builtin_unreachable(); // Suppress missing return statement warning + } + + inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { +@@ -121,6 +126,7 @@ inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { + #else + return __hadd2(a, b); + #endif ++ __builtin_unreachable(); // Suppress missing return statement warning + } + + inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { +@@ -161,30 +167,32 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { + } + + // Vector multiplication. +-template<> ++template <> + inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); + #else + return __hmul(a, b); + #endif ++ __builtin_unreachable(); // Suppress missing return statement warning + } + +-template<> ++template <> + inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); + #else + return __hmul2(a, b); + #endif ++ __builtin_unreachable(); // Suppress missing return statement warning + } + +-template<> ++template <> + inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); + } + +-template<> ++template <> + inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); +@@ -192,7 +200,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { + return c; + } + +-template<> ++template <> + inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t c; +@@ -201,7 +209,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { + return c; + } + +-template<> ++template <> + inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); +@@ -211,7 +219,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { + return c; + } + +-template<> ++template <> + inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t c; +@@ -222,26 +230,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { + return c; + } + +-template<> ++template <> + inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { + float fa = __bfloat162float(a); + float fb = __bfloat162float(b); + return fa * fb; + } + +-template<> ++template <> + inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return mul(fa, fb); + } + +-template<> ++template <> + inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul(bf162bf162(a), b); + } + +-template<> ++template <> + inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { + Float4_ fc; + fc.x = mul(a.x, b.x); +@@ -249,7 +257,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { + return fc; + } + +-template<> ++template <> + inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float4_ fc; +@@ -258,7 +266,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { + return fc; + } + +-template<> ++template <> + inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { + Float8_ fc; + fc.x = mul(a.x, b.x); +@@ -268,7 +276,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { + return fc; + } + +-template<> ++template <> + inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float8_ fc; +@@ -280,20 +288,24 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { + } + + // Vector fused multiply-add. +-inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { ++inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, ++ __nv_bfloat162 c) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); + #else + return __hfma2(a, b, c); + #endif ++ __builtin_unreachable(); // Suppress missing return statement warning + } + +-inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { ++inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, ++ __nv_bfloat162 c) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); + #else + return __hfma2(bf162bf162(a), b, c); + #endif ++ __builtin_unreachable(); // Suppress missing return statement warning + } + + inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { +@@ -379,23 +391,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { + } + + // Vector sum. +-template<> ++template <> + inline __device__ float sum(__nv_bfloat16 v) { + return __bfloat162float(v); + } + +-template<> ++template <> + inline __device__ float sum(__nv_bfloat162 v) { + float2 vf = bf1622float2(v); + return vf.x + vf.y; + } + +-template<> ++template <> + inline __device__ float sum(bf16_4_t v) { + return sum(v.x) + sum(v.y); + } + +-template<> ++template <> + inline __device__ float sum(bf16_8_t v) { + return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); + } +@@ -448,4 +460,4 @@ inline __device__ void zero(__nv_bfloat16& dst) { + #endif + } + +-} // namespace vllm ++} // namespace vllm +diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh +index d3271e6..3a1815f 100644 +--- a/csrc/attention/dtype_float16.cuh ++++ b/csrc/attention/dtype_float16.cuh +@@ -1,6 +1,8 @@ + /* +- * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp +- * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h ++ * Adapted from ++ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp ++ * and ++ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * +@@ -30,37 +32,37 @@ + namespace vllm { + + // FP16 vector types for Q, K, V. +-template<> ++template <> + struct Vec { + using Type = uint16_t; + }; +-template<> ++template <> + struct Vec { + using Type = uint32_t; + }; +-template<> ++template <> + struct Vec { + using Type = uint2; + }; +-template<> ++template <> + struct Vec { + using Type = uint4; + }; + + // FP32 accumulator vector types corresponding to Vec. +-template<> ++template <> + struct FloatVec { + using Type = float; + }; +-template<> ++template <> + struct FloatVec { + using Type = float2; + }; +-template<> ++template <> + struct FloatVec { + using Type = Float4_; + }; +-template<> ++template <> + struct FloatVec { + using Type = Float8_; + }; +@@ -73,8 +75,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) { + return b; + #else + union { +- uint32_t u32; +- uint16_t u16[2]; ++ uint32_t u32; ++ uint16_t u16[2]; + } tmp; + tmp.u16[0] = a; + tmp.u16[1] = a; +@@ -130,10 +132,12 @@ inline __device__ uint32_t float2_to_half2(float2 f) { + } tmp; + #ifndef USE_ROCM + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +- asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); ++ asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" ++ : "=r"(tmp.u32) ++ : "f"(f.y), "f"(f.x)); + #else +- asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); +- asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); ++ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); ++ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + #endif + #else + tmp.u16[0] = float_to_half(f.x); +@@ -201,7 +205,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { + } + + // Vector multiplication. +-template<> ++template <> + inline __device__ uint16_t mul(uint16_t a, uint16_t b) { + uint16_t c; + #ifndef USE_ROCM +@@ -212,7 +216,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) { + return c; + } + +-template<> ++template <> + inline __device__ uint32_t mul(uint32_t a, uint32_t b) { + uint32_t c; + #ifndef USE_ROCM +@@ -223,12 +227,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) { + return c; + } + +-template<> ++template <> + inline __device__ uint32_t mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); + } + +-template<> ++template <> + inline __device__ uint2 mul(uint2 a, uint2 b) { + uint2 c; + c.x = mul(a.x, b.x); +@@ -236,7 +240,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) { + return c; + } + +-template<> ++template <> + inline __device__ uint2 mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + uint2 c; +@@ -245,7 +249,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) { + return c; + } + +-template<> ++template <> + inline __device__ uint4 mul(uint4 a, uint4 b) { + uint4 c; + c.x = mul(a.x, b.x); +@@ -255,7 +259,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) { + return c; + } + +-template<> ++template <> + inline __device__ uint4 mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + uint4 c; +@@ -266,26 +270,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) { + return c; + } + +-template<> ++template <> + inline __device__ float mul(uint16_t a, uint16_t b) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb; + } + +-template<> ++template <> + inline __device__ float2 mul(uint32_t a, uint32_t b) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return mul(fa, fb); + } + +-template<> ++template <> + inline __device__ float2 mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); + } + +-template<> ++template <> + inline __device__ Float4_ mul(uint2 a, uint2 b) { + Float4_ fc; + fc.x = mul(a.x, b.x); +@@ -293,7 +297,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) { + return fc; + } + +-template<> ++template <> + inline __device__ Float4_ mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + Float4_ fc; +@@ -302,7 +306,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) { + return fc; + } + +-template<> ++template <> + inline __device__ Float8_ mul(uint4 a, uint4 b) { + Float8_ fc; + fc.x = mul(a.x, b.x); +@@ -312,7 +316,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) { + return fc; + } + +-template<> ++template <> + inline __device__ Float8_ mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + Float8_ fc; +@@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { + inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; + #ifndef USE_ROCM +- asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); ++ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" ++ : "=r"(d) ++ : "r"(a), "r"(b), "r"(c)); + #else +- asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); ++ asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" ++ : "=v"(d) ++ : "v"(a), "v"(b), "v"(c)); + #endif + return d; + } +@@ -423,24 +431,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { + } + + // Vector sum. +-template<> ++template <> + inline __device__ float sum(uint16_t v) { + return half_to_float(v); + } + +-template<> ++template <> + inline __device__ float sum(uint32_t v) { + float2 tmp = half2_to_float2(v); + return tmp.x + tmp.y; + } + +-template<> ++template <> + inline __device__ float sum(uint2 v) { + uint32_t c = add(v.x, v.y); + return sum(c); + } + +-template<> ++template <> + inline __device__ float sum(uint4 v) { + uint32_t c = add(v.x, v.y); + c = add(c, v.z); +@@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) { + } + + // From float16 to float32. +-inline __device__ float to_float(uint16_t u) { +- return half_to_float(u); +-} ++inline __device__ float to_float(uint16_t u) { return half_to_float(u); } + +-inline __device__ float2 to_float(uint32_t u) { +- return half2_to_float2(u); +-} ++inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); } + + inline __device__ Float4_ to_float(uint2 u) { + Float4_ tmp; +@@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) { + } + + // Zero-out a variable. +-inline __device__ void zero(uint16_t& dst) { +- dst = uint16_t(0); +-} ++inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } + +-} // namespace vllm ++} // namespace vllm +diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh +index b200d2d..7c6a686 100644 +--- a/csrc/attention/dtype_float32.cuh ++++ b/csrc/attention/dtype_float32.cuh +@@ -1,6 +1,8 @@ + /* +- * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp +- * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h ++ * Adapted from ++ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp ++ * and ++ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * +@@ -38,37 +40,35 @@ struct Float8_ { + }; + + // FP32 vector types for Q, K, V. +-template<> ++template <> + struct Vec { + using Type = float; + }; +-template<> ++template <> + struct Vec { + using Type = float2; + }; +-template<> ++template <> + struct Vec { + using Type = float4; + }; + + // FP32 accumulator vector types corresponding to Vec. +-template<> ++template <> + struct FloatVec { + using Type = float; + }; +-template<> ++template <> + struct FloatVec { + using Type = float2; + }; +-template<> ++template <> + struct FloatVec { + using Type = float4; + }; + + // Vector addition. +-inline __device__ float add(float a, float b) { +- return a + b; +-} ++inline __device__ float add(float a, float b) { return a + b; } + + inline __device__ float2 add(float2 a, float2 b) { + float2 c; +@@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) { + } + + // Vector multiplication. +-template<> ++template <> + inline __device__ float mul(float a, float b) { + return a * b; + } + +-template<> ++template <> + inline __device__ float2 mul(float2 a, float2 b) { + float2 c; + c.x = a.x * b.x; +@@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) { + return c; + } + +-template<> ++template <> + inline __device__ float2 mul(float a, float2 b) { + float2 c; + c.x = a * b.x; +@@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) { + return c; + } + +-template<> ++template <> + inline __device__ float4 mul(float4 a, float4 b) { + float4 c; + c.x = a.x * b.x; +@@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) { + return c; + } + +-template<> ++template <> + inline __device__ float4 mul(float a, float4 b) { + float4 c; + c.x = a * b.x; +@@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) { + } + + // Vector fused multiply-add. +-inline __device__ float fma(float a, float b, float c) { +- return a * b + c; +-} ++inline __device__ float fma(float a, float b, float c) { return a * b + c; } + + inline __device__ float2 fma(float2 a, float2 b, float2 c) { + float2 d; +@@ -182,35 +180,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { + } + + // Vector sum. +-template<> ++template <> + inline __device__ float sum(float v) { + return v; + } + +-template<> ++template <> + inline __device__ float sum(float2 v) { + return v.x + v.y; + } + +-template<> ++template <> + inline __device__ float sum(float4 v) { + return v.x + v.y + v.z + v.w; + } + +-template<> ++template <> + inline __device__ float sum(Float4_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y; + } + +-template<> ++template <> + inline __device__ float sum(Float8_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; + } + + // Vector dot product. +-inline __device__ float dot(float a, float b) { +- return a * b; +-} ++inline __device__ float dot(float a, float b) { return a * b; } + + inline __device__ float dot(float2 a, float2 b) { + float2 c = mul(a, b); +@@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) { + } + + // From float to float. +-inline __device__ void from_float(float& dst, float src) { +- dst = src; +-} ++inline __device__ void from_float(float& dst, float src) { dst = src; } + +-inline __device__ void from_float(float2& dst, float2 src) { +- dst = src; +-} ++inline __device__ void from_float(float2& dst, float2 src) { dst = src; } + +-inline __device__ void from_float(float4& dst, float4 src) { +- dst = src; +-} ++inline __device__ void from_float(float4& dst, float4 src) { dst = src; } + + // From float to float. +-inline __device__ float to_float(float u) { +- return u; +-} ++inline __device__ float to_float(float u) { return u; } + +-inline __device__ float2 to_float(float2 u) { +- return u; +-} ++inline __device__ float2 to_float(float2 u) { return u; } + +-inline __device__ float4 to_float(float4 u) { +- return u; +-} ++inline __device__ float4 to_float(float4 u) { return u; } + +-inline __device__ Float4_ to_float(Float4_ u) { +- return u; +-} ++inline __device__ Float4_ to_float(Float4_ u) { return u; } + +-inline __device__ Float8_ to_float(Float8_ u) { +- return u; +-} ++inline __device__ Float8_ to_float(Float8_ u) { return u; } + + // Zero-out a variable. +-inline __device__ void zero(float& dst) { +- dst = 0.f; +-} ++inline __device__ void zero(float& dst) { dst = 0.f; } + +-} // namespace vllm ++} // namespace vllm +diff --git a/csrc/attention/dtype_fp8.cuh b/csrc/attention/dtype_fp8.cuh +index d11dee9..e714e32 100644 +--- a/csrc/attention/dtype_fp8.cuh ++++ b/csrc/attention/dtype_fp8.cuh +@@ -3,33 +3,39 @@ + #include "attention_generic.cuh" + + #include +-#ifdef ENABLE_FP8_E5M2 +-#include +-#endif ++#ifdef ENABLE_FP8 ++ #ifndef USE_ROCM ++ #include ++ #endif // USE_ROCM ++#endif // ENABLE_FP8 + + namespace vllm { +-#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) +-// fp8 vector types for quantization of kv cache + +-template<> ++enum class Fp8KVCacheDataType { ++ kAuto = 0, ++ kFp8E4M3 = 1, ++ kFp8E5M2 = 2, ++}; ++ ++// fp8 vector types for quantization of kv cache ++template <> + struct Vec { +- using Type = uint8_t; ++ using Type = uint8_t; + }; + +-template<> ++template <> + struct Vec { +- using Type = uint16_t; ++ using Type = uint16_t; + }; + +-template<> ++template <> + struct Vec { +- using Type = uint32_t; ++ using Type = uint32_t; + }; + +-template<> ++template <> + struct Vec { +- using Type = uint2; ++ using Type = uint2; + }; +-#endif // ENABLE_FP8_E5M2 + +-} // namespace vllm ++} // namespace vllm +diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu +new file mode 100644 +index 0000000..2732114 +--- /dev/null ++++ b/csrc/attention/paged_attention_v1.cu +@@ -0,0 +1,193 @@ ++/* ++ * Adapted from ++ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp ++ * Copyright (c) 2023, The vLLM team. ++ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++#include "attention_kernels.cuh" ++ ++#ifndef USE_ROCM ++ #define WARP_SIZE 32 ++#else ++ #define WARP_SIZE warpSize ++#endif ++ ++#define MAX(a, b) ((a) > (b) ? (a) : (b)) ++#define MIN(a, b) ((a) < (b) ? (a) : (b)) ++#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) ++ ++#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ ++ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ ++ ((void*)vllm::paged_attention_v1_kernel), \ ++ shared_mem_size); \ ++ vllm::paged_attention_v1_kernel \ ++ <<>>( \ ++ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ ++ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ ++ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ ++ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ ++ blocksparse_vert_stride, blocksparse_block_size, \ ++ blocksparse_head_sliding_step); ++ ++// TODO(woosuk): Tune NUM_THREADS. ++template ++void paged_attention_v1_launcher( ++ torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, ++ torch::Tensor& value_cache, int num_kv_heads, float scale, ++ torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, ++ const std::optional& alibi_slopes, float k_scale, ++ float v_scale, const int tp_rank, const int blocksparse_local_blocks, ++ const int blocksparse_vert_stride, const int blocksparse_block_size, ++ const int blocksparse_head_sliding_step) { ++ int num_seqs = query.size(0); ++ int num_heads = query.size(1); ++ int head_size = query.size(2); ++ int max_num_blocks_per_seq = block_tables.size(1); ++ int q_stride = query.stride(0); ++ int kv_block_stride = key_cache.stride(0); ++ int kv_head_stride = key_cache.stride(1); ++ ++ [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); ++ assert(head_size % thread_group_size == 0); ++ ++ // NOTE: alibi_slopes is optional. ++ const float* alibi_slopes_ptr = ++ alibi_slopes ++ ? reinterpret_cast(alibi_slopes.value().data_ptr()) ++ : nullptr; ++ ++ T* out_ptr = reinterpret_cast(out.data_ptr()); ++ T* query_ptr = reinterpret_cast(query.data_ptr()); ++ CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); ++ CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); ++ int* block_tables_ptr = block_tables.data_ptr(); ++ int* seq_lens_ptr = seq_lens.data_ptr(); ++ ++ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; ++ int padded_max_seq_len = ++ DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; ++ int logits_size = padded_max_seq_len * sizeof(float); ++ int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); ++ // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len ++ // Keep that in sync with the logic here! ++ int shared_mem_size = std::max(logits_size, outputs_size); ++ ++ dim3 grid(num_heads, num_seqs, 1); ++ dim3 block(NUM_THREADS); ++ const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); ++ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); ++ switch (head_size) { ++ // NOTE(woosuk): To reduce the compilation time, we only compile for the ++ // head sizes that we use in the model. However, we can easily extend this ++ // to support any head size which is a multiple of 16. ++ case 32: ++ LAUNCH_PAGED_ATTENTION_V1(32); ++ break; ++ case 64: ++ LAUNCH_PAGED_ATTENTION_V1(64); ++ break; ++ case 80: ++ LAUNCH_PAGED_ATTENTION_V1(80); ++ break; ++ case 96: ++ LAUNCH_PAGED_ATTENTION_V1(96); ++ break; ++ case 112: ++ LAUNCH_PAGED_ATTENTION_V1(112); ++ break; ++ case 120: ++ LAUNCH_PAGED_ATTENTION_V1(120); ++ break; ++ case 128: ++ LAUNCH_PAGED_ATTENTION_V1(128); ++ break; ++ case 192: ++ LAUNCH_PAGED_ATTENTION_V1(192); ++ break; ++ case 256: ++ LAUNCH_PAGED_ATTENTION_V1(256); ++ break; ++ default: ++ TORCH_CHECK(false, "Unsupported head size: ", head_size); ++ break; ++ } ++} ++ ++#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ ++ paged_attention_v1_launcher( \ ++ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ ++ seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ ++ blocksparse_local_blocks, blocksparse_vert_stride, \ ++ blocksparse_block_size, blocksparse_head_sliding_step); ++ ++#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ ++ if (is_block_sparse) { \ ++ CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ ++ } else { \ ++ CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ ++ } ++ ++// NOTE(woosuk): To reduce the compilation time, we omitted block sizes ++// 1, 2, 4, 64, 128, 256. ++#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ ++ switch (block_size) { \ ++ case 8: \ ++ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ ++ break; \ ++ case 16: \ ++ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ ++ break; \ ++ case 32: \ ++ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ ++ break; \ ++ default: \ ++ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ ++ break; \ ++ } ++ ++void paged_attention_v1( ++ torch::Tensor& out, // [num_seqs, num_heads, head_size] ++ torch::Tensor& query, // [num_seqs, num_heads, head_size] ++ torch::Tensor& ++ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] ++ torch::Tensor& ++ value_cache, // [num_blocks, num_heads, head_size, block_size] ++ int64_t num_kv_heads, // [num_heads] ++ double scale, ++ torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] ++ torch::Tensor& seq_lens, // [num_seqs] ++ int64_t block_size, int64_t max_seq_len, ++ const std::optional& alibi_slopes, ++ const std::string& kv_cache_dtype, double k_scale, double v_scale, ++ const int64_t tp_rank, const int64_t blocksparse_local_blocks, ++ const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, ++ const int64_t blocksparse_head_sliding_step) { ++ const bool is_block_sparse = (blocksparse_vert_stride > 1); ++ ++ DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, ++ CALL_V1_LAUNCHER_BLOCK_SIZE) ++} ++ ++#undef WARP_SIZE ++#undef MAX ++#undef MIN ++#undef DIVIDE_ROUND_UP +\ No newline at end of file +diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu +new file mode 100644 +index 0000000..a453b22 +--- /dev/null ++++ b/csrc/attention/paged_attention_v2.cu +@@ -0,0 +1,203 @@ ++/* ++ * Adapted from ++ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp ++ * Copyright (c) 2023, The vLLM team. ++ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++#include "attention_kernels.cuh" ++ ++#ifndef USE_ROCM ++ #define WARP_SIZE 32 ++#else ++ #define WARP_SIZE warpSize ++#endif ++ ++#define MAX(a, b) ((a) > (b) ? (a) : (b)) ++#define MIN(a, b) ((a) < (b) ? (a) : (b)) ++#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) ++ ++#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ ++ vllm::paged_attention_v2_kernel \ ++ <<>>( \ ++ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ ++ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ ++ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ ++ kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ ++ blocksparse_local_blocks, blocksparse_vert_stride, \ ++ blocksparse_block_size, blocksparse_head_sliding_step); \ ++ vllm::paged_attention_v2_reduce_kernel \ ++ <<>>( \ ++ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ ++ max_num_partitions); ++ ++template ++void paged_attention_v2_launcher( ++ torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, ++ torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, ++ torch::Tensor& value_cache, int num_kv_heads, float scale, ++ torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, ++ const std::optional& alibi_slopes, float k_scale, ++ float v_scale, const int tp_rank, const int blocksparse_local_blocks, ++ const int blocksparse_vert_stride, const int blocksparse_block_size, ++ const int blocksparse_head_sliding_step) { ++ int num_seqs = query.size(0); ++ int num_heads = query.size(1); ++ int head_size = query.size(2); ++ int max_num_blocks_per_seq = block_tables.size(1); ++ int q_stride = query.stride(0); ++ int kv_block_stride = key_cache.stride(0); ++ int kv_head_stride = key_cache.stride(1); ++ ++ [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); ++ assert(head_size % thread_group_size == 0); ++ ++ // NOTE: alibi_slopes is optional. ++ const float* alibi_slopes_ptr = ++ alibi_slopes ++ ? reinterpret_cast(alibi_slopes.value().data_ptr()) ++ : nullptr; ++ ++ T* out_ptr = reinterpret_cast(out.data_ptr()); ++ float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); ++ float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); ++ T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); ++ T* query_ptr = reinterpret_cast(query.data_ptr()); ++ CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); ++ CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); ++ int* block_tables_ptr = block_tables.data_ptr(); ++ int* seq_lens_ptr = seq_lens.data_ptr(); ++ ++ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; ++ int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); ++ int logits_size = PARTITION_SIZE * sizeof(float); ++ int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); ++ ++ // For paged attention v2 kernel. ++ dim3 grid(num_heads, num_seqs, max_num_partitions); ++ int shared_mem_size = std::max(logits_size, outputs_size); ++ // For paged attention v2 reduce kernel. ++ dim3 reduce_grid(num_heads, num_seqs); ++ int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); ++ ++ dim3 block(NUM_THREADS); ++ const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); ++ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); ++ switch (head_size) { ++ // NOTE(woosuk): To reduce the compilation time, we only compile for the ++ // head sizes that we use in the model. However, we can easily extend this ++ // to support any head size which is a multiple of 16. ++ case 32: ++ LAUNCH_PAGED_ATTENTION_V2(32); ++ break; ++ case 64: ++ LAUNCH_PAGED_ATTENTION_V2(64); ++ break; ++ case 80: ++ LAUNCH_PAGED_ATTENTION_V2(80); ++ break; ++ case 96: ++ LAUNCH_PAGED_ATTENTION_V2(96); ++ break; ++ case 112: ++ LAUNCH_PAGED_ATTENTION_V2(112); ++ break; ++ case 120: ++ LAUNCH_PAGED_ATTENTION_V2(120); ++ break; ++ case 128: ++ LAUNCH_PAGED_ATTENTION_V2(128); ++ break; ++ case 192: ++ LAUNCH_PAGED_ATTENTION_V2(192); ++ break; ++ case 256: ++ LAUNCH_PAGED_ATTENTION_V2(256); ++ break; ++ default: ++ TORCH_CHECK(false, "Unsupported head size: ", head_size); ++ break; ++ } ++} ++ ++#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ ++ paged_attention_v2_launcher( \ ++ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ ++ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ ++ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ ++ blocksparse_vert_stride, blocksparse_block_size, \ ++ blocksparse_head_sliding_step); ++ ++#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ ++ if (is_block_sparse) { \ ++ CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ ++ } else { \ ++ CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ ++ } ++ ++// NOTE(woosuk): To reduce the compilation time, we omitted block sizes ++// 1, 2, 4, 64, 128, 256. ++#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ ++ switch (block_size) { \ ++ case 8: \ ++ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ ++ break; \ ++ case 16: \ ++ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ ++ break; \ ++ case 32: \ ++ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ ++ break; \ ++ default: \ ++ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ ++ break; \ ++ } ++ ++void paged_attention_v2( ++ torch::Tensor& out, // [num_seqs, num_heads, head_size] ++ torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] ++ torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] ++ torch::Tensor& ++ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] ++ torch::Tensor& query, // [num_seqs, num_heads, head_size] ++ torch::Tensor& ++ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] ++ torch::Tensor& ++ value_cache, // [num_blocks, num_heads, head_size, block_size] ++ int64_t num_kv_heads, // [num_heads] ++ double scale, ++ torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] ++ torch::Tensor& seq_lens, // [num_seqs] ++ int64_t block_size, int64_t max_seq_len, ++ const std::optional& alibi_slopes, ++ const std::string& kv_cache_dtype, double k_scale, double v_scale, ++ const int64_t tp_rank, const int64_t blocksparse_local_blocks, ++ const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, ++ const int64_t blocksparse_head_sliding_step) { ++ const bool is_block_sparse = (blocksparse_vert_stride > 1); ++ DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, ++ CALL_V2_LAUNCHER_BLOCK_SIZE) ++} ++ ++#undef WARP_SIZE ++#undef MAX ++#undef MIN ++#undef DIVIDE_ROUND_UP +\ No newline at end of file +diff --git a/csrc/cache.h b/csrc/cache.h +index 4c142ce..11c4c50 100644 +--- a/csrc/cache.h ++++ b/csrc/cache.h +@@ -1,38 +1,33 @@ + #pragma once + +-#include ++#include + + #include + #include + +-void swap_blocks( +- torch::Tensor& src, +- torch::Tensor& dst, +- const std::map& block_mapping); ++void swap_blocks(torch::Tensor& src, torch::Tensor& dst, ++ const torch::Tensor& block_mapping); + +-void copy_blocks( +- std::vector& key_caches, +- std::vector& value_caches, +- const std::map>& block_mapping); ++// Note: the key_caches and value_caches vectors are constant but ++// not the Tensors they contain. The vectors need to be const refs ++// in order to satisfy pytorch's C++ operator registration code. ++void copy_blocks(std::vector const& key_caches, ++ std::vector const& value_caches, ++ const torch::Tensor& block_mapping); + +-void reshape_and_cache( +- torch::Tensor& key, +- torch::Tensor& value, +- torch::Tensor& key_cache, +- torch::Tensor& value_cache, +- torch::Tensor& slot_mapping, +- const std::string& kv_cache_dtype, +- const float kv_scale); ++void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, ++ torch::Tensor& key_cache, torch::Tensor& value_cache, ++ torch::Tensor& slot_mapping, ++ const std::string& kv_cache_dtype, const double k_scale, ++ const double v_scale); + +-void reshape_and_cache_flash( +- torch::Tensor& key, +- torch::Tensor& value, +- torch::Tensor& key_cache, +- torch::Tensor& value_cache, +- torch::Tensor& slot_mapping, +- const std::string& kv_cache_dtype); ++void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, ++ torch::Tensor& key_cache, ++ torch::Tensor& value_cache, ++ torch::Tensor& slot_mapping, ++ const std::string& kv_cache_dtype, ++ const double k_scale, const double v_scale); + + // Just for unittest +-void convert_fp8( +- torch::Tensor& src_cache, +- torch::Tensor& dst_cache); ++void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, ++ const double scale, const std::string& kv_cache_dtype); +diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu +index 42f884c..8a95279 100644 +--- a/csrc/cache_kernels.cu ++++ b/csrc/cache_kernels.cu +@@ -1,13 +1,14 @@ +-#include ++#include + #include + #include + + #include "cuda_compat.h" + #include "dispatch_utils.h" +-#if defined(ENABLE_FP8_E5M2) +-#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh" +-#elif defined(ENABLE_FP8_E4M3) +-#include "quantization/fp8/amd_detail/quant_utils.cuh" ++ ++#ifdef USE_ROCM ++ #include "quantization/fp8/amd/quant_utils.cuh" ++#else ++ #include "quantization/fp8/nvidia/quant_utils.cuh" + #endif + + #include +@@ -17,20 +18,17 @@ + + #ifdef USE_ROCM + #include +- typedef __hip_bfloat16 __nv_bfloat16; ++typedef __hip_bfloat16 __nv_bfloat16; + #endif + +-void swap_blocks( +- torch::Tensor& src, +- torch::Tensor& dst, +- const std::map& block_mapping) { ++void swap_blocks(torch::Tensor& src, torch::Tensor& dst, ++ const torch::Tensor& block_mapping) { + torch::Device src_device = src.device(); + torch::Device dst_device = dst.device(); + cudaMemcpyKind memcpy_type; + if (src_device.is_cuda() && dst_device.is_cuda()) { +- TORCH_CHECK( +- src_device.index() == dst_device.index(), +- "src and dst must be on the same GPU"); ++ TORCH_CHECK(src_device.index() == dst_device.index(), ++ "src and dst must be on the same GPU"); + memcpy_type = cudaMemcpyDeviceToDevice; + } else if (src_device.is_cuda() && dst_device.is_cpu()) { + memcpy_type = cudaMemcpyDeviceToHost; +@@ -40,41 +38,44 @@ void swap_blocks( + TORCH_CHECK(false, "Invalid device combination"); + } + +- char *src_ptr = static_cast(src.data_ptr()); +- char *dst_ptr = static_cast(dst.data_ptr()); ++ // NOTE(youkaichao): keep in mind that `block_mapping` should be ++ // a cpu tensor, otherwise every `item` call will require a gpu-cpu ++ // synchronization. ++ TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); ++ ++ char* src_ptr = static_cast(src.data_ptr()); ++ char* dst_ptr = static_cast(dst.data_ptr()); + + const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); +- const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); ++ const at::cuda::OptionalCUDAGuard device_guard( ++ src_device.is_cuda() ? src_device : dst_device); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // NOTE(woosuk): This can be slow if the number of blocks is large. +- for (const auto& pair : block_mapping) { +- int64_t src_block_number = pair.first; +- int64_t dst_block_number = pair.second; ++ const int64_t num_blocks = block_mapping.size(0); ++ for (size_t i = 0; i < num_blocks; i++) { ++ int64_t src_block_number = block_mapping[i][0].item(); ++ int64_t dst_block_number = block_mapping[i][1].item(); + int64_t src_offset = src_block_number * block_size_in_bytes; + int64_t dst_offset = dst_block_number * block_size_in_bytes; +- cudaMemcpyAsync( +- dst_ptr + dst_offset, +- src_ptr + src_offset, +- block_size_in_bytes, +- memcpy_type, +- stream); ++ cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset, ++ block_size_in_bytes, memcpy_type, stream); + } + } + + namespace vllm { + + // Grid: (num_layers, num_pairs) +-template +-__global__ void copy_blocks_kernel( +- int64_t* key_cache_ptrs, +- int64_t* value_cache_ptrs, +- const int64_t* __restrict__ block_mapping, +- const int numel_per_block) { ++template ++__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, ++ int64_t* value_cache_ptrs, ++ const int64_t* __restrict__ block_mapping, ++ const int numel_per_block) { + const int layer_idx = blockIdx.x; + const int pair_idx = blockIdx.y; + + scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); +- scalar_t* value_cache = reinterpret_cast(value_cache_ptrs[layer_idx]); ++ scalar_t* value_cache = ++ reinterpret_cast(value_cache_ptrs[layer_idx]); + int64_t src_block_number = block_mapping[2 * pair_idx]; + int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; + +@@ -92,12 +93,14 @@ __global__ void copy_blocks_kernel( + } + } + +-} // namespace vllm ++} // namespace vllm + +-void copy_blocks( +- std::vector& key_caches, +- std::vector& value_caches, +- const std::map>& block_mapping) { ++// Note: the key_caches and value_caches vectors are constant but ++// not the Tensors they contain. The vectors need to be const refs ++// in order to satisfy pytorch's C++ operator registration code. ++void copy_blocks(std::vector const& key_caches, ++ std::vector const& value_caches, ++ const torch::Tensor& block_mapping) { + int num_layers = key_caches.size(); + TORCH_CHECK(num_layers == value_caches.size()); + if (num_layers == 0) { +@@ -111,29 +114,23 @@ void copy_blocks( + int64_t key_cache_ptrs[num_layers]; + int64_t value_cache_ptrs[num_layers]; + for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { +- key_cache_ptrs[layer_idx] = reinterpret_cast(key_caches[layer_idx].data_ptr()); +- value_cache_ptrs[layer_idx] = reinterpret_cast(value_caches[layer_idx].data_ptr()); ++ key_cache_ptrs[layer_idx] = ++ reinterpret_cast(key_caches[layer_idx].data_ptr()); ++ value_cache_ptrs[layer_idx] = ++ reinterpret_cast(value_caches[layer_idx].data_ptr()); + } +- // Create block mapping array. +- std::vector block_mapping_vec; +- for (const auto& pair : block_mapping) { +- int64_t src_block_number = pair.first; +- for (int64_t dst_block_number : pair.second) { +- block_mapping_vec.push_back(src_block_number); +- block_mapping_vec.push_back(dst_block_number); +- } +- } +- int64_t* block_mapping_array = block_mapping_vec.data(); +- int num_pairs = block_mapping_vec.size() / 2; ++ ++ // block_mapping is a 2D tensor with shape (num_pairs, 2). ++ int num_pairs = block_mapping.size(0); + + // Move the data structures to the GPU. + // NOTE: This synchronizes the CPU and GPU. +- torch::Tensor key_cache_ptrs_tensor = torch::from_blob( +- key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); +- torch::Tensor value_cache_ptrs_tensor = torch::from_blob( +- value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); +- torch::Tensor block_mapping_tensor = torch::from_blob( +- block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device); ++ torch::Tensor key_cache_ptrs_tensor = ++ torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64) ++ .to(cache_device); ++ torch::Tensor value_cache_ptrs_tensor = ++ torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64) ++ .to(cache_device); + + // Launch the kernel. + const int numel_per_block = key_caches[0][0].numel(); +@@ -142,31 +139,28 @@ void copy_blocks( + const at::cuda::OptionalCUDAGuard device_guard(cache_device); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( +- key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { +- vllm::copy_blocks_kernel<<>>( +- key_cache_ptrs_tensor.data_ptr(), +- value_cache_ptrs_tensor.data_ptr(), +- block_mapping_tensor.data_ptr(), +- numel_per_block); +- })); ++ key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { ++ vllm::copy_blocks_kernel<<>>( ++ key_cache_ptrs_tensor.data_ptr(), ++ value_cache_ptrs_tensor.data_ptr(), ++ block_mapping.data_ptr(), numel_per_block); ++ })); + } + + namespace vllm { + +-template ++template + __global__ void reshape_and_cache_kernel( +- const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] +- const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] +- cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] +- cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] +- const int64_t* __restrict__ slot_mapping, // [num_tokens] +- const int key_stride, +- const int value_stride, +- const int num_heads, +- const int head_size, +- const int block_size, +- const int x, +- const float kv_scale) { ++ const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] ++ const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] ++ cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, ++ // block_size, x] ++ cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, ++ // block_size] ++ const int64_t* __restrict__ slot_mapping, // [num_tokens] ++ const int key_stride, const int value_stride, const int num_heads, ++ const int head_size, const int block_size, const int x, const float k_scale, ++ const float v_scale) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx < 0) { +@@ -187,47 +181,40 @@ __global__ void reshape_and_cache_kernel( + const int x_idx = head_offset / x; + const int x_offset = head_offset % x; + +- const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x +- + head_idx * (head_size / x) * block_size * x +- + x_idx * block_size * x +- + block_offset * x +- + x_offset; +- const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size +- + head_idx * head_size * block_size +- + head_offset * block_size +- + block_offset; ++ const int64_t tgt_key_idx = ++ block_idx * num_heads * (head_size / x) * block_size * x + ++ head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + ++ block_offset * x + x_offset; ++ const int64_t tgt_value_idx = ++ block_idx * num_heads * head_size * block_size + ++ head_idx * head_size * block_size + head_offset * block_size + ++ block_offset; + scalar_t tgt_key = key[src_key_idx]; + scalar_t tgt_value = value[src_value_idx]; +- if constexpr (is_fp8_kv_cache) { +-#if defined(ENABLE_FP8_E5M2) +- key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_key); +- value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_value); +-#elif defined(ENABLE_FP8_E4M3) +- key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion(tgt_key, kv_scale); +- value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion(tgt_value, kv_scale); +-#else +- assert(false); +-#endif +- } else { ++ if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + key_cache[tgt_key_idx] = tgt_key; + value_cache[tgt_value_idx] = tgt_value; ++ } else { ++ key_cache[tgt_key_idx] = ++ fp8::scaled_convert(tgt_key, k_scale); ++ value_cache[tgt_value_idx] = ++ fp8::scaled_convert(tgt_value, v_scale); + } + } + } + +-template ++template + __global__ void reshape_and_cache_flash_kernel( +- const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] +- const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] +- scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size] +- scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size] +- const int64_t* __restrict__ slot_mapping, // [num_tokens] +- const int block_stride, +- const int key_stride, +- const int value_stride, +- const int num_heads, +- const int head_size, +- const int block_size) { ++ const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] ++ const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] ++ cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, ++ // head_size] ++ cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, ++ // head_size] ++ const int64_t* __restrict__ slot_mapping, // [num_tokens] ++ const int block_stride, const int key_stride, const int value_stride, ++ const int num_heads, const int head_size, const int block_size, ++ const float k_scale, const float v_scale) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded +@@ -242,40 +229,47 @@ __global__ void reshape_and_cache_flash_kernel( + const int64_t src_value_idx = token_idx * value_stride + i; + const int head_idx = i / head_size; + const int head_offset = i % head_size; +- const int64_t tgt_value_idx = block_idx * block_stride +- + block_offset * num_heads * head_size +- + head_idx * head_size +- + head_offset; +- k_cache[tgt_value_idx] = key[src_key_idx]; +- v_cache[tgt_value_idx] = value[src_value_idx]; ++ const int64_t tgt_key_value_idx = block_idx * block_stride + ++ block_offset * num_heads * head_size + ++ head_idx * head_size + head_offset; ++ scalar_t tgt_key = key[src_key_idx]; ++ scalar_t tgt_value = value[src_value_idx]; ++ if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { ++ key_cache[tgt_key_value_idx] = tgt_key; ++ value_cache[tgt_key_value_idx] = tgt_value; ++ } else { ++ key_cache[tgt_key_value_idx] = ++ fp8::scaled_convert(tgt_key, k_scale); ++ value_cache[tgt_key_value_idx] = ++ fp8::scaled_convert(tgt_value, v_scale); ++ } + } + } +-} // namespace vllm +- +-#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ +- vllm::reshape_and_cache_kernel<<>>( \ +- reinterpret_cast(key.data_ptr()), \ +- reinterpret_cast(value.data_ptr()), \ +- reinterpret_cast(key_cache.data_ptr()), \ +- reinterpret_cast(value_cache.data_ptr()), \ +- slot_mapping.data_ptr(), \ +- key_stride, \ +- value_stride, \ +- num_heads, \ +- head_size, \ +- block_size, \ +- x, \ +- kv_scale); ++} // namespace vllm ++ ++// KV_T is the stored data type of kv-cache. ++// CACHE_T is the data type of key and value tensors. ++// KV_DTYPE is the real data type of kv-cache. ++#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ ++ vllm::reshape_and_cache_kernel \ ++ <<>>( \ ++ reinterpret_cast(key.data_ptr()), \ ++ reinterpret_cast(value.data_ptr()), \ ++ reinterpret_cast(key_cache.data_ptr()), \ ++ reinterpret_cast(value_cache.data_ptr()), \ ++ slot_mapping.data_ptr(), key_stride, value_stride, \ ++ num_heads, head_size, block_size, x, k_scale, v_scale); + + void reshape_and_cache( +- torch::Tensor& key, // [num_tokens, num_heads, head_size] +- torch::Tensor& value, // [num_tokens, num_heads, head_size] +- torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] +- torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] +- torch::Tensor& slot_mapping, // [num_tokens] +- const std::string& kv_cache_dtype, +- const float kv_scale) +-{ ++ torch::Tensor& key, // [num_tokens, num_heads, head_size] ++ torch::Tensor& value, // [num_tokens, num_heads, head_size] ++ torch::Tensor& ++ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] ++ torch::Tensor& ++ value_cache, // [num_blocks, num_heads, head_size, block_size] ++ torch::Tensor& slot_mapping, // [num_tokens] ++ const std::string& kv_cache_dtype, const double k_scale, ++ const double v_scale) { + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); +@@ -289,111 +283,93 @@ void reshape_and_cache( + dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); +- if (kv_cache_dtype == "auto") { +- if (key.dtype() == at::ScalarType::Float) { +- CALL_RESHAPE_AND_CACHE(float, float, false); +- } else if (key.dtype() == at::ScalarType::Half) { +- CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false); +- } else if (key.dtype() == at::ScalarType::BFloat16) { +- CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false); +- } +- } else if (kv_cache_dtype == "fp8") { +- if (key.dtype() == at::ScalarType::Float) { +- CALL_RESHAPE_AND_CACHE(float, uint8_t, true); +- } else if (key.dtype() == at::ScalarType::Half) { +- CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true); +- } else if (key.dtype() == at::ScalarType::BFloat16) { +- CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true); +- } +- } else { +- TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); +- } ++ ++ DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, ++ CALL_RESHAPE_AND_CACHE) + } + ++// KV_T is the stored data type of kv-cache. ++// CACHE_T is the data type of key and value tensors. ++// KV_DTYPE is the real data type of kv-cache. ++#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \ ++ vllm::reshape_and_cache_flash_kernel \ ++ <<>>( \ ++ reinterpret_cast(key.data_ptr()), \ ++ reinterpret_cast(value.data_ptr()), \ ++ reinterpret_cast(key_cache.data_ptr()), \ ++ reinterpret_cast(value_cache.data_ptr()), \ ++ slot_mapping.data_ptr(), block_stride, key_stride, \ ++ value_stride, num_heads, head_size, block_size, k_scale, v_scale); ++ + void reshape_and_cache_flash( +- torch::Tensor& key, // [num_tokens, num_heads, head_size] +- torch::Tensor& value, // [num_tokens, num_heads, head_size] +- torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] +- torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] +- torch::Tensor& slot_mapping, // [num_tokens] +- const std::string& kv_cache_dtype) +-{ +- // FIXME: only support auto datatype, does not support fp8 +- if (kv_cache_dtype != "auto") { +- TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); +- } +- int num_tokens = key.size(0); ++ torch::Tensor& key, // [num_tokens, num_heads, head_size] ++ torch::Tensor& value, // [num_tokens, num_heads, head_size] ++ torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] ++ torch::Tensor& ++ value_cache, // [num_blocks, block_size, num_heads, head_size] ++ torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] ++ const std::string& kv_cache_dtype, const double k_scale, ++ const double v_scale) { ++ // NOTE(woosuk): In vLLM V1, key.size(0) can be different from ++ // slot_mapping.size(0) because of padding for CUDA graphs. ++ // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because ++ // both include padding. ++ // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0) ++ // since key includes padding for CUDA graphs, while slot_mapping does not. ++ // In this case, slot_mapping.size(0) represents the actual number of tokens ++ // before padding. ++ // For compatibility with both cases, we use slot_mapping.size(0) as the ++ // number of tokens. ++ int num_tokens = slot_mapping.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); +- int block_size = k_cache.size(1); ++ int block_size = key_cache.size(1); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); +- int block_stride = k_cache.stride(0); +- TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0)); ++ int block_stride = key_cache.stride(0); ++ TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0)); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); +- VLLM_DISPATCH_FLOATING_TYPES( +- key.scalar_type(), +- "reshape_and_cache_flash", +- [&] { +- vllm::reshape_and_cache_flash_kernel<<>>( +- key.data_ptr(), +- value.data_ptr(), +- k_cache.data_ptr(), +- v_cache.data_ptr(), +- slot_mapping.data_ptr(), +- block_stride, +- key_stride, +- value_stride, +- num_heads, +- head_size, +- block_size); +- }); ++ ++ DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, ++ CALL_RESHAPE_AND_CACHE_FLASH); + } + + namespace vllm { + +-template +-__global__ void convert_fp8_kernel( +- const Tin* __restrict__ src_cache, +- Tout* __restrict__ dst_cache, +- const int64_t block_stride) { ++template ++__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, ++ Tout* __restrict__ dst_cache, ++ const float scale, ++ const int64_t block_stride) { + const int64_t block_idx = blockIdx.x; + for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { + int64_t idx = block_idx * block_stride + i; +-#if defined(ENABLE_FP8_E5M2) +- dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion(src_cache[idx]); +-#elif defined(ENABLE_FP8_E4M3) +- dst_cache[idx] = fp8_e4m3::vec_conversion(src_cache[idx]); +-#else +- assert(false); +-#endif ++ dst_cache[idx] = ++ fp8::scaled_convert(src_cache[idx], scale); + } + } + +-} // namespace vllm ++} // namespace vllm + +-#define CALL_CONVERT_FP8(Tout, Tin) \ +- vllm::convert_fp8_kernel<<>>( \ +- reinterpret_cast(src_cache.data_ptr()), \ +- reinterpret_cast(dst_cache.data_ptr()), \ +- block_stride); ++#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ ++ vllm::convert_fp8_kernel<<>>( \ ++ reinterpret_cast(src_cache.data_ptr()), \ ++ reinterpret_cast(dst_cache.data_ptr()), scale, block_stride); + +-void convert_fp8( +- torch::Tensor& src_cache, +- torch::Tensor& dst_cache) +-{ ++// Only for testing. ++void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, ++ const double scale, const std::string& kv_cache_dtype) { + torch::Device src_device = src_cache.device(); + torch::Device dst_device = dst_cache.device(); + TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") + TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") +- TORCH_CHECK( +- src_device.index() == dst_device.index(), +- "src and dst must be on the same GPU"); ++ TORCH_CHECK(src_device.index() == dst_device.index(), ++ "src and dst must be on the same GPU"); + at::cuda::OptionalCUDAGuard device_guard(src_device); + + int64_t num_blocks = src_cache.size(0); +@@ -403,17 +379,37 @@ void convert_fp8( + dim3 block(std::min(block_stride, int64_t(512))); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + +- if (src_cache.dtype() == at::ScalarType::Float) { +- CALL_CONVERT_FP8(uint8_t, float); +- } else if (src_cache.dtype() == at::ScalarType::Half) { +- CALL_CONVERT_FP8(uint8_t, uint16_t); +- } else if (src_cache.dtype() == at::ScalarType::BFloat16) { +- CALL_CONVERT_FP8(uint8_t, __nv_bfloat16); +- } else if (dst_cache.dtype() == at::ScalarType::Float) { +- CALL_CONVERT_FP8(float, uint8_t); +- } else if (dst_cache.dtype() == at::ScalarType::Half) { +- CALL_CONVERT_FP8(uint16_t, uint8_t); +- } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { +- CALL_CONVERT_FP8(__nv_bfloat16, uint8_t); ++ if (kv_cache_dtype == "auto") { ++ if (src_cache.dtype() == at::ScalarType::Float) { ++ CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto); ++ } else if (src_cache.dtype() == at::ScalarType::Half) { ++ CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); ++ } else if (src_cache.dtype() == at::ScalarType::BFloat16) { ++ CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); ++ } else if (dst_cache.dtype() == at::ScalarType::Float) { ++ CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto); ++ } else if (dst_cache.dtype() == at::ScalarType::Half) { ++ CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto); ++ } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { ++ CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto); ++ } ++ } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { ++ if (src_cache.dtype() == at::ScalarType::Float) { ++ CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3); ++ } else if (src_cache.dtype() == at::ScalarType::Half) { ++ CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3); ++ } else if (src_cache.dtype() == at::ScalarType::BFloat16) { ++ CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, ++ vllm::Fp8KVCacheDataType::kFp8E4M3); ++ } else if (dst_cache.dtype() == at::ScalarType::Float) { ++ CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); ++ } else if (dst_cache.dtype() == at::ScalarType::Half) { ++ CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); ++ } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { ++ CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, ++ vllm::Fp8KVCacheDataType::kFp8E4M3); ++ } ++ } else { ++ TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); + } + } +diff --git a/csrc/core/exception.hpp b/csrc/core/exception.hpp +new file mode 100644 +index 0000000..f3b2ffa +--- /dev/null ++++ b/csrc/core/exception.hpp +@@ -0,0 +1,3 @@ ++#pragma once ++ ++#define VLLM_IMPLIES(p, q) (!(p) || (q)) +diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp +new file mode 100644 +index 0000000..ba9f40a +--- /dev/null ++++ b/csrc/core/math.hpp +@@ -0,0 +1,7 @@ ++#include ++#include ++ ++inline uint32_t next_pow_2(uint32_t const num) { ++ if (num <= 1) return num; ++ return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); ++} +\ No newline at end of file +diff --git a/csrc/core/registration.h b/csrc/core/registration.h +new file mode 100644 +index 0000000..4d0ce1c +--- /dev/null ++++ b/csrc/core/registration.h +@@ -0,0 +1,27 @@ ++#pragma once ++ ++#include ++ ++#define _CONCAT(A, B) A##B ++#define CONCAT(A, B) _CONCAT(A, B) ++ ++#define _STRINGIFY(A) #A ++#define STRINGIFY(A) _STRINGIFY(A) ++ ++// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME ++// could be a macro instead of a literal token. ++#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) ++ ++// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME ++// could be a macro instead of a literal token. ++#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \ ++ TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE) ++ ++// REGISTER_EXTENSION allows the shared library to be loaded and initialized ++// via python's import statement. ++#define REGISTER_EXTENSION(NAME) \ ++ PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ ++ static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \ ++ STRINGIFY(NAME), nullptr, 0, nullptr}; \ ++ return PyModule_Create(&module); \ ++ } +diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp +new file mode 100644 +index 0000000..408e736 +--- /dev/null ++++ b/csrc/core/scalar_type.hpp +@@ -0,0 +1,347 @@ ++#pragma once ++ ++// For TORCH_CHECK ++#include ++ ++namespace vllm { ++ ++// ++// ScalarType can represent a wide range of floating point and integer types, ++// in particular it can be used to represent sub-byte data types (something ++// that torch.dtype currently does not support). ++// ++// The type definitions on the Python side can be found in: vllm/scalar_type.py ++// these type definitions should be kept up to date with any Python API changes ++// here. ++// ++class ScalarType { ++ public: ++ enum NanRepr : uint8_t { ++ NAN_NONE = 0, // nans are not supported ++ NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s ++ NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s ++ ++ NAN_REPR_ID_MAX ++ }; ++ ++ constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_, ++ int32_t bias, bool finite_values_only = false, ++ NanRepr nan_repr = NAN_IEEE_754) ++ : exponent(exponent), ++ mantissa(mantissa), ++ signed_(signed_), ++ bias(bias), ++ finite_values_only(finite_values_only), ++ nan_repr(nan_repr){}; ++ ++ static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { ++ return ScalarType(0, size_bits - 1, true, bias); ++ } ++ ++ static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { ++ return ScalarType(0, size_bits, false, bias); ++ } ++ ++ // IEEE 754 compliant floating point type ++ static constexpr ScalarType float_IEEE754(uint8_t exponent, ++ uint8_t mantissa) { ++ TORCH_CHECK(mantissa > 0 && exponent > 0); ++ return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); ++ } ++ ++ // IEEE 754 non-compliant floating point type ++ static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, ++ bool finite_values_only, ++ NanRepr nan_repr) { ++ TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); ++ TORCH_CHECK(mantissa > 0 && exponent > 0); ++ TORCH_CHECK(nan_repr != NAN_IEEE_754, ++ "use `float_IEEE754` constructor for floating point types that " ++ "follow IEEE 754 conventions"); ++ return ScalarType(exponent, mantissa, true, 0, finite_values_only, ++ nan_repr); ++ } ++ ++ uint8_t const exponent; // size of the exponent field (0 for integer types) ++ uint8_t const mantissa; // size of the mantissa field (size of the integer ++ // excluding the sign bit for integer types) ++ bool const signed_; // flag if the type supports negative numbers (i.e. has a ++ // sign bit) ++ int32_t const bias; // stored values equal value + bias, ++ // used for quantized type ++ ++ // Extra Floating point info ++ bool const finite_values_only; // i.e. no +/-inf if true ++ NanRepr const nan_repr; // how NaNs are represented ++ // (not applicable for integer types) ++ ++ using Id = int64_t; ++ ++ private: ++ // Field size in id ++ template ++ static constexpr size_t member_id_field_width() { ++ using T = std::decay_t; ++ return std::is_same_v ? 1 : sizeof(T) * 8; ++ } ++ ++ template ++ static constexpr auto reduce_members_helper(Fn f, Init val, Member member, ++ Rest... rest) { ++ auto new_val = f(val, member); ++ if constexpr (sizeof...(rest) > 0) { ++ return reduce_members_helper(f, new_val, rest...); ++ } else { ++ return new_val; ++ }; ++ } ++ ++ template ++ constexpr auto reduce_members(Fn f, Init init) const { ++ // Should be in constructor order for `from_id` ++ return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, ++ finite_values_only, nan_repr); ++ }; ++ ++ template ++ static constexpr auto reduce_member_types(Fn f, Init init) { ++ constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); ++ return dummy_type.reduce_members(f, init); ++ }; ++ ++ static constexpr auto id_size_bits() { ++ return reduce_member_types( ++ [](int acc, auto member) -> int { ++ return acc + member_id_field_width(); ++ }, ++ 0); ++ } ++ ++ public: ++ // unique id for this scalar type that can be computed at compile time for ++ // c++17 template specialization this is not needed once we migrate to ++ // c++20 and can pass literal classes as template parameters ++ constexpr Id id() const { ++ static_assert(id_size_bits() <= sizeof(Id) * 8, ++ "ScalarType id is too large to be stored"); ++ ++ auto or_and_advance = [](std::pair result, ++ auto member) -> std::pair { ++ auto [id, bit_offset] = result; ++ auto constexpr bits = member_id_field_width(); ++ return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) ++ << bit_offset, ++ bit_offset + bits}; ++ }; ++ return reduce_members(or_and_advance, std::pair{}).first; ++ } ++ ++ // create a ScalarType from an id, for c++17 template specialization, ++ // this is not needed once we migrate to c++20 and can pass literal ++ // classes as template parameters ++ static constexpr ScalarType from_id(Id id) { ++ auto extract_and_advance = [id](auto result, auto member) { ++ using T = decltype(member); ++ auto [tuple, bit_offset] = result; ++ auto constexpr bits = member_id_field_width(); ++ auto extracted_val = static_cast((int64_t(id) >> bit_offset) & ++ ((uint64_t(1) << bits) - 1)); ++ auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); ++ return std::pair{new_tuple, bit_offset + bits}; ++ }; ++ ++ auto [tuple_args, _] = reduce_member_types(extract_and_advance, ++ std::pair, int>{}); ++ return std::apply([](auto... args) { return ScalarType(args...); }, ++ tuple_args); ++ } ++ ++ constexpr int64_t size_bits() const { ++ return mantissa + exponent + is_signed(); ++ } ++ constexpr bool is_signed() const { return signed_; } ++ constexpr bool is_integer() const { return exponent == 0; } ++ constexpr bool is_floating_point() const { return exponent > 0; } ++ constexpr bool is_ieee_754() const { ++ return is_floating_point() && finite_values_only == false && ++ nan_repr == NAN_IEEE_754; ++ } ++ constexpr bool has_nans() const { ++ return is_floating_point() && nan_repr != NAN_NONE; ++ } ++ constexpr bool has_infs() const { ++ return is_floating_point() && finite_values_only == false; ++ } ++ constexpr bool has_bias() const { return bias != 0; } ++ ++ private: ++ double _floating_point_max() const { ++ TORCH_CHECK(mantissa <= 52 && exponent <= 11, ++ "Cannot represent max/min as a double for type ", str()); ++ ++ uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; ++ if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { ++ max_mantissa -= 1; ++ } ++ ++ uint64_t max_exponent = (uint64_t(1) << exponent) - 2; ++ if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { ++ TORCH_CHECK(exponent < 11, ++ "Cannot represent max/min as a double for type ", str()); ++ max_exponent += 1; ++ } ++ ++ // adjust the exponent to match that of a double ++ // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e ++ // is the exponent bits), there is some precedent for non-standard biases, ++ // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes ++ // but to avoid premature over complication we are just assuming the ++ // standard exponent bias until there is a need to support non-standard ++ // biases ++ uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; ++ uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 ++ ++ uint64_t max_exponent_double = ++ max_exponent - exponent_bias + exponent_bias_double; ++ ++ // shift the mantissa into the position for a double and ++ // the exponent ++ uint64_t double_raw = ++ (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); ++ ++ return *reinterpret_cast(&double_raw); ++ } ++ ++ constexpr std::variant _raw_max() const { ++ if (is_floating_point()) { ++ return {_floating_point_max()}; ++ } else { ++ TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), ++ "Cannot represent max as a int64_t"); ++ return {(int64_t(1) << mantissa) - 1}; ++ } ++ } ++ ++ constexpr std::variant _raw_min() const { ++ if (is_floating_point()) { ++ TORCH_CHECK(is_signed(), ++ "We currently assume all floating point types are signed"); ++ constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); ++ ++ double max = _floating_point_max(); ++ uint64_t max_raw = *reinterpret_cast(&max); ++ uint64_t min_raw = max_raw | sign_bit_double; ++ return {*reinterpret_cast(&min_raw)}; ++ } else { ++ TORCH_CHECK(!is_signed() || size_bits() <= 64, ++ "Cannot represent min as a int64_t"); ++ if (is_signed()) { ++ // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 ++ // then perform an arithmetic shift right to set all the bits above ++ // (size_bits() - 1) to 1 ++ return {INT64_MIN >> (64 - size_bits())}; ++ } else { ++ return {int64_t(0)}; ++ } ++ } ++ } ++ ++ public: ++ // Max representable value for this scalar type. ++ // (accounting for bias if there is one) ++ constexpr std::variant max() const { ++ return std::visit( ++ [this](auto x) -> std::variant { return {x - bias}; }, ++ _raw_max()); ++ } ++ ++ // Min representable value for this scalar type. ++ // (accounting for bias if there is one) ++ constexpr std::variant min() const { ++ return std::visit( ++ [this](auto x) -> std::variant { return {x - bias}; }, ++ _raw_min()); ++ } ++ ++ std::string str() const { ++ /* naming generally follows: https://github.com/jax-ml/ml_dtypes ++ * for floating point types (leading f) the scheme is: ++ * `float_em[flags]` ++ * flags: ++ * - no-flags: means it follows IEEE 754 conventions ++ * - f: means finite values only (no infinities) ++ * - n: means nans are supported (non-standard encoding) ++ * for integer types the scheme is: ++ * `[u]int[b]` ++ * - if bias is not present it means its zero ++ */ ++ if (is_floating_point()) { ++ auto ret = "float" + std::to_string(size_bits()) + "_e" + ++ std::to_string(exponent) + "m" + std::to_string(mantissa); ++ if (!is_ieee_754()) { ++ if (finite_values_only) { ++ ret += "f"; ++ } ++ if (nan_repr != NAN_NONE) { ++ ret += "n"; ++ } ++ } ++ return ret; ++ } else { ++ auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); ++ if (has_bias()) { ++ ret += "b" + std::to_string(bias); ++ } ++ return ret; ++ } ++ } ++ ++ constexpr bool operator==(ScalarType const& other) const { ++ return mantissa == other.mantissa && exponent == other.exponent && ++ bias == other.bias && signed_ == other.signed_ && ++ finite_values_only == other.finite_values_only && ++ nan_repr == other.nan_repr; ++ } ++}; ++ ++using ScalarTypeId = ScalarType::Id; ++ ++// "rust style" names generally following: ++// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 ++static inline constexpr auto kS4 = ScalarType::int_(4); ++static inline constexpr auto kU4 = ScalarType::uint(4); ++static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); ++static inline constexpr auto kS8 = ScalarType::int_(8); ++static inline constexpr auto kU8 = ScalarType::uint(8); ++static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); ++ ++static inline constexpr auto kFE3M2f = ++ ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); ++static inline constexpr auto kFE4M3fn = ++ ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); ++static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); ++static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); ++static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); ++ ++// Fixed width style names, generally following: ++// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57 ++static inline constexpr auto kInt4 = kS4; ++static inline constexpr auto kUint4 = kU4; ++static inline constexpr auto kUint4b8 = kU4B8; ++static inline constexpr auto kInt8 = kS8; ++static inline constexpr auto kUint8 = kU8; ++static inline constexpr auto kUint8b128 = kU8B128; ++ ++static inline constexpr auto kFloat6_e3m2f = kFE3M2f; ++static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; ++static inline constexpr auto kFloat8_e5m2 = kFE5M2; ++static inline constexpr auto kFloat16_e8m7 = kFE8M7; ++static inline constexpr auto kFloat16_e5m10 = kFE5M10; ++ ++// colloquial names ++static inline constexpr auto kHalf = kFE5M10; ++static inline constexpr auto kFloat16 = kHalf; ++static inline constexpr auto kBFloat16 = kFE8M7; ++ ++static inline constexpr auto kFloat16Id = kFloat16.id(); ++}; // namespace vllm +diff --git a/csrc/cpu/activation.cpp b/csrc/cpu/activation.cpp +index 1bd24eb..039b8d5 100644 +--- a/csrc/cpu/activation.cpp ++++ b/csrc/cpu/activation.cpp +@@ -1,10 +1,10 @@ + #include "cpu_types.hpp" + + namespace { +-template +-void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, +- scalar_t *__restrict__ output) { ++void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input, ++ scalar_t* __restrict__ output) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + +@@ -34,13 +34,13 @@ void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, + } + } + +-FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) { ++FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) { + const vec_op::FP32Vec8 zeros(0.0); + const vec_op::FP32Vec8 ones(1.0); + return x / (ones + (zeros - x).exp()); + } + +-FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { ++FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(0.79788456f); + const vec_op::FP32Vec8 w2(0.044715f); +@@ -50,7 +50,7 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { + return w3 * x * (ones + t); + } + +-FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { ++FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(0.79788456f); + const vec_op::FP32Vec8 w2(0.044715f); +@@ -59,14 +59,21 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { + return w3 * x * (ones + t); + } + +-FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) { ++FORCE_INLINE vec_op::FP32Vec8 gelu_quick_act(const vec_op::FP32Vec8& x) { ++ const vec_op::FP32Vec8 zeros(0.0); ++ const vec_op::FP32Vec8 ones(1.0); ++ const vec_op::FP32Vec8 w1(1.702f); ++ return x / (ones + (zeros - w1 * x).exp()); ++} ++ ++FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(M_SQRT1_2); + const vec_op::FP32Vec8 w2(0.5); + return x * w2 * (ones + (x * w1).er()); + } + +-FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { ++FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) { + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5); + const vec_op::FP32Vec8 w2(0.5); +@@ -75,40 +82,36 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { + const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3); + return x * w2 * (ones + inner.tanh()); + } +-}; // namespace ++}; // namespace + +-void silu_and_mul(torch::Tensor &out, torch::Tensor &input) { ++void silu_and_mul(torch::Tensor& out, torch::Tensor& input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; + +- VLLM_DISPATCH_FLOATING_TYPES( +- input.scalar_type(), "silu_and_mul_impl", [&] { +- CPU_KERNEL_GUARD_IN(silu_and_mul_impl) +- activation_kernel(num_tokens, d, +- input.data_ptr(), +- out.data_ptr()); +- CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) +- }); ++ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] { ++ CPU_KERNEL_GUARD_IN(silu_and_mul_impl) ++ activation_kernel( ++ num_tokens, d, input.data_ptr(), out.data_ptr()); ++ CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) ++ }); + } + +-void gelu_and_mul(torch::Tensor &out, // [..., d] +- torch::Tensor &input) // [..., 2 * d] ++void gelu_and_mul(torch::Tensor& out, // [..., d] ++ torch::Tensor& input) // [..., 2 * d] + { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; + +- VLLM_DISPATCH_FLOATING_TYPES( +- input.scalar_type(), "gelu_and_mul_impl", [&] { +- CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) +- activation_kernel(num_tokens, d, +- input.data_ptr(), +- out.data_ptr()); +- CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) +- }); ++ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] { ++ CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) ++ activation_kernel( ++ num_tokens, d, input.data_ptr(), out.data_ptr()); ++ CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) ++ }); + } + +-void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] +- torch::Tensor &input) // [..., 2 * d] ++void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] ++ torch::Tensor& input) // [..., 2 * d] + { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; +@@ -123,7 +126,7 @@ void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] + }); + } + +-void gelu_new(torch::Tensor &out, torch::Tensor &input) { ++void gelu_new(torch::Tensor& out, torch::Tensor& input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1); + +@@ -135,7 +138,7 @@ void gelu_new(torch::Tensor &out, torch::Tensor &input) { + }); + } + +-void gelu_fast(torch::Tensor &out, torch::Tensor &input) { ++void gelu_fast(torch::Tensor& out, torch::Tensor& input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1); + +@@ -146,3 +149,15 @@ void gelu_fast(torch::Tensor &out, torch::Tensor &input) { + CPU_KERNEL_GUARD_OUT(gelu_fast_impl) + }); + } ++ ++void gelu_quick(torch::Tensor& out, torch::Tensor& input) { ++ int num_tokens = input.numel() / input.size(-1); ++ int d = input.size(-1); ++ ++ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_quick_impl", [&] { ++ CPU_KERNEL_GUARD_IN(gelu_quick_impl) ++ activation_kernel( ++ num_tokens, d, input.data_ptr(), out.data_ptr()); ++ CPU_KERNEL_GUARD_OUT(gelu_quick_impl) ++ }); ++} +diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp +index c1d765b..ef5b140 100644 +--- a/csrc/cpu/attention.cpp ++++ b/csrc/cpu/attention.cpp +@@ -2,7 +2,8 @@ + + namespace { + +-template struct KernelVecType { ++template ++struct KernelVecType { + using q_load_vec_type = void; + using q_vec_type = void; + using k_load_vec_type = void; +@@ -11,7 +12,8 @@ template struct KernelVecType { + using v_load_vec_type = void; + }; + +-template <> struct KernelVecType { ++template <> ++struct KernelVecType { + using q_load_vec_type = vec_op::FP32Vec4; + using q_vec_type = vec_op::FP32Vec16; + using k_load_vec_type = vec_op::FP32Vec16; +@@ -20,8 +22,27 @@ template <> struct KernelVecType { + using v_load_vec_type = vec_op::FP32Vec16; + }; + ++template <> ++struct KernelVecType { ++#ifdef __powerpc64__ ++ // Power architecture-specific vector types ++ using q_load_vec_type = vec_op::FP32Vec8; ++ using k_load_vec_type = vec_op::FP32Vec16; ++ using v_load_vec_type = vec_op::FP32Vec16; ++#else ++ // Fallback for other architectures, including x86 ++ using q_load_vec_type = vec_op::FP16Vec8; ++ using k_load_vec_type = vec_op::FP16Vec16; ++ using v_load_vec_type = vec_op::FP16Vec16; ++#endif ++ using q_vec_type = vec_op::FP32Vec16; ++ using k_vec_type = vec_op::FP32Vec16; ++ using qk_acc_vec_type = vec_op::FP32Vec16; ++}; ++ + #ifdef __AVX512BF16__ +-template <> struct KernelVecType { ++template <> ++struct KernelVecType { + using q_load_vec_type = vec_op::BF16Vec8; + using q_vec_type = vec_op::BF16Vec32; + using k_load_vec_type = vec_op::BF16Vec32; +@@ -30,7 +51,12 @@ template <> struct KernelVecType { + using v_load_vec_type = vec_op::BF16Vec16; + }; + #else +-template <> struct KernelVecType { ++ #ifdef __aarch64__ ++ #ifndef ARM_BF16_SUPPORT ++ // pass ++ #else ++template <> ++struct KernelVecType { + using q_load_vec_type = vec_op::BF16Vec8; + using q_vec_type = vec_op::FP32Vec16; + using k_load_vec_type = vec_op::BF16Vec16; +@@ -38,10 +64,22 @@ template <> struct KernelVecType { + using qk_acc_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::BF16Vec16; + }; ++ #endif ++ #else ++template <> ++struct KernelVecType { ++ using q_load_vec_type = vec_op::BF16Vec8; ++ using q_vec_type = vec_op::FP32Vec16; ++ using k_load_vec_type = vec_op::BF16Vec16; ++ using k_vec_type = vec_op::FP32Vec16; ++ using qk_acc_vec_type = vec_op::FP32Vec16; ++ using v_load_vec_type = vec_op::BF16Vec16; ++}; ++ #endif + #endif + + template +-FORCE_INLINE std::pair reduceSoftmax(T *data, const int size, ++FORCE_INLINE std::pair reduceSoftmax(T* data, const int size, + const int capacity) { + T max = data[0]; + for (int i = 1; i < size; ++i) { +@@ -67,10 +105,11 @@ FORCE_INLINE std::pair reduceSoftmax(T *data, const int size, + } + + template +-FORCE_INLINE std::pair +-reduceSoftmaxAlibi(T *data, const int size, const int capacity, +- const float alibi_slope, const int start_index, +- const int seq_len) { ++FORCE_INLINE std::pair reduceSoftmaxAlibi(T* data, const int size, ++ const int capacity, ++ const float alibi_slope, ++ const int start_index, ++ const int seq_len) { + data[0] += alibi_slope * (start_index - seq_len + 1); + T max = data[0]; + for (int i = 1; i < size; ++i) { +@@ -98,7 +137,7 @@ reduceSoftmaxAlibi(T *data, const int size, const int capacity, + } + + template +-FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data, ++FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data, + const int size) { + T max = max_data[0]; + for (int i = 1; i < size; ++i) { +@@ -132,9 +171,9 @@ struct reduceQKBlockKernel { + static_assert(k_load_vec_type::get_elem_num() % x == 0); + static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); + +- FORCE_INLINE static void call(const scalar_t *__restrict__ q, +- const scalar_t *__restrict__ k_block, +- float *__restrict__ logits, float scale, ++ FORCE_INLINE static void call(const scalar_t* __restrict__ q, ++ const scalar_t* __restrict__ k_block, ++ float* __restrict__ logits, float scale, + const int token_num) { + const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; + +@@ -196,8 +235,8 @@ struct reduceQKBlockKernel { + + template +-FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, +- acc_t &&acc) { ++FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block, ++ acc_t&& acc) { + using v_load_vec_type = typename KernelVecType::v_load_vec_type; + constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); + static_assert(BLOCK_SIZE == ELEM_NUM); +@@ -209,27 +248,27 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, + acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; + }); + } +-}; // namespace ++}; // namespace + + // Paged attention v1 + namespace { + template + struct paged_attention_v1_impl { +- static void +- call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] +- const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] +- const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, ++ static void call( ++ scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] ++ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] ++ const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] +- const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, ++ const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] +- const int num_kv_heads, const float scale, +- const int +- *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] +- const int *__restrict__ seq_lens, // [num_seqs] +- const int max_num_blocks_per_seq, +- const float *__restrict__ alibi_slopes, // [num_heads] +- const int q_stride, const int kv_block_stride, const int kv_head_stride, +- const int num_seqs, const int num_heads) { ++ const int num_kv_heads, const float scale, ++ const int* __restrict__ block_tables, // [num_seqs, ++ // max_num_blocks_per_seq] ++ const int* __restrict__ seq_lens, // [num_seqs] ++ const int max_num_blocks_per_seq, ++ const float* __restrict__ alibi_slopes, // [num_heads] ++ const int q_stride, const int kv_block_stride, const int kv_head_stride, ++ const int num_seqs, const int num_heads) { + constexpr int x = 16 / sizeof(scalar_t); + const int num_queries_per_kv = num_heads / num_kv_heads; + +@@ -243,32 +282,31 @@ struct paged_attention_v1_impl { + + size_t logits_bytes = + parallel_work_item_num * max_seq_len_padded * sizeof(float); +- float *logits = (float *)std::aligned_alloc( +- 64, logits_bytes); // Cacheline alignment for each context token. +- // [parallel_work_item_num, max_seq_len_padded] ++ float* logits = (float*)std::aligned_alloc( ++ 64, logits_bytes); // Cacheline alignment for each context token. ++ // [parallel_work_item_num, max_seq_len_padded] + + #pragma omp parallel for collapse(2) schedule(dynamic, 1) + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + int seq_len = seq_lens[seq_idx]; +- const int *seq_block_table = ++ const int* seq_block_table = + block_tables + max_num_blocks_per_seq * seq_idx; + const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int64_t kv_head_idx = head_idx / num_queries_per_kv; +- const scalar_t *__restrict__ q_vec_ptr = ++ const scalar_t* __restrict__ q_vec_ptr = + q + seq_idx * q_stride + head_idx * HEAD_SIZE; +- const int last_block_token_num = +- seq_len - (block_num - 1) * BLOCK_SIZE; +- float *__restrict__ thread_block_logits = ++ const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE; ++ float* __restrict__ thread_block_logits = + logits + omp_get_thread_num() * max_seq_len_padded; + + // Compute logits + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; +- const scalar_t *__restrict__ k_block_cache_ptr = ++ const scalar_t* __restrict__ k_block_cache_ptr = + k_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride; +- float *__restrict__ head_block_logits = ++ float* __restrict__ head_block_logits = + thread_block_logits + block_idx * BLOCK_SIZE; + + reduceQKBlockKernel::call( +@@ -282,8 +320,7 @@ struct paged_attention_v1_impl { + block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, + seq_len); + } else { +- reduceSoftmax(thread_block_logits, seq_len, +- block_num * BLOCK_SIZE); ++ reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE); + } + + // Compute value +@@ -293,14 +330,14 @@ struct paged_attention_v1_impl { + for (int head_part_idx = 0; head_part_idx < head_partition_num; + ++head_part_idx) { + vec_op::FP32Vec16 accums[head_elem_num_per_partition]; +- scalar_t *__restrict__ out_ptr = ++ scalar_t* __restrict__ out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + + head_part_idx * head_elem_num_per_partition; + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; +- const float *__restrict__ prob_vec_ptr = ++ const float* __restrict__ prob_vec_ptr = + thread_block_logits + block_idx * BLOCK_SIZE; +- const scalar_t *__restrict__ v_block_cache_ptr = ++ const scalar_t* __restrict__ v_block_cache_ptr = + v_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; +@@ -311,7 +348,7 @@ struct paged_attention_v1_impl { + if (block_idx != block_num - 1) { + const int64_t next_physical_block_idx = + seq_block_table[block_idx + 1]; +- const scalar_t *__restrict__ next_v_block_cache_ptr = ++ const scalar_t* __restrict__ next_v_block_cache_ptr = + v_cache + next_physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; +@@ -340,16 +377,16 @@ struct paged_attention_v1_impl { + #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v1_impl::call( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ +- block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ ++ block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ + num_heads); + + template + void paged_attention_v1_impl_launcher( +- torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, +- torch::Tensor &value_cache, int num_kv_heads, float scale, +- torch::Tensor &block_tables, torch::Tensor &seq_lens, +- int max_seq_len, const c10::optional &alibi_slopes) { ++ torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, ++ torch::Tensor& value_cache, int num_kv_heads, float scale, ++ torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, ++ const std::optional& alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); +@@ -359,68 +396,77 @@ void paged_attention_v1_impl_launcher( + int kv_head_stride = key_cache.stride(1); + + // NOTE: alibi_slopes is optional. +- const float *alibi_slopes_ptr = ++ const float* alibi_slopes_ptr = + alibi_slopes +- ? reinterpret_cast(alibi_slopes.value().data_ptr()) ++ ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + +- T *out_ptr = reinterpret_cast(out.data_ptr()); +- T *query_ptr = reinterpret_cast(query.data_ptr()); +- T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); +- T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); +- int *block_tables_ptr = block_tables.data_ptr(); +- int *seq_lens_ptr = seq_lens.data_ptr(); ++ T* out_ptr = reinterpret_cast(out.data_ptr()); ++ T* query_ptr = reinterpret_cast(query.data_ptr()); ++ T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); ++ T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); ++ int* block_tables_ptr = block_tables.data_ptr(); ++ int* seq_lens_ptr = seq_lens.data_ptr(); + + switch (head_size) { +- case 64: +- LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); +- break; +- case 80: +- LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); +- break; +- case 96: +- LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); +- break; +- case 112: +- LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); +- break; +- case 128: +- LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); +- break; +- case 256: +- LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); +- break; +- default: +- TORCH_CHECK(false, "Unsupported head size: ", head_size); +- break; ++ case 32: ++ LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE); ++ break; ++ case 64: ++ LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); ++ break; ++ case 80: ++ LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); ++ break; ++ case 96: ++ LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); ++ break; ++ case 112: ++ LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); ++ break; ++ case 128: ++ LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); ++ break; ++ case 192: ++ LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); ++ break; ++ case 256: ++ LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); ++ break; ++ default: ++ TORCH_CHECK(false, "Unsupported head size: ", head_size); ++ break; + } + } + +-#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ +- paged_attention_v1_impl_launcher( \ +- out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ ++#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ ++ paged_attention_v1_impl_launcher( \ ++ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + seq_lens, max_seq_len, alibi_slopes); + +-#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ +- switch (block_size) { \ +- case 16: \ +- CALL_V1_KERNEL_LAUNCHER(T, 16); \ +- break; \ +- default: \ +- TORCH_CHECK(false, "Unsupported block size: ", block_size); \ +- break; \ ++#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ ++ switch (block_size) { \ ++ case 16: \ ++ CALL_V1_KERNEL_LAUNCHER(T, 16); \ ++ break; \ ++ default: \ ++ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ ++ break; \ + } +-} // namespace +- +-void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, +- torch::Tensor &key_cache, torch::Tensor &value_cache, +- int num_kv_heads, float scale, +- torch::Tensor &block_tables, +- torch::Tensor &seq_lens, int block_size, +- int max_seq_len, +- const c10::optional &alibi_slopes, +- const std::string &kv_cache_dtype, float kv_scale) { +- TORCH_CHECK(kv_scale == 1.0f); ++} // namespace ++ ++void paged_attention_v1( ++ torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, ++ torch::Tensor& value_cache, int64_t num_kv_heads, double scale, ++ torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, ++ int64_t max_seq_len, const std::optional& alibi_slopes, ++ const std::string& kv_cache_dtype, double k_scale, double v_scale, ++ const int64_t tp_rank, const int64_t blocksparse_local_blocks, ++ const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, ++ const int64_t blocksparse_head_sliding_step) { ++ TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); ++ TORCH_CHECK(blocksparse_vert_stride <= 1, ++ "CPU backend does not support blocksparse attention yet."); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", + [&] { + CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) +@@ -434,23 +480,24 @@ namespace { + template + struct paged_attention_v2_impl { + static void call( +- scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] +- float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] +- float +- *__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] +- scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads, +- // max_num_partitions, head_size] +- const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] +- const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, +- // head_size/x, block_size, x] +- const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, +- // head_size, block_size] ++ scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] ++ float* __restrict__ exp_sums, // [num_seqs, num_heads, ++ // max_num_partitions] ++ float* __restrict__ max_logits, // [num_seqs, num_heads, ++ // max_num_partitions] ++ scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, ++ // max_num_partitions, head_size] ++ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] ++ const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, ++ // head_size/x, block_size, x] ++ const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, ++ // head_size, block_size] + const int num_kv_heads, const float scale, +- const int +- *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] +- const int *__restrict__ seq_lens, // [num_seqs] ++ const int* __restrict__ block_tables, // [num_seqs, ++ // max_num_blocks_per_seq] ++ const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, +- const float *__restrict__ alibi_slopes, // [num_heads] ++ const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const int num_seqs, const int num_heads, const int max_num_partitions) { + constexpr int x = 16 / sizeof(scalar_t); +@@ -468,8 +515,7 @@ struct paged_attention_v2_impl { + const int seq_len = seq_lens[seq_idx]; + const int start_token_idx = partition_idx * PARTITION_SIZE; + +- if (start_token_idx >= seq_len) +- continue; ++ if (start_token_idx >= seq_len) continue; + + const int partition_num = + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; +@@ -477,15 +523,14 @@ struct paged_attention_v2_impl { + const int token_num = + (std::min(seq_len, start_token_idx + PARTITION_SIZE) - + start_token_idx); +- const int block_num = +- (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; ++ const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int last_block_token_num = + token_num - (block_num - 1) * BLOCK_SIZE; +- const int *seq_block_table = block_tables + ++ const int* seq_block_table = block_tables + + max_num_blocks_per_seq * seq_idx + + start_token_idx / BLOCK_SIZE; + const int64_t kv_head_idx = head_idx / num_queries_per_kv; +- const scalar_t *__restrict__ q_vec_ptr = ++ const scalar_t* __restrict__ q_vec_ptr = + q + seq_idx * q_stride + head_idx * HEAD_SIZE; + + float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; +@@ -493,10 +538,10 @@ struct paged_attention_v2_impl { + // Compute logits + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; +- const scalar_t *__restrict__ k_block_cache_ptr = ++ const scalar_t* __restrict__ k_block_cache_ptr = + k_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride; +- float *__restrict__ head_block_logits = ++ float* __restrict__ head_block_logits = + logits + block_idx * BLOCK_SIZE; + + reduceQKBlockKernel::call( +@@ -510,13 +555,13 @@ struct paged_attention_v2_impl { + logits, token_num, block_num * BLOCK_SIZE, + alibi_slopes[head_idx], start_token_idx, seq_len); + } else { +- max_and_sum = reduceSoftmax(logits, token_num, +- block_num * BLOCK_SIZE); ++ max_and_sum = ++ reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE); + } + +- auto &&[max_logit, exp_sum] = max_and_sum; ++ auto&& [max_logit, exp_sum] = max_and_sum; + +- scalar_t *__restrict__ output_buffer = nullptr; ++ scalar_t* __restrict__ output_buffer = nullptr; + if (!no_reduce) { + auto idx = seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; +@@ -538,13 +583,13 @@ struct paged_attention_v2_impl { + for (int head_part_idx = 0; head_part_idx < head_partition_num; + ++head_part_idx) { + vec_op::FP32Vec16 accums[head_elem_num_per_partition]; +- scalar_t *__restrict__ out_ptr = ++ scalar_t* __restrict__ out_ptr = + output_buffer + head_part_idx * head_elem_num_per_partition; + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; +- const float *__restrict__ prob_vec_ptr = ++ const float* __restrict__ prob_vec_ptr = + logits + block_idx * BLOCK_SIZE; +- const scalar_t *__restrict__ v_block_cache_ptr = ++ const scalar_t* __restrict__ v_block_cache_ptr = + v_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; +@@ -555,7 +600,7 @@ struct paged_attention_v2_impl { + if (block_idx != block_num - 1) { + const int64_t next_physical_block_idx = + seq_block_table[block_idx + 1]; +- const scalar_t *__restrict__ next_v_block_cache_ptr = ++ const scalar_t* __restrict__ next_v_block_cache_ptr = + v_cache + next_physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; +@@ -587,8 +632,7 @@ struct paged_attention_v2_impl { + const int partition_num = + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + +- if (partition_num == 1) +- continue; ++ if (partition_num == 1) continue; + + reducePartitonSoftmax( + max_logits + seq_idx * num_heads * max_num_partitions + +@@ -603,11 +647,11 @@ struct paged_attention_v2_impl { + using v_load_vec_type = typename KernelVecType::v_load_vec_type; + static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); + constexpr int head_elem_num_per_group = +- 16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE +- // didn't align with 64 bytes ++ 16; // Note: didn't align with the cacheline size, due to some ++ // HEAD_SIZE didn't align with 64 bytes + static_assert(HEAD_SIZE % head_elem_num_per_group == 0); + constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; +- const float *__restrict__ rescale_factors = exp_sums; ++ const float* __restrict__ rescale_factors = exp_sums; + #pragma omp parallel for collapse(3) schedule(static, 1) + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { +@@ -616,17 +660,16 @@ struct paged_attention_v2_impl { + const int partition_num = + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + +- if (partition_num == 1) +- continue; ++ if (partition_num == 1) continue; + +- const float *__restrict__ seq_head_rescale_factors = ++ const float* __restrict__ seq_head_rescale_factors = + rescale_factors + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; +- const scalar_t *__restrict__ seq_head_tmp_out = ++ const scalar_t* __restrict__ seq_head_tmp_out = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + group_idx * head_elem_num_per_group; +- scalar_t *__restrict__ seq_head_output = ++ scalar_t* __restrict__ seq_head_output = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + + group_idx * head_elem_num_per_group; + +@@ -645,21 +688,21 @@ struct paged_attention_v2_impl { + } + }; + +-#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ +- paged_attention_v2_impl::call( \ +- out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ +- key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ +- seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ +- kv_block_stride, kv_head_stride, num_seqs, num_heads, \ ++#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ ++ paged_attention_v2_impl::call( \ ++ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ ++ key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ ++ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ ++ kv_block_stride, kv_head_stride, num_seqs, num_heads, \ + max_num_partitions); + + template + void paged_attention_v2_impl_launcher( +- torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, +- torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, +- torch::Tensor &value_cache, int num_kv_heads, float scale, +- torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size, +- int max_seq_len, const c10::optional &alibi_slopes) { ++ torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, ++ torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, ++ torch::Tensor& value_cache, int num_kv_heads, float scale, ++ torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, ++ int max_seq_len, const std::optional& alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); +@@ -670,77 +713,86 @@ void paged_attention_v2_impl_launcher( + int max_num_partitions = exp_sums.size(-1); + + // NOTE: alibi_slopes is optional. +- const float *alibi_slopes_ptr = ++ const float* alibi_slopes_ptr = + alibi_slopes +- ? reinterpret_cast(alibi_slopes.value().data_ptr()) ++ ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + +- T *out_ptr = reinterpret_cast(out.data_ptr()); +- float *exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); +- float *max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); +- T *tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); +- T *query_ptr = reinterpret_cast(query.data_ptr()); +- T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); +- T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); +- int *block_tables_ptr = block_tables.data_ptr(); +- int *seq_lens_ptr = seq_lens.data_ptr(); ++ T* out_ptr = reinterpret_cast(out.data_ptr()); ++ float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); ++ float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); ++ T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); ++ T* query_ptr = reinterpret_cast(query.data_ptr()); ++ T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); ++ T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); ++ int* block_tables_ptr = block_tables.data_ptr(); ++ int* seq_lens_ptr = seq_lens.data_ptr(); + + switch (head_size) { +- case 64: +- LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); +- break; +- case 80: +- LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); +- break; +- case 96: +- LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); +- break; +- case 112: +- LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); +- break; +- case 128: +- LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); +- break; +- case 256: +- LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); +- break; +- default: +- TORCH_CHECK(false, "Unsupported head size: ", head_size); +- break; ++ case 32: ++ LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE); ++ break; ++ case 64: ++ LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); ++ break; ++ case 80: ++ LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); ++ break; ++ case 96: ++ LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); ++ break; ++ case 112: ++ LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); ++ break; ++ case 128: ++ LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); ++ break; ++ case 192: ++ LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); ++ break; ++ case 256: ++ LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); ++ break; ++ default: ++ TORCH_CHECK(false, "Unsupported head size: ", head_size); ++ break; + } + } + +-#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ +- paged_attention_v2_impl_launcher( \ +- out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ +- num_kv_heads, scale, block_tables, seq_lens, block_size, \ +- max_seq_len, alibi_slopes); +- +-#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ +- switch (block_size) { \ +- case 16: \ +- CALL_V2_KERNEL_LAUNCHER(T, 16); \ +- break; \ +- default: \ +- TORCH_CHECK(false, "Unsupported block size: ", block_size); \ +- break; \ ++#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ ++ paged_attention_v2_impl_launcher( \ ++ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ ++ num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \ ++ alibi_slopes); ++ ++#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ ++ switch (block_size) { \ ++ case 16: \ ++ CALL_V2_KERNEL_LAUNCHER(T, 16); \ ++ break; \ ++ default: \ ++ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ ++ break; \ + } +-} // namespace +- +-void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, +- torch::Tensor &max_logits, torch::Tensor &tmp_out, +- torch::Tensor &query, torch::Tensor &key_cache, +- torch::Tensor &value_cache, int num_kv_heads, +- float scale, torch::Tensor &block_tables, +- torch::Tensor &seq_lens, int block_size, +- int max_seq_len, +- const c10::optional &alibi_slopes, +- const std::string &kv_cache_dtype, float kv_scale) { +- TORCH_CHECK(kv_scale == 1.0f); ++} // namespace ++ ++void paged_attention_v2( ++ torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, ++ torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, ++ torch::Tensor& value_cache, int64_t num_kv_heads, double scale, ++ torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, ++ int64_t max_seq_len, const std::optional& alibi_slopes, ++ const std::string& kv_cache_dtype, double k_scale, double v_scale, ++ const int64_t tp_rank, const int64_t blocksparse_local_blocks, ++ const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, ++ const int64_t blocksparse_head_sliding_step) { ++ TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); ++ TORCH_CHECK(blocksparse_vert_stride <= 1, ++ "CPU backend does not support blocksparse attention yet."); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", + [&] { + CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) + CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); + CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl) + }); +-} ++} +\ No newline at end of file +diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp +index 7849a5d..31d4543 100644 +--- a/csrc/cpu/cache.cpp ++++ b/csrc/cpu/cache.cpp +@@ -5,25 +5,26 @@ + + namespace { + template +-void copy_blocks_cpu_impl( +- std::vector &key_caches, +- std::vector &value_caches, +- const std::vector> mapping_pairs, +- const int element_num_per_block, const int layer_num) { +- const size_t pair_num = mapping_pairs.size(); ++void copy_blocks_cpu_impl(std::vector const& key_caches, ++ std::vector const& value_caches, ++ const torch::Tensor& mapping_pairs, ++ const int element_num_per_block, ++ const int layer_num) { ++ const size_t pair_num = mapping_pairs.size(0); + const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; + #pragma omp parallel for collapse(2) + for (int layer = 0; layer < layer_num; ++layer) { + for (size_t pair = 0; pair < pair_num; ++pair) { +- int64_t source_offset = element_num_per_block * mapping_pairs[pair].first; ++ int64_t source_offset = ++ element_num_per_block * mapping_pairs[pair][0].item(); + int64_t target_offset = +- element_num_per_block * mapping_pairs[pair].second; +- scalar_t *key_cache_ptr = key_caches[layer].data_ptr(); +- scalar_t *source_ptr = key_cache_ptr + source_offset; +- scalar_t *target_ptr = key_cache_ptr + target_offset; ++ element_num_per_block * mapping_pairs[pair][1].item(); ++ scalar_t* key_cache_ptr = key_caches[layer].data_ptr(); ++ scalar_t* source_ptr = key_cache_ptr + source_offset; ++ scalar_t* target_ptr = key_cache_ptr + target_offset; + std::memcpy(target_ptr, source_ptr, block_bytes); + +- scalar_t *value_cache_ptr = value_caches[layer].data_ptr(); ++ scalar_t* value_cache_ptr = value_caches[layer].data_ptr(); + source_ptr = value_cache_ptr + source_offset; + target_ptr = value_cache_ptr + target_offset; + std::memcpy(target_ptr, source_ptr, block_bytes); +@@ -33,9 +34,9 @@ void copy_blocks_cpu_impl( + + template + void reshape_and_cache_cpu_impl( +- const scalar_t *__restrict__ key, const scalar_t *__restrict__ value, +- scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache, +- const int64_t *__restrict__ slot_mapping, const int num_tokens, ++ const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, ++ scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, ++ const int64_t* __restrict__ slot_mapping, const int num_tokens, + const int key_stride, const int value_stride, const int num_heads, + const int head_size, const int block_size, const int x) { + const int block_elem_num = num_heads * head_size * block_size; +@@ -48,14 +49,14 @@ void reshape_and_cache_cpu_impl( + int src_key_head_idx = token_idx * key_stride + head_idx * head_size; + int src_value_head_idx = + token_idx * value_stride + head_idx * head_size; +- const scalar_t *src_key_head_ptr = key + src_key_head_idx; +- const scalar_t *src_value_head_ptr = value + src_value_head_idx; ++ const scalar_t* src_key_head_ptr = key + src_key_head_idx; ++ const scalar_t* src_value_head_ptr = value + src_value_head_idx; + const int64_t block_index = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; +- scalar_t *target_key_head_ptr = key_cache + ++ scalar_t* target_key_head_ptr = key_cache + + block_elem_num * block_index + + head_idx * block_size * head_size; +- scalar_t *target_value_head_ptr = value_cache + ++ scalar_t* target_value_head_ptr = value_cache + + block_elem_num * block_index + + head_idx * block_size * head_size; + +@@ -79,40 +80,36 @@ void reshape_and_cache_cpu_impl( + } + } + } +-}; // namespace ++}; // namespace + +-void copy_blocks(std::vector &key_caches, +- std::vector &value_caches, +- const std::map> &block_mapping) { +- int num_layers = key_caches.size(); ++// Note: the key_caches and value_caches vectors are constant but ++// not the Tensors they contain. The vectors need to be const refs ++// in order to satisfy pytorch's C++ operator registration code. ++void copy_blocks(std::vector const& key_caches, ++ std::vector const& value_caches, ++ const torch::Tensor& block_mapping) { ++ unsigned num_layers = key_caches.size(); + TORCH_CHECK(num_layers == value_caches.size()); + if (num_layers == 0) { + return; + } + +- std::vector> mapping_pairs; +- mapping_pairs.reserve(block_mapping.size()); +- for (const auto &pair : block_mapping) { +- for (const auto &dst : pair.second) { +- mapping_pairs.emplace_back(pair.first, dst); +- } +- } +- + const int element_num_per_block = key_caches[0][0].numel(); + VLLM_DISPATCH_FLOATING_TYPES( + key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) +- copy_blocks_cpu_impl(key_caches, value_caches, mapping_pairs, ++ copy_blocks_cpu_impl(key_caches, value_caches, block_mapping, + element_num_per_block, num_layers); + CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) + }); + } + +-void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, +- torch::Tensor &key_cache, torch::Tensor &value_cache, +- torch::Tensor &slot_mapping, +- const std::string &kv_cache_dtype, float kv_scale) { +- TORCH_CHECK(kv_scale == 1.0f); ++void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, ++ torch::Tensor& key_cache, torch::Tensor& value_cache, ++ torch::Tensor& slot_mapping, ++ const std::string& kv_cache_dtype, double k_scale, ++ double v_scale) { ++ TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); + + int num_tokens = key.size(0); + int num_heads = key.size(1); +@@ -135,7 +132,7 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, + }); + } + +-void swap_blocks(torch::Tensor &src, torch::Tensor &dst, +- const std::map &block_mapping) { ++void swap_blocks(torch::Tensor& src, torch::Tensor& dst, ++ const torch::Tensor& block_mapping) { + TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") + } +diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp +index c1d3ec0..28db047 100644 +--- a/csrc/cpu/cpu_types.hpp ++++ b/csrc/cpu/cpu_types.hpp +@@ -1,352 +1,17 @@ +- + #ifndef CPU_TYPES_HPP + #define CPU_TYPES_HPP + +-#include +-#include +- +-namespace vec_op { +- +-// FIXME: FP16 is not fully supported in Torch-CPU +-#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ +- AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ +- AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) +- +-#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ +- AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +- +-#ifndef CPU_OP_GUARD +-#define CPU_KERNEL_GUARD_IN(NAME) +-#define CPU_KERNEL_GUARD_OUT(NAME) ++#if defined(__x86_64__) ++ //x86 implementation ++ #include "cpu_types_x86.hpp" ++#elif defined(__POWER9_VECTOR__) ++ //ppc implementation ++ #include "cpu_types_vsx.hpp" ++#elif defined(__aarch64__) ++ //arm implementation ++ #include "cpu_types_arm.hpp" + #else +-#define CPU_KERNEL_GUARD_IN(NAME) \ +- std::cout << #NAME << " invoked." << std::endl; +-#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; +-#endif +- +-#define FORCE_INLINE __attribute__((always_inline)) inline +- +-namespace { +-template +-constexpr void unroll_loop_item(std::integer_sequence, F &&f) { +- (f(std::integral_constant{}), ...); +-} +-}; // namespace +- +-template >> +-constexpr void unroll_loop(F &&f) { +- unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); +-} +- +-template struct Vec { +- constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } +-}; +- +-struct FP32Vec8; +-struct FP32Vec16; +- +-#ifdef __AVX512FP16__ +-struct FP16Vec8 : public Vec { +- constexpr static int VEC_ELEM_NUM = 8; +- +- __m128h reg; +- +- explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {} +- +- explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {} +- +- explicit FP16Vec8(__m128h data) : reg(data) {} +- +- FP16Vec8 operator*(const FP16Vec8 &b) const { +- return FP16Vec8(_mm_mul_ph(reg, b.reg)); +- } +- +- FP16Vec8 operator+(const FP16Vec8 &b) const { +- return FP16Vec8(_mm_add_ph(reg, b.reg)); +- } +- +- FP16Vec8 operator-(const FP16Vec8 &b) const { +- return FP16Vec8(_mm_sub_ph(reg, b.reg)); +- } +- +- FP16Vec8 operator/(const FP16Vec8 &b) const { +- return FP16Vec8(_mm_div_ph(reg, b.reg)); +- } +- +- void save(void *ptr) const { _mm_storeu_ph(ptr, reg); } +-}; ++ #warning "unsupported vLLM cpu implementation" + #endif + +-struct BF16Vec8 : public Vec { +- constexpr static int VEC_ELEM_NUM = 8; +- +- __m128i reg; +- +- explicit BF16Vec8(const void *ptr) +- : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} +- +- explicit BF16Vec8(const FP32Vec8 &); +- +- void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } +-}; +- +-struct BF16Vec16 : public Vec { +- constexpr static int VEC_ELEM_NUM = 16; +- +- __m256i reg; +- +- explicit BF16Vec16(const void *ptr) +- : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} +- +- explicit BF16Vec16(const FP32Vec16 &); +- +- void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } +-}; +- +-struct BF16Vec32 : public Vec { +- constexpr static int VEC_ELEM_NUM = 32; +- +- __m512i reg; +- +- explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {} +- +- explicit BF16Vec32(__m512i data) : reg(data) {} +- +- explicit BF16Vec32(BF16Vec8 &vec8_data) +- : reg((__m512i)_mm512_inserti32x4( +- _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512( +- (__m128i)vec8_data.reg), +- (__m128i)vec8_data.reg, 1), +- (__m128i)vec8_data.reg, 2), +- (__m128i)vec8_data.reg, 3)) {} +- +- void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; } +-}; +- +-struct FP32Vec4 : public Vec { +- constexpr static int VEC_ELEM_NUM = 4; +- union AliasReg { +- __m128 reg; +- float values[VEC_ELEM_NUM]; +- }; +- +- __m128 reg; +- +- explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {} +- +- explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {} +- +- explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {} +- +- explicit FP32Vec4(__m128 data) : reg(data) {} +- +- explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} +-}; +- +-struct FP32Vec8 : public Vec { +- constexpr static int VEC_ELEM_NUM = 8; +- union AliasReg { +- __m256 reg; +- float values[VEC_ELEM_NUM]; +- }; +- +- __m256 reg; +- +- explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {} +- +- explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {} +- +- explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {} +- +- explicit FP32Vec8(__m256 data) : reg(data) {} +- +- explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {} +- +-#ifdef __AVX512FP16__ +- explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {} +-#endif +- +- explicit FP32Vec8(const BF16Vec8 &v) +- : reg(_mm256_castsi256_ps( +- _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {} +- +- float reduce_sum() const { +- AliasReg ar; +- ar.reg = reg; +- float result = 0; +- unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); +- +- return result; +- } +- +- FP32Vec8 exp() const { +- AliasReg ar; +- ar.reg = reg; +- return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]), +- expf(ar.values[5]), expf(ar.values[4]), +- expf(ar.values[3]), expf(ar.values[2]), +- expf(ar.values[1]), expf(ar.values[0]))); +- } +- +- FP32Vec8 tanh() const { +- AliasReg ar; +- ar.reg = reg; +- return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]), +- tanhf(ar.values[5]), tanhf(ar.values[4]), +- tanhf(ar.values[3]), tanhf(ar.values[2]), +- tanhf(ar.values[1]), tanhf(ar.values[0]))); +- } +- +- FP32Vec8 er() const { +- AliasReg ar; +- ar.reg = reg; +- return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]), +- erf(ar.values[5]), erf(ar.values[4]), +- erf(ar.values[3]), erf(ar.values[2]), +- erf(ar.values[1]), erf(ar.values[0]))); +- } +- +- FP32Vec8 operator*(const FP32Vec8 &b) const { +- return FP32Vec8(_mm256_mul_ps(reg, b.reg)); +- } +- +- FP32Vec8 operator+(const FP32Vec8 &b) const { +- return FP32Vec8(_mm256_add_ps(reg, b.reg)); +- } +- +- FP32Vec8 operator-(const FP32Vec8 &b) const { +- return FP32Vec8(_mm256_sub_ps(reg, b.reg)); +- } +- +- FP32Vec8 operator/(const FP32Vec8 &b) const { +- return FP32Vec8(_mm256_div_ps(reg, b.reg)); +- } +- +- void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } +-}; +- +-struct FP32Vec16 : public Vec { +- constexpr static int VEC_ELEM_NUM = 16; +- union AliasReg { +- __m512 reg; +- float values[VEC_ELEM_NUM]; +- }; +- +- __m512 reg; +- +- explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {} +- +- explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} +- +- explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {} +- +- explicit FP32Vec16(__m512 data) : reg(data) {} +- +- explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} +- +- explicit FP32Vec16(const FP32Vec4 &data) +- : reg((__m512)_mm512_inserti32x4( +- _mm512_inserti32x4( +- _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg), +- (__m128i)data.reg, 1), +- (__m128i)data.reg, 2), +- (__m128i)data.reg, 3)) {} +- +- explicit FP32Vec16(const FP32Vec8 &data) +- : reg((__m512)_mm512_inserti32x8( +- _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {} +- +- explicit FP32Vec16(const BF16Vec16 &v) +- : reg(_mm512_castsi512_ps( +- _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {} +- +- explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} +- +- FP32Vec16 operator*(const FP32Vec16 &b) const { +- return FP32Vec16(_mm512_mul_ps(reg, b.reg)); +- } +- +- FP32Vec16 operator+(const FP32Vec16 &b) const { +- return FP32Vec16(_mm512_add_ps(reg, b.reg)); +- } +- +- FP32Vec16 operator-(const FP32Vec16 &b) const { +- return FP32Vec16(_mm512_sub_ps(reg, b.reg)); +- } +- +- FP32Vec16 operator/(const FP32Vec16 &b) const { +- return FP32Vec16(_mm512_div_ps(reg, b.reg)); +- } +- +- float reduce_sum() const { return _mm512_reduce_add_ps(reg); } +- +- template float reduce_sub_sum(int idx) { +- static_assert(VEC_ELEM_NUM % group_size == 0); +- constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); +- __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size)); +- return _mm512_mask_reduce_add_ps(mask, reg); +- } +- +- void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } +-}; +- +-template struct VecType { using vec_type = void; }; +- +-template using vec_t = typename VecType::vec_type; +- +-template <> struct VecType { using vec_type = FP32Vec8; }; +- +-#ifdef __AVX512FP16__ +-template <> struct VecType { using vec_type = FP16Vec16; }; +-#endif +- +-template <> struct VecType { using vec_type = BF16Vec8; }; +- +-template void storeFP32(float v, T *ptr) { *ptr = v; } +- +-#ifdef __AVX512FP16__ +-template <> inline void storeFP32(float v, c10::Half *ptr) { +- *reinterpret_cast<_Float16 *>(ptr) = v; +-} +-#endif +- +-inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { +- acc = acc + a * b; +-} +- +-#ifdef __AVX512BF16__ +-template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { +- *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v); +-} +- +-inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) +- : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {} +- +-inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) +- : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {} +- +-inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { +- acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg); +-} +-#else +-template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { +- c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = +- reinterpret_cast(&v); +- *ptr = *(v_ptr + 1); +-} +- +-inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) +- : reg(_mm256_cvtepi32_epi16( +- _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {} +- +-inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) +- : reg(_mm512_cvtepi32_epi16( +- _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {} +-#endif +- +-inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); } +- +-}; // namespace vec_op +- +-#endif ++#endif +\ No newline at end of file +diff --git a/csrc/cpu/cpu_types_arm.hpp b/csrc/cpu/cpu_types_arm.hpp +new file mode 100644 +index 0000000..ae062a5 +--- /dev/null ++++ b/csrc/cpu/cpu_types_arm.hpp +@@ -0,0 +1,572 @@ ++#include ++#include ++#include ++ ++namespace vec_op { ++ ++#ifdef ARM_BF16_SUPPORT ++ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ ++ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ ++ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ ++ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) ++#else ++ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ ++ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ ++ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) ++#endif ++ ++#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ ++ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) ++ ++#ifndef CPU_OP_GUARD ++#define CPU_KERNEL_GUARD_IN(NAME) ++#define CPU_KERNEL_GUARD_OUT(NAME) ++#else ++#define CPU_KERNEL_GUARD_IN(NAME) \ ++ std::cout << #NAME << " invoked." << std::endl; ++#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; ++#endif ++ ++#define FORCE_INLINE __attribute__((always_inline)) inline ++ ++namespace { ++ template ++ constexpr void unroll_loop_item(std::integer_sequence, F &&f) { ++ (f(std::integral_constant{}), ...); ++ }; ++}; ++ ++template >> ++constexpr void unroll_loop(F &&f) { ++ unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); ++} ++ ++template struct Vec { ++ constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }; ++}; ++ ++struct FP32Vec8; ++struct FP32Vec16; ++ ++struct FP16Vec8 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 8; ++ ++ float16x8_t reg; ++ ++ explicit FP16Vec8(const void *ptr) ++ : reg(vld1q_f16(static_cast(ptr))) {}; ++ ++ explicit FP16Vec8(const FP32Vec8 &); ++ ++ void save(void *ptr) const { ++ vst1q_f16(static_cast<__fp16 *>(ptr), reg); ++ } ++}; ++ ++struct FP16Vec16 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 16; ++ ++ float16x8x2_t reg; ++ ++ explicit FP16Vec16(const void *ptr) { ++ reg.val[0] = vld1q_f16(reinterpret_cast(ptr)); ++ reg.val[1] = vld1q_f16(reinterpret_cast(ptr) + 8); ++ } ++ ++ explicit FP16Vec16(const FP32Vec16& vec); ++ ++ void save(void *ptr) const { ++ vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]); ++ vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]); ++ } ++ ++ void save(void *ptr, const int elem_num) const { ++ int full_blocks = elem_num / 8; ++ int remainder = elem_num % 8; ++ ++ if (full_blocks > 0) { ++ vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]); ++ if (full_blocks > 1) { ++ vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]); ++ } ++ } ++ ++ // Note: below is the unrolled version of the following code: ++ // ++ // for (int i = 0; i < remainder; ++i) { ++ // reinterpret_cast<__fp16*>(ptr)[full_blocks * 8 + i] = ++ // vgetq_lane_f16(temp, i); ++ // } ++ // ++ // For macOS build (Clang), the arm/neon intrinsics function ++ // `vgetq_lane_f16` needs the parameter `i` to be constant at compile ++ // time. ++ ++ if (remainder > 0) { ++ float16x8_t temp = reg.val[full_blocks]; ++ __fp16* fp16_ptr = reinterpret_cast<__fp16*>(ptr); ++ switch (remainder) ++ { ++ case 1: ++ fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); ++ break; ++ case 2: ++ fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); ++ fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); ++ break; ++ case 3: ++ fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); ++ fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); ++ fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2); ++ break; ++ case 4: ++ fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); ++ fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); ++ fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2); ++ fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3); ++ break; ++ case 5: ++ fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); ++ fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); ++ fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2); ++ fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3); ++ fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4); ++ break; ++ case 6: ++ fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); ++ fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); ++ fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2); ++ fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3); ++ fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4); ++ fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5); ++ break; ++ case 7: ++ fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); ++ fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); ++ fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2); ++ fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3); ++ fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4); ++ fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5); ++ fp16_ptr[full_blocks * 8 + 6] = vgetq_lane_f16(temp, 6); ++ break; ++ ++ default: ++ break; ++ } ++ } ++ } ++}; ++ ++ ++#ifdef ARM_BF16_SUPPORT ++struct BF16Vec8 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 8; ++ ++ bfloat16x8_t reg; ++ ++ explicit BF16Vec8(const void *ptr) ++ : reg(*reinterpret_cast(ptr)) {}; ++ ++ explicit BF16Vec8(bfloat16x8_t data) : reg(data) {}; ++ ++ explicit BF16Vec8(const FP32Vec8 &); ++ ++ explicit BF16Vec8(float32x4x2_t v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {}; ++ ++ void save(void *ptr) const { *reinterpret_cast(ptr) = reg; } ++}; ++ ++struct BF16Vec16 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 16; ++ ++ bfloat16x8x2_t reg; ++ ++ explicit BF16Vec16(const void *ptr) ++ : reg(*reinterpret_cast(ptr)) {}; ++ ++ explicit BF16Vec16(bfloat16x8x2_t data) : reg(data) {}; ++ ++ explicit BF16Vec16(const FP32Vec16 &); ++ ++ explicit BF16Vec16(float32x4x4_t v) : reg({ ++ vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]), ++ vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3]) ++ }){}; ++ ++ void save(void *ptr) const { *reinterpret_cast(ptr) = reg; }; ++}; ++ ++struct BF16Vec32 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 32; ++ ++ bfloat16x8x4_t reg; ++ ++ explicit BF16Vec32(const void *ptr) ++ : reg(*reinterpret_cast(ptr)) {}; ++ ++ explicit BF16Vec32(bfloat16x8x4_t data) : reg(data) {}; ++ ++ explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({ ++ vec8_data.reg, ++ vec8_data.reg, ++ vec8_data.reg, ++ vec8_data.reg ++ }) {}; ++ ++ void save(void *ptr) const { *reinterpret_cast(ptr) = reg; }; ++}; ++#endif ++ ++struct FP32Vec4 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 4; ++ ++ union AliasReg { ++ float32x4_t reg; ++ float values[VEC_ELEM_NUM]; ++ }; ++ ++ float32x4_t reg; ++ ++ explicit FP32Vec4(float v) : reg(vdupq_n_f32(v)) {}; ++ ++ explicit FP32Vec4() : reg(vdupq_n_f32(0.0f)) {}; ++ ++ explicit FP32Vec4(const float *ptr) : reg(vld1q_f32(ptr)) {}; ++ ++ explicit FP32Vec4(float32x4_t data) : reg(data) {}; ++ ++ explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}; ++}; ++ ++struct FP32Vec8 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 8; ++ union AliasReg { ++ float32x4x2_t reg; ++ float values[VEC_ELEM_NUM]; ++ }; ++ ++ float32x4x2_t reg; ++ ++ explicit FP32Vec8(float v) : reg({vmovq_n_f32(v), vmovq_n_f32(v)}) {}; ++ ++ explicit FP32Vec8() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {}; ++ ++ explicit FP32Vec8(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {}; ++ ++ explicit FP32Vec8(float32x4x2_t data) : reg(data) {}; ++ ++ explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}; ++ ++ explicit FP32Vec8(const FP16Vec8 &v) { ++ reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg)); ++ reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg)); ++ }; ++ ++ explicit FP32Vec8(float16x8_t v) : reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {}; ++ ++ #ifdef ARM_BF16_SUPPORT ++ ++ explicit FP32Vec8(bfloat16x8_t v) : reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {}; ++ ++ explicit FP32Vec8(const BF16Vec8 &v) : reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {}; ++ ++ #endif ++ ++ float reduce_sum() const { ++ AliasReg ar; ++ ar.reg = reg; ++ float answer = 0; ++ unroll_loop([&answer, &ar](int i) { answer += ar.values[i]; }); ++ ++ return answer; ++ } ++ ++ FP32Vec8 exp() const { ++ AliasReg ar; ++ ar.reg = reg; ++ ++ float32x2_t exp_vec0 = {expf(ar.values[0]), expf(ar.values[1])}; ++ float32x2_t exp_vec1 = {expf(ar.values[2]), expf(ar.values[3])}; ++ float32x2_t exp_vec2 = {expf(ar.values[4]), expf(ar.values[5])}; ++ float32x2_t exp_vec3 = {expf(ar.values[6]), expf(ar.values[7])}; ++ ++ float32x4_t result0 = vcombine_f32(exp_vec0, exp_vec1); ++ float32x4_t result1 = vcombine_f32(exp_vec2, exp_vec3); ++ ++ float32x4x2_t result; ++ result.val[0] = result0; ++ result.val[1] = result1; ++ ++ return FP32Vec8(result); ++ } ++ ++ FP32Vec8 tanh() const { ++ AliasReg ar; ++ ar.reg = reg; ++ ++ float32x2_t tanh_vec0 = {tanhf(ar.values[0]), tanhf(ar.values[1])}; ++ float32x2_t tanh_vec1 = {tanhf(ar.values[2]), tanhf(ar.values[3])}; ++ float32x2_t tanh_vec2 = {tanhf(ar.values[4]), tanhf(ar.values[5])}; ++ float32x2_t tanh_vec3 = {tanhf(ar.values[6]), tanhf(ar.values[7])}; ++ ++ float32x4_t result0 = vcombine_f32(tanh_vec0, tanh_vec1); ++ float32x4_t result1 = vcombine_f32(tanh_vec2, tanh_vec3); ++ ++ float32x4x2_t result; ++ result.val[0] = result0; ++ result.val[1] = result1; ++ ++ return FP32Vec8(result); ++ } ++ ++ FP32Vec8 er() const { ++ AliasReg ar; ++ ar.reg = reg; ++ ++ float32x2_t er_vec0 = {static_cast(erf(ar.values[0])), static_cast(erf(ar.values[1]))}; ++ float32x2_t er_vec1 = {static_cast(erf(ar.values[2])), static_cast(erf(ar.values[3]))}; ++ float32x2_t er_vec2 = {static_cast(erf(ar.values[4])), static_cast(erf(ar.values[5]))}; ++ float32x2_t er_vec3 = {static_cast(erf(ar.values[6])), static_cast(erf(ar.values[7]))}; ++ ++ float32x4_t result0 = vcombine_f32(er_vec0, er_vec1); ++ float32x4_t result1 = vcombine_f32(er_vec2, er_vec3); ++ ++ float32x4x2_t result; ++ result.val[0] = result0; ++ result.val[1] = result1; ++ ++ return FP32Vec8(result); ++ } ++ ++ FP32Vec8 operator*(const FP32Vec8 &b) const { ++ return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]), vmulq_f32(reg.val[1], b.reg.val[1])})); ++ } ++ ++ FP32Vec8 operator+(const FP32Vec8 &b) const { ++ return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]), vaddq_f32(reg.val[1], b.reg.val[1])})); ++ } ++ ++ FP32Vec8 operator-(const FP32Vec8 &b) const { ++ return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]), vsubq_f32(reg.val[1], b.reg.val[1])})); ++ } ++ ++ FP32Vec8 operator/(const FP32Vec8 &b) const { ++ return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]), vdivq_f32(reg.val[1], b.reg.val[1])})); ++ } ++ ++ void save(float *ptr) const { ++ vst1q_f32(ptr, reg.val[0]); ++ vst1q_f32(ptr + 4, reg.val[1]); ++ } ++}; ++ ++struct FP32Vec16 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 16; ++ union AliasReg { ++ float32x4x4_t reg; ++ float values[VEC_ELEM_NUM]; ++ }; ++ ++ float32x4x4_t reg; ++ ++ explicit FP32Vec16(float v) : reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {} ++ ++ explicit FP32Vec16() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {} ++ ++ explicit FP32Vec16(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8), vld1q_f32(ptr + 12)}) {} ++ ++ explicit FP32Vec16(float32x4x4_t data) : reg(data) {} ++ ++ explicit FP32Vec16(const FP32Vec8 &data) { ++ reg.val[0] = data.reg.val[0]; ++ reg.val[1] = data.reg.val[1]; ++ reg.val[2] = data.reg.val[0]; ++ reg.val[3] = data.reg.val[1]; ++ } ++ ++ explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} ++ ++ explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v.reg)) {} ++ ++ #ifdef ARM_BF16_SUPPORT ++ explicit FP32Vec16(bfloat16x8x2_t v) : reg({ ++ vcvtq_low_f32_bf16(v.val[0]), ++ vcvtq_high_f32_bf16(v.val[0]), ++ vcvtq_low_f32_bf16(v.val[1]), ++ vcvtq_high_f32_bf16(v.val[1]) ++ }) {}; ++ #endif ++ ++ explicit FP32Vec16(const FP32Vec4 &data) { ++ reg.val[0] = data.reg; ++ reg.val[1] = data.reg; ++ reg.val[2] = data.reg; ++ reg.val[3] = data.reg; ++ }; ++ ++ #ifdef ARM_BF16_SUPPORT ++ explicit FP32Vec16(const BF16Vec16 &v) : reg({ ++ vcvtq_low_f32_bf16(v.reg.val[0]), ++ vcvtq_high_f32_bf16(v.reg.val[0]), ++ vcvtq_low_f32_bf16(v.reg.val[1]), ++ vcvtq_high_f32_bf16(v.reg.val[1]) ++ }) {}; ++ ++ explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}; ++ #endif ++ ++ explicit FP32Vec16(const FP16Vec16 &v) { ++ reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0])); ++ reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0])); ++ reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1])); ++ reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1])); ++ }; ++ ++ FP32Vec16 operator+(const FP32Vec16 &b) const { ++ return FP32Vec16(float32x4x4_t({ ++ vaddq_f32(reg.val[0], b.reg.val[0]), ++ vaddq_f32(reg.val[1], b.reg.val[1]), ++ vaddq_f32(reg.val[2], b.reg.val[2]), ++ vaddq_f32(reg.val[3], b.reg.val[3])})); ++ }; ++ ++ FP32Vec16 operator*(const FP32Vec16 &b) const { ++ return FP32Vec16(float32x4x4_t({ ++ vmulq_f32(reg.val[0], b.reg.val[0]), ++ vmulq_f32(reg.val[1], b.reg.val[1]), ++ vmulq_f32(reg.val[2], b.reg.val[2]), ++ vmulq_f32(reg.val[3], b.reg.val[3])})); ++ }; ++ ++ FP32Vec16 operator-(const FP32Vec16 &b) const { ++ return FP32Vec16(float32x4x4_t({ ++ vsubq_f32(reg.val[0], b.reg.val[0]), ++ vsubq_f32(reg.val[1], b.reg.val[1]), ++ vsubq_f32(reg.val[2], b.reg.val[2]), ++ vsubq_f32(reg.val[3], b.reg.val[3]) ++ })); ++ }; ++ ++ FP32Vec16 operator/(const FP32Vec16 &b) const { ++ return FP32Vec16(float32x4x4_t({ ++ vdivq_f32(reg.val[0], b.reg.val[0]), ++ vdivq_f32(reg.val[1], b.reg.val[1]), ++ vdivq_f32(reg.val[2], b.reg.val[2]), ++ vdivq_f32(reg.val[3], b.reg.val[3]) ++ })); ++ }; ++ ++ float reduce_sum() const { ++ AliasReg ar; ++ ar.reg = reg; ++ float answer = 0; ++ unroll_loop([&answer, &ar](int i) { answer += ar.values[i]; }); ++ ++ return answer; ++ }; ++ ++ template float reduce_sub_sum(int idx) { ++ static_assert(VEC_ELEM_NUM % group_size == 0); ++ ++ AliasReg ar; ++ ar.reg = reg; ++ float answer = 0; ++ const int start = idx * group_size; ++ unroll_loop( ++ [&answer, &start, ar](int i) { answer += ar.values[start + i]; }); ++ ++ return answer; ++ }; ++ ++ void save(float *ptr) const { ++ vst1q_f32(ptr, reg.val[0]); ++ vst1q_f32(ptr + 4, reg.val[1]); ++ vst1q_f32(ptr + 8, reg.val[2]); ++ vst1q_f32(ptr + 12, reg.val[3]); ++ }; ++}; ++ ++template struct VecType { using vec_type = void; }; ++ ++template using vec_t = typename VecType::vec_type; ++ ++template <> struct VecType { using vec_type = FP32Vec8; }; ++ ++template <> struct VecType { using vec_type = FP16Vec8; }; ++ ++#ifdef ARM_BF16_SUPPORT ++template <> struct VecType { using vec_type = BF16Vec8; }; ++#endif ++ ++template void storeFP32(float v, T *ptr) { *ptr = v; } ++ ++template <> inline void storeFP32(float v, c10::Half *ptr) { ++ *reinterpret_cast<__fp16 *>(ptr) = v; ++} ++ ++inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) { ++ float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]); ++ float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]); ++ float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]); ++ float16x4_t high_1 = vcvt_f16_f32(v.reg.val[3]); ++ ++ reg.val[0] = vcombine_f16(low_0, high_0); ++ reg.val[1] = vcombine_f16(low_1, high_1); ++}; ++ ++inline FP16Vec8 :: FP16Vec8(const FP32Vec8 &v) { ++ float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]); ++ float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]); ++ ++ reg = vcombine_f16(lower_half, upper_half); ++}; ++ ++inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { ++ ++ acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a.reg.val[0], b.reg.val[0]); ++ acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a.reg.val[1], b.reg.val[1]); ++ acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a.reg.val[2], b.reg.val[2]); ++ acc.reg.val[3] = vfmaq_f32(acc.reg.val[3], a.reg.val[3], b.reg.val[3]); ++}; ++ ++#ifdef ARM_BF16_SUPPORT ++inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { ++ ++ float32x4_t a0_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[0])); ++ float32x4_t a0_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[0])); ++ float32x4_t a1_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[1])); ++ float32x4_t a1_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[1])); ++ ++ float32x4_t b0_low = vcvt_f32_bf16(vget_low_bf16(b.reg.val[0])); ++ float32x4_t b0_high = vcvt_f32_bf16(vget_high_bf16(b.reg.val[0])); ++ float32x4_t b1_low = vcvt_f32_bf16(vget_low_bf16(b.reg.val[1])); ++ float32x4_t b1_high = vcvt_f32_bf16(vget_high_bf16(b.reg.val[1])); ++ ++ acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a0_low, b0_low); ++ acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a0_high, b0_high); ++ acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a1_low, b1_low); ++ acc.reg.val[3] = vfmaq_f32(acc.reg.val[3], a1_high, b1_high); ++}; ++#endif ++ ++#ifdef ARM_BF16_SUPPORT ++inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {}; ++ ++inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) : reg({ ++ vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]), ++ vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]), v.reg.val[3]) ++ }){}; ++#endif ++ ++inline void prefetch(const void *addr) { ++ __builtin_prefetch(addr, 0, 1); ++}; ++ ++#ifdef ARM_BF16_SUPPORT ++template <> ++inline void storeFP32(float v, c10::BFloat16 *ptr) { ++ *reinterpret_cast<__bf16 *>(ptr) = vcvth_bf16_f32(v); ++}; ++#endif ++}; +\ No newline at end of file +diff --git a/csrc/cpu/cpu_types_vsx.hpp b/csrc/cpu/cpu_types_vsx.hpp +new file mode 100644 +index 0000000..b50bdad +--- /dev/null ++++ b/csrc/cpu/cpu_types_vsx.hpp +@@ -0,0 +1,491 @@ ++ ++#ifndef CPU_TYPES_VSX_HPP ++#define CPU_TYPES_VSX_HPP ++ ++#include ++#include ++#include ++ ++namespace vec_op { ++ ++// FIXME: FP16 is not fully supported in Torch-CPU ++#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ ++ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ ++ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) ++ ++#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ ++ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) ++ ++#ifndef CPU_OP_GUARD ++#define CPU_KERNEL_GUARD_IN(NAME) ++#define CPU_KERNEL_GUARD_OUT(NAME) ++#else ++#define CPU_KERNEL_GUARD_IN(NAME) \ ++ std::cout << #NAME << " invoked." << std::endl; ++#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; ++#endif ++ ++#define FORCE_INLINE __attribute__((always_inline)) inline ++ ++namespace { ++template ++constexpr void unroll_loop_item(std::integer_sequence, F &&f) { ++ (f(std::integral_constant{}), ...); ++} ++}; // namespace ++ ++template >> ++constexpr void unroll_loop(F &&f) { ++ unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); ++} ++ ++template struct Vec { ++ constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } ++}; ++ ++typedef struct ss16x8x2_t { ++ __vector signed short val[2]; ++} ss16x8x2_t; ++ ++typedef struct ss16x8x4_t { ++ __vector signed short val[4]; ++} ss16x8x4_t; ++ ++typedef struct f32x4x2_t { ++ __vector float val[2]; ++} f32x4x2_t; ++ ++typedef struct f32x4x4_t { ++ __vector float val[4]; ++} f32x4x4_t; ++ ++struct FP32Vec8; ++struct FP32Vec16; ++ ++struct BF16Vec8 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 8; ++ ++ __vector signed short reg; ++ ++ explicit BF16Vec8(const void *ptr) ++ : reg((__vector signed short)vec_xl(0, (__vector signed short *)ptr)) {} ++ ++ explicit BF16Vec8(const FP32Vec8 &); ++ ++ void save(void *ptr) const { *reinterpret_cast<__vector signed short *>(ptr) = reg; } ++}; ++ ++struct BF16Vec16 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 16; ++ ++ ss16x8x2_t reg; ++ ++ explicit BF16Vec16(const void *ptr) { ++ // Load 256 bits in two parts ++ reg.val[0] = (__vector signed short)vec_xl(0, (signed short *)ptr); ++ reg.val[1] = (__vector signed short)vec_xl(16, (signed short *)ptr); ++ } ++ ++ explicit BF16Vec16(const FP32Vec16 &); ++ ++ void save(void *ptr) const { ++ // Save 256 bits in two parts ++ vec_xst(reg.val[0], 0, (signed short *)ptr); ++ vec_xst(reg.val[1], 16, (signed short *)ptr); ++ } ++}; ++ ++const static __vector signed short zero = vec_splats((signed short)0); ++ ++struct BF16Vec32 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 32; ++ ++ ss16x8x4_t reg; ++ explicit BF16Vec32(const void *ptr) ++ : reg(*reinterpret_cast(ptr)) {} ++ ++ explicit BF16Vec32(ss16x8x4_t data) : reg(data) {} ++ ++ explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({ ++ vec8_data.reg, ++ vec8_data.reg, ++ vec8_data.reg, ++ vec8_data.reg ++ }) {} ++ ++ void save(void *ptr) const { *reinterpret_cast(ptr) = reg; } ++}; ++ ++struct FP32Vec4 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 4; ++ union AliasReg { ++ __vector float reg; ++ float values[VEC_ELEM_NUM]; ++ }; ++ ++ __vector float reg; ++ ++ explicit FP32Vec4(float v) : reg(vec_splats(v)) {} ++ ++ explicit FP32Vec4() : reg(vec_splats(0.0f)) {} ++ ++ explicit FP32Vec4(const float *ptr) : reg(vec_xl(0, ptr)) {} ++ ++ explicit FP32Vec4(__vector float data) : reg(data) {} ++ ++ explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} ++}; ++ ++struct FP32Vec8 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 8; ++ union AliasReg { ++ f32x4x2_t reg; ++ float values[VEC_ELEM_NUM]; ++ }; ++ ++ f32x4x2_t reg; ++ ++ explicit FP32Vec8(float v) { ++ reg.val[0] = vec_splats(v); ++ reg.val[1] = vec_splats(v); ++ } ++ ++ explicit FP32Vec8() { ++ reg.val[0] = vec_splats(0.0f); ++ reg.val[1] = vec_splats(0.0f); ++ } ++ ++ explicit FP32Vec8(const float *ptr) { ++ reg.val[0] = vec_xl(0, ptr); ++ reg.val[1] = vec_xl(16, ptr); ++ } ++ ++ explicit FP32Vec8(f32x4x2_t data) : reg(data) {} ++ ++ explicit FP32Vec8(const FP32Vec8 &data) { ++ reg.val[0] = data.reg.val[0]; ++ reg.val[1] = data.reg.val[1]; ++ } ++ ++ explicit FP32Vec8(const BF16Vec8 &v) { ++ reg.val[0] = (__vector float)vec_mergeh(zero, v.reg); ++ reg.val[1] = (__vector float)vec_mergel(zero, v.reg); ++ } ++ ++ float reduce_sum() const { ++ AliasReg ar; ++ ar.reg = reg; ++ float result = 0; ++ unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); ++ ++ return result; ++ } ++ ++ FP32Vec8 exp() const { ++ // TODO: Vectorize this ++ AliasReg ar; ++ ar.reg = reg; ++ f32x4x4_t ret; ++ ret.val[0][0] = std::exp(ar.values[0]); ++ ret.val[0][1] = std::exp(ar.values[1]); ++ ret.val[0][2] = std::exp(ar.values[2]); ++ ret.val[0][3] = std::exp(ar.values[3]); ++ ret.val[1][0] = std::exp(ar.values[4]); ++ ret.val[1][1] = std::exp(ar.values[5]); ++ ret.val[1][2] = std::exp(ar.values[6]); ++ ret.val[1][3] = std::exp(ar.values[7]); ++ return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); ++ } ++ ++ FP32Vec8 tanh() const { ++ // TODO: Vectorize this ++ AliasReg ar; ++ ar.reg = reg; ++ f32x4x4_t ret; ++ ret.val[0][0] = std::tanh(ar.values[0]); ++ ret.val[0][1] = std::tanh(ar.values[1]); ++ ret.val[0][2] = std::tanh(ar.values[2]); ++ ret.val[0][3] = std::tanh(ar.values[3]); ++ ret.val[1][0] = std::tanh(ar.values[4]); ++ ret.val[1][1] = std::tanh(ar.values[5]); ++ ret.val[1][2] = std::tanh(ar.values[6]); ++ ret.val[1][3] = std::tanh(ar.values[7]); ++ return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); ++ } ++ ++ FP32Vec8 er() const { ++ // TODO: Vectorize this ++ AliasReg ar; ++ ar.reg = reg; ++ f32x4x4_t ret; ++ ret.val[0][0] = std::erf(ar.values[0]); ++ ret.val[0][1] = std::erf(ar.values[1]); ++ ret.val[0][2] = std::erf(ar.values[2]); ++ ret.val[0][3] = std::erf(ar.values[3]); ++ ret.val[1][0] = std::erf(ar.values[4]); ++ ret.val[1][1] = std::erf(ar.values[5]); ++ ret.val[1][2] = std::erf(ar.values[6]); ++ ret.val[1][3] = std::erf(ar.values[7]); ++ return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); ++ } ++ ++ FP32Vec8 operator*(const FP32Vec8 &b) const { ++ return FP32Vec8({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])}); ++ } ++ ++ FP32Vec8 operator+(const FP32Vec8 &b) const { ++ return FP32Vec8({vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])}); ++ } ++ ++ FP32Vec8 operator-(const FP32Vec8 &b) const { ++ return FP32Vec8({vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])}); ++ } ++ ++ FP32Vec8 operator/(const FP32Vec8 &b) const { ++ return FP32Vec8({vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])}); ++ } ++ ++ void save(float *ptr) const { ++ vec_xst(reg.val[0], 0, ptr); ++ vec_xst(reg.val[1], 16, ptr); ++ } ++}; ++ ++struct FP32Vec16 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 16; ++ union AliasReg { ++ f32x4x4_t reg; ++ float values[VEC_ELEM_NUM]; ++ }; ++ ++ f32x4x4_t reg; ++ ++ explicit FP32Vec16(float v) { ++ reg.val[0] = vec_splats(v); ++ reg.val[1] = vec_splats(v); ++ reg.val[2] = vec_splats(v); ++ reg.val[3] = vec_splats(v); ++ } ++ ++ explicit FP32Vec16() { ++ reg.val[0] = vec_splats(0.0f); ++ reg.val[1] = vec_splats(0.0f); ++ reg.val[2] = vec_splats(0.0f); ++ reg.val[3] = vec_splats(0.0f); ++ } ++ ++ explicit FP32Vec16(const float *ptr) { ++ reg.val[0] = vec_xl(0, ptr); ++ reg.val[1] = vec_xl(16, ptr); ++ reg.val[2] = vec_xl(32, ptr); ++ reg.val[3] = vec_xl(48, ptr); ++ } ++ ++ explicit FP32Vec16(f32x4x4_t data) : reg(data) {} ++ ++ explicit FP32Vec16(const FP32Vec16 &data) { ++ reg.val[0] = data.reg.val[0]; ++ reg.val[1] = data.reg.val[1]; ++ reg.val[2] = data.reg.val[2]; ++ reg.val[3] = data.reg.val[3]; ++ } ++ ++ explicit FP32Vec16(const FP32Vec4 &data) { ++ reg.val[0] = data.reg; ++ reg.val[1] = data.reg; ++ reg.val[2] = data.reg; ++ reg.val[3] = data.reg; ++ } ++ ++ explicit FP32Vec16(const FP32Vec8 &data) { ++ reg.val[0] = data.reg.val[0]; ++ reg.val[1] = data.reg.val[1]; ++ reg.val[2] = data.reg.val[0]; ++ reg.val[3] = data.reg.val[1]; ++ } ++ ++ explicit FP32Vec16(const BF16Vec16 &v) { ++ reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]); ++ reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]); ++ reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]); ++ reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]); ++ } ++ ++ explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} ++ ++ FP32Vec16 operator*(const FP32Vec16 &b) const { ++ return FP32Vec16(f32x4x4_t({ ++ vec_mul(reg.val[0], b.reg.val[0]), ++ vec_mul(reg.val[1], b.reg.val[1]), ++ vec_mul(reg.val[2], b.reg.val[2]), ++ vec_mul(reg.val[3], b.reg.val[3])})); ++ } ++ ++ FP32Vec16 operator+(const FP32Vec16 &b) const { ++ return FP32Vec16(f32x4x4_t({ ++ vec_add(reg.val[0], b.reg.val[0]), ++ vec_add(reg.val[1], b.reg.val[1]), ++ vec_add(reg.val[2], b.reg.val[2]), ++ vec_add(reg.val[3], b.reg.val[3])})); ++ } ++ ++ FP32Vec16 operator-(const FP32Vec16 &b) const { ++ return FP32Vec16(f32x4x4_t({ ++ vec_sub(reg.val[0], b.reg.val[0]), ++ vec_sub(reg.val[1], b.reg.val[1]), ++ vec_sub(reg.val[2], b.reg.val[2]), ++ vec_sub(reg.val[3], b.reg.val[3])})); ++ } ++ ++ FP32Vec16 operator/(const FP32Vec16 &b) const { ++ return FP32Vec16(f32x4x4_t({ ++ vec_div(reg.val[0], b.reg.val[0]), ++ vec_div(reg.val[1], b.reg.val[1]), ++ vec_div(reg.val[2], b.reg.val[2]), ++ vec_div(reg.val[3], b.reg.val[3])})); ++ } ++ ++ float reduce_sum() const { ++ AliasReg ar; ++ ar.reg = reg; ++ float result = 0; ++ unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); ++ ++ return result; ++ } ++ ++ template float reduce_sub_sum(int idx) { ++ static_assert(VEC_ELEM_NUM % group_size == 0); ++ ++ AliasReg ar; ++ ar.reg = reg; ++ float result = 0; ++ const int start = idx * group_size; ++ unroll_loop( ++ [&result, &start, ar](int i) { result += ar.values[start + i]; }); ++ ++ return result; ++ } ++ ++ void save(float *ptr) const { ++ vec_xst(reg.val[0], 0, ptr); ++ vec_xst(reg.val[1], 16, ptr); ++ vec_xst(reg.val[2], 32, ptr); ++ vec_xst(reg.val[3], 48, ptr); ++ } ++}; ++ ++template struct VecType { using vec_type = void; }; ++ ++template using vec_t = typename VecType::vec_type; ++ ++template <> struct VecType { using vec_type = FP32Vec8; }; ++ ++template <> struct VecType { using vec_type = BF16Vec8; }; ++ ++template void storeFP32(float v, T *ptr) { *ptr = v; } ++ ++inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { ++ acc = acc + a * b; ++} ++ ++template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { ++ c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = ++ reinterpret_cast(&v); ++ *ptr = *(v_ptr + 1); ++} ++ ++#ifndef __VEC_CLASS_FP_NAN ++#define __VEC_CLASS_FP_NAN (1 << 6) ++#endif ++ ++const static __vector unsigned char omask = { 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 }; ++#ifndef _ARCH_PWR10 ++const static __vector unsigned int bias = { 0x00007fff, 0x00007fff, 0x00007fff, 0x00007fff }; ++const static __vector unsigned int nan = { 0x7fc00000, 0x7fc00000, 0x7fc00000, 0x7fc00000 }; ++const static __vector unsigned int sh16 = { 16, 16, 16, 16 }; ++const static __vector unsigned int one = { 1, 1, 1, 1 }; ++#endif ++ ++inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) { ++#ifdef _ARCH_PWR10 ++ __vector signed short ret[2]; ++ ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); ++ ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); ++ reg = vec_perm(ret[0], ret[1], omask); ++#elif defined(_ARCH_PWR9) ++ __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); ++ __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]); ++ __vector unsigned int lsb0 = vec_sr(inp0, sh16); ++ __vector unsigned int lsb1 = vec_sr(inp1, sh16); ++ lsb0 = vec_and(lsb0, one); ++ lsb1 = vec_and(lsb1, one); ++ __vector unsigned int rnd0 = vec_add(lsb0, bias); ++ __vector unsigned int rnd1 = vec_add(lsb1, bias); ++ inp0 = vec_add(inp0, rnd0); ++ inp1 = vec_add(inp1, rnd1); ++ __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); ++ __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); ++ inp0 = vec_sel(inp0, nan, sel0); ++ inp1 = vec_sel(inp1, nan, sel1); ++ inp0 = vec_sr(inp0, sh16); ++ inp1 = vec_sr(inp1, sh16); ++ reg = (__vector signed short)vec_perm(inp0, inp1, omask); ++#endif ++} ++ ++inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { ++#ifdef _ARCH_PWR10 ++ __vector signed short ret[4]; ++ ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); ++ ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); ++ ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[2]); ++ ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[3]); ++ reg.val[0] = vec_perm(ret[0], ret[1], omask); ++ reg.val[1] = vec_perm(ret[2], ret[3], omask); ++#elif defined(_ARCH_PWR9) ++ __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); ++ __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]); ++ __vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]); ++ __vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]); ++ __vector unsigned int lsb0 = vec_sr(inp0, sh16); ++ __vector unsigned int lsb1 = vec_sr(inp1, sh16); ++ __vector unsigned int lsb2 = vec_sr(inp2, sh16); ++ __vector unsigned int lsb3 = vec_sr(inp3, sh16); ++ lsb0 = vec_and(lsb0, one); ++ lsb1 = vec_and(lsb1, one); ++ lsb2 = vec_and(lsb2, one); ++ lsb3 = vec_and(lsb3, one); ++ __vector unsigned int rnd0 = vec_add(lsb0, bias); ++ __vector unsigned int rnd1 = vec_add(lsb1, bias); ++ __vector unsigned int rnd2 = vec_add(lsb2, bias); ++ __vector unsigned int rnd3 = vec_add(lsb3, bias); ++ inp0 = vec_add(inp0, rnd0); ++ inp1 = vec_add(inp1, rnd1); ++ inp2 = vec_add(inp2, rnd2); ++ inp3 = vec_add(inp3, rnd3); ++ __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); ++ __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); ++ __vector __bool int sel2 = vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN); ++ __vector __bool int sel3 = vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN); ++ inp0 = vec_sel(inp0, nan, sel0); ++ inp1 = vec_sel(inp1, nan, sel1); ++ inp2 = vec_sel(inp2, nan, sel2); ++ inp3 = vec_sel(inp3, nan, sel3); ++ inp0 = vec_sr(inp0, sh16); ++ inp1 = vec_sr(inp1, sh16); ++ inp2 = vec_sr(inp2, sh16); ++ inp3 = vec_sr(inp3, sh16); ++ reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask); ++ reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask); ++#endif ++} ++ ++inline void prefetch(const void *addr) { ++ __asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory"); ++} ++ ++}; // namespace vec_op ++ ++#endif +diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp +new file mode 100644 +index 0000000..4bb4eb0 +--- /dev/null ++++ b/csrc/cpu/cpu_types_x86.hpp +@@ -0,0 +1,632 @@ ++ ++#ifndef CPU_TYPES_X86_HPP ++#define CPU_TYPES_X86_HPP ++ ++#include ++#include ++ ++#ifndef __AVX2__ ++static_assert(false, "AVX2 must be supported for the current implementation."); ++#endif ++ ++namespace vec_op { ++ ++#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ ++ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ ++ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ ++ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) ++ ++#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ ++ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) ++ ++#ifndef CPU_OP_GUARD ++#define CPU_KERNEL_GUARD_IN(NAME) ++#define CPU_KERNEL_GUARD_OUT(NAME) ++#else ++#define CPU_KERNEL_GUARD_IN(NAME) \ ++ RECORD_FUNCTION(#NAME, c10::ArrayRef({})); ++#define CPU_KERNEL_GUARD_OUT(NAME) ++#endif ++ ++#define FORCE_INLINE __attribute__((always_inline)) inline ++ ++namespace { ++template ++constexpr void unroll_loop_item(std::integer_sequence, F &&f) { ++ (f(std::integral_constant{}), ...); ++} ++}; // namespace ++ ++template >> ++constexpr void unroll_loop(F &&f) { ++ unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); ++} ++ ++template struct Vec { ++ constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } ++}; ++ ++struct FP32Vec8; ++struct FP32Vec16; ++ ++struct FP16Vec8 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 8; ++ ++ __m128i reg; ++ ++ explicit FP16Vec8(const void *ptr) ++ : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} ++ ++ explicit FP16Vec8(const FP32Vec8 &); ++ ++ void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } ++}; ++ ++struct FP16Vec16 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 16; ++ ++ __m256i reg; ++ ++ explicit FP16Vec16(const void *ptr) ++ : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} ++ ++ explicit FP16Vec16(const FP32Vec16 &); ++ ++ void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } ++ ++ void save(void* ptr, const int elem_num) const { ++ constexpr uint32_t M = 0xFFFFFFFF; ++ __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); ++ _mm256_mask_storeu_epi16(ptr, mask, reg); ++ } ++}; ++ ++struct BF16Vec8 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 8; ++ ++ __m128i reg; ++ ++ explicit BF16Vec8(const void *ptr) ++ : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} ++ ++ explicit BF16Vec8(const FP32Vec8 &); ++ ++ void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } ++}; ++ ++struct BF16Vec16 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 16; ++ ++ __m256i reg; ++ ++ explicit BF16Vec16(const void *ptr) ++ : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} ++ ++ explicit BF16Vec16(const FP32Vec16 &); ++ ++ void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } ++ ++ void save(void* ptr, const int elem_num) const { ++ constexpr uint32_t M = 0xFFFFFFFF; ++ __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); ++ _mm256_mask_storeu_epi16(ptr, mask, reg); ++ } ++}; ++ ++#ifdef __AVX512F__ ++struct BF16Vec32 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 32; ++ ++ __m512i reg; ++ ++ explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {} ++ ++ explicit BF16Vec32(__m512i data) : reg(data) {} ++ ++ explicit BF16Vec32(BF16Vec8 &vec8_data) ++ : reg((__m512i)_mm512_inserti32x4( ++ _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512( ++ (__m128i)vec8_data.reg), ++ (__m128i)vec8_data.reg, 1), ++ (__m128i)vec8_data.reg, 2), ++ (__m128i)vec8_data.reg, 3)) {} ++ ++ void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; } ++}; ++#else ++struct BF16Vec32 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 32; ++ ++ __m256i reg_low; ++ __m256i reg_high; ++ ++ explicit BF16Vec32(const void *ptr) ++ : reg_low(_mm256_loadu_si256((__m256i const *)ptr)), ++ reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {} ++ ++ explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low), ++ reg_high(high) {} ++ ++ explicit BF16Vec32(BF16Vec8 &vec8_data) ++ : reg_low((__m256i)_mm256_inserti32x4( ++ _mm256_castsi128_si256((__m128i)vec8_data.reg), ++ (__m128i)vec8_data.reg, 1)), ++ reg_high((__m256i)_mm256_inserti32x4( ++ _mm256_castsi128_si256((__m128i)vec8_data.reg), ++ (__m128i)vec8_data.reg, 1)) {} ++ ++ void save(void *ptr) const { ++ *reinterpret_cast<__m256i *>(ptr) = reg_low; ++ *reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high; ++ } ++}; ++#endif ++ ++struct FP32Vec4 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 4; ++ union AliasReg { ++ __m128 reg; ++ float values[VEC_ELEM_NUM]; ++ }; ++ ++ __m128 reg; ++ ++ explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {} ++ ++ explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {} ++ ++ explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {} ++ ++ explicit FP32Vec4(__m128 data) : reg(data) {} ++ ++ explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} ++}; ++ ++struct FP32Vec8 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 8; ++ union AliasReg { ++ __m256 reg; ++ float values[VEC_ELEM_NUM]; ++ }; ++ ++ __m256 reg; ++ ++ explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {} ++ ++ explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {} ++ ++ explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {} ++ ++ explicit FP32Vec8(__m256 data) : reg(data) {} ++ ++ explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {} ++ ++ explicit FP32Vec8(const FP16Vec8 &v) : reg(_mm256_cvtph_ps(v.reg)) {} ++ ++ explicit FP32Vec8(const BF16Vec8 &v) ++ : reg(_mm256_castsi256_ps( ++ _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {} ++ ++ float reduce_sum() const { ++ AliasReg ar; ++ ar.reg = reg; ++ float result = 0; ++ unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); ++ ++ return result; ++ } ++ ++ FP32Vec8 exp() const { ++ AliasReg ar; ++ ar.reg = reg; ++ return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]), ++ expf(ar.values[5]), expf(ar.values[4]), ++ expf(ar.values[3]), expf(ar.values[2]), ++ expf(ar.values[1]), expf(ar.values[0]))); ++ } ++ ++ FP32Vec8 tanh() const { ++ AliasReg ar; ++ ar.reg = reg; ++ return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]), ++ tanhf(ar.values[5]), tanhf(ar.values[4]), ++ tanhf(ar.values[3]), tanhf(ar.values[2]), ++ tanhf(ar.values[1]), tanhf(ar.values[0]))); ++ } ++ ++ FP32Vec8 er() const { ++ AliasReg ar; ++ ar.reg = reg; ++ return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]), ++ erf(ar.values[5]), erf(ar.values[4]), ++ erf(ar.values[3]), erf(ar.values[2]), ++ erf(ar.values[1]), erf(ar.values[0]))); ++ } ++ ++ FP32Vec8 operator*(const FP32Vec8 &b) const { ++ return FP32Vec8(_mm256_mul_ps(reg, b.reg)); ++ } ++ ++ FP32Vec8 operator+(const FP32Vec8 &b) const { ++ return FP32Vec8(_mm256_add_ps(reg, b.reg)); ++ } ++ ++ FP32Vec8 operator-(const FP32Vec8 &b) const { ++ return FP32Vec8(_mm256_sub_ps(reg, b.reg)); ++ } ++ ++ FP32Vec8 operator/(const FP32Vec8 &b) const { ++ return FP32Vec8(_mm256_div_ps(reg, b.reg)); ++ } ++ ++ void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } ++}; ++ ++#ifdef __AVX512F__ ++struct INT32Vec16: public Vec { ++ constexpr static int VEC_ELEM_NUM = 16; ++ union AliasReg { ++ __m512i reg; ++ int32_t values[VEC_ELEM_NUM]; ++ }; ++ ++ __m512i reg; ++ ++ explicit INT32Vec16(const void* data_ptr) : reg(_mm512_loadu_epi32(data_ptr)) {} ++ ++ void save(int32_t* ptr) const { ++ _mm512_storeu_epi32(ptr, reg); ++ } ++ ++ void save(int32_t* ptr, const int elem_num) const { ++ constexpr uint32_t M = 0xFFFFFFFF; ++ __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); ++ _mm512_mask_storeu_epi32(ptr, mask, reg); ++ } ++}; ++#endif ++ ++#ifdef __AVX512F__ ++struct FP32Vec16 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 16; ++ union AliasReg { ++ __m512 reg; ++ float values[VEC_ELEM_NUM]; ++ }; ++ ++ __m512 reg; ++ ++ explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {} ++ ++ explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} ++ ++ explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {} ++ ++ explicit FP32Vec16(__m512 data) : reg(data) {} ++ ++ explicit FP32Vec16(const FP32Vec4 &data) ++ : reg((__m512)_mm512_inserti32x4( ++ _mm512_inserti32x4( ++ _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg), ++ (__m128i)data.reg, 1), ++ (__m128i)data.reg, 2), ++ (__m128i)data.reg, 3)) {} ++ ++ explicit FP32Vec16(const FP32Vec8 &data) ++ : reg((__m512)_mm512_inserti32x8( ++ _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {} ++ ++ explicit FP32Vec16(const BF16Vec16 &v) ++ : reg(_mm512_castsi512_ps( ++ _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {} ++ ++ explicit FP32Vec16(const FP16Vec16 &v) : reg(_mm512_cvtph_ps(v.reg)) {} ++ ++ explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} ++ ++ explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} ++ ++ explicit FP32Vec16(const INT32Vec16 &v) ++ : reg(_mm512_cvt_roundepi32_ps(v.reg, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC)) {} ++ ++ FP32Vec16 operator*(const FP32Vec16 &b) const { ++ return FP32Vec16(_mm512_mul_ps(reg, b.reg)); ++ } ++ ++ FP32Vec16 operator+(const FP32Vec16 &b) const { ++ return FP32Vec16(_mm512_add_ps(reg, b.reg)); ++ } ++ ++ FP32Vec16 operator-(const FP32Vec16 &b) const { ++ return FP32Vec16(_mm512_sub_ps(reg, b.reg)); ++ } ++ ++ FP32Vec16 operator/(const FP32Vec16 &b) const { ++ return FP32Vec16(_mm512_div_ps(reg, b.reg)); ++ } ++ ++ FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const { ++ return FP32Vec16(_mm512_min_ps(max.reg, _mm512_max_ps(min.reg, reg))); ++ } ++ ++ FP32Vec16 max(const FP32Vec16& b) const { ++ return FP32Vec16(_mm512_max_ps(reg, b.reg)); ++ } ++ ++ FP32Vec16 max(const FP32Vec16& b, const int elem_num) const { ++ constexpr uint32_t M = 0xFFFFFFFF; ++ __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); ++ return FP32Vec16(_mm512_mask_max_ps(reg, mask, reg, b.reg)); ++ } ++ ++ FP32Vec16 min(const FP32Vec16& b) const { ++ return FP32Vec16(_mm512_min_ps(reg, b.reg)); ++ } ++ ++ FP32Vec16 min(const FP32Vec16& b, const int elem_num) const { ++ constexpr uint32_t M = 0xFFFFFFFF; ++ __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); ++ return FP32Vec16(_mm512_mask_min_ps(reg, mask, reg, b.reg)); ++ } ++ ++ FP32Vec16 abs() const { ++ return FP32Vec16(_mm512_abs_ps(reg)); ++ } ++ ++ float reduce_sum() const { return _mm512_reduce_add_ps(reg); } ++ ++ float reduce_max() const { return _mm512_reduce_max_ps(reg); } ++ ++ float reduce_min() const { return _mm512_reduce_min_ps(reg); } ++ ++ template float reduce_sub_sum(int idx) { ++ static_assert(VEC_ELEM_NUM % group_size == 0); ++ constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); ++ __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size)); ++ return _mm512_mask_reduce_add_ps(mask, reg); ++ } ++ ++ void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } ++ ++ void save(float* ptr, const int elem_num) const { ++ constexpr uint32_t M = 0xFFFFFFFF; ++ __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); ++ _mm512_mask_storeu_ps(ptr, mask, reg); ++ } ++}; ++#else ++struct FP32Vec16 : public Vec { ++ constexpr static int VEC_ELEM_NUM = 16; ++ ++ union AliasReg { ++ __m256 reg; ++ float values[8]; ++ }; ++ ++ __m256 reg_low; ++ __m256 reg_high; ++ ++ explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)), ++ reg_high(_mm256_set1_ps(v)) {} ++ ++ explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)), ++ reg_high(_mm256_set1_ps(0.0)) {} ++ ++ explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)), ++ reg_high(_mm256_loadu_ps(ptr + 8)) {} ++ ++ explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {} ++ ++ explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low), ++ reg_high(data.reg_high) {} ++ ++ explicit FP32Vec16(const FP32Vec4 &data) ++ : reg_low((__m256)_mm256_inserti128_si256( ++ _mm256_castsi128_si256((__m128i)data.reg), ++ (__m128i)data.reg, 1)), ++ reg_high((__m256)_mm256_inserti128_si256( ++ _mm256_castsi128_si256((__m128i)data.reg), ++ (__m128i)data.reg, 1)) {} ++ ++ explicit FP32Vec16(const FP32Vec8 &data) ++ : reg_low(data.reg), reg_high(data.reg) {} ++ ++ explicit FP32Vec16(const FP16Vec16 &v) { ++ __m128i low = _mm256_extractf128_si256(v.reg, 0); ++ __m128i high = _mm256_extractf128_si256(v.reg, 1); ++ ++ reg_low = _mm256_cvtph_ps(low); ++ reg_high = _mm256_cvtph_ps(high); ++ } ++ ++ explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} ++ ++ explicit FP32Vec16(const BF16Vec16 &v) { ++ __m128i low = _mm256_extractf128_si256(v.reg, 0); ++ __m128i high = _mm256_extractf128_si256(v.reg, 1); ++ ++ __m256i v_low_epi32 = _mm256_cvtepu16_epi32(low); ++ __m256i v_high_epi32 = _mm256_cvtepu16_epi32(high); ++ ++ __m256i v_low_shifted = _mm256_bslli_epi128(v_low_epi32, 2); ++ __m256i v_high_shifted = _mm256_bslli_epi128(v_high_epi32, 2); ++ ++ reg_low = _mm256_castsi256_ps(v_low_shifted); ++ reg_high = _mm256_castsi256_ps(v_high_shifted); ++ } ++ ++ explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} ++ ++ FP32Vec16 operator*(const FP32Vec16 &b) const { ++ return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low), ++ _mm256_mul_ps(reg_high, b.reg_high)); ++ } ++ ++ FP32Vec16 operator+(const FP32Vec16 &b) const { ++ return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low), ++ _mm256_add_ps(reg_high, b.reg_high)); ++ } ++ ++ FP32Vec16 operator-(const FP32Vec16 &b) const { ++ return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low), ++ _mm256_sub_ps(reg_high, b.reg_high)); ++ } ++ ++ FP32Vec16 operator/(const FP32Vec16 &b) const { ++ return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low), ++ _mm256_div_ps(reg_high, b.reg_high)); ++ } ++ ++ float reduce_sum() const { ++ FP32Vec8 low = FP32Vec8(reg_low); ++ FP32Vec8 high = FP32Vec8(reg_high); ++ return low.reduce_sum() + high.reduce_sum(); ++ } ++ ++ template float reduce_sub_sum(int idx) { ++ float sum = 0.0; ++ static_assert(VEC_ELEM_NUM % group_size == 0); ++ constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); ++ uint32_t mask = base_mask << (idx * group_size); ++ ++ AliasReg ar; ++ ++ auto func = [&sum, &mask, &ar](int i) { ++ int flag = mask & 0x1; ++ mask = mask >> 1; ++ if (flag != 0) sum += ar.values[i]; ++ }; ++ ++ ar.reg = reg_low; ++ unroll_loop(func); ++ ++ ar.reg = reg_high; ++ unroll_loop(func); ++ ++ return sum; ++ } ++ ++ void save(float *ptr) const { ++ _mm256_storeu_ps(ptr, reg_low); ++ _mm256_storeu_ps(ptr + 8, reg_high); ++ } ++}; ++#endif ++ ++#ifdef __AVX512F__ ++struct INT8Vec16: public Vec { ++ constexpr static int VEC_ELEM_NUM = 16; ++ union AliasReg { ++ __m128i reg; ++ int8_t values[VEC_ELEM_NUM]; ++ }; ++ ++ __m128i reg; ++ ++ explicit INT8Vec16(const FP32Vec16& vec) : reg( ++ _mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) ++ ) {} ++ ++ void save(int8_t* ptr) const { ++ _mm_storeu_epi8(ptr, reg); ++ } ++ ++ void save(int8_t* ptr, const int elem_num) const { ++ constexpr uint32_t M = 0xFFFFFFFF; ++ __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); ++ _mm_mask_storeu_epi8(ptr, mask, reg); ++ } ++}; ++#endif ++ ++template struct VecType { using vec_type = void; }; ++ ++template using vec_t = typename VecType::vec_type; ++ ++template <> struct VecType { using vec_type = FP32Vec8; }; ++ ++template <> struct VecType { using vec_type = FP16Vec8; }; ++ ++template <> struct VecType { using vec_type = BF16Vec8; }; ++ ++template void storeFP32(float v, T *ptr) { *ptr = v; } ++ ++inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { ++ acc = acc + a * b; ++} ++ ++template <> inline void storeFP32(float v, c10::Half *ptr) { ++ *reinterpret_cast(ptr) = ++ _cvtss_sh(v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); ++} ++ ++inline FP16Vec8::FP16Vec8(const FP32Vec8 &v) ++ : reg(_mm256_cvtps_ph(v.reg, ++ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {} ++ ++#ifdef __AVX512F__ ++inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) ++ : reg(_mm512_cvtps_ph(v.reg, ++ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {} ++#else ++inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) ++ : reg(_mm256_insertf128_si256(_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg), FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {} ++#endif ++ ++#ifdef __AVX512BF16__ ++template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { ++ *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v); ++} ++ ++inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) ++ : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {} ++ ++inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) ++ : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {} ++ ++inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { ++ acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg); ++} ++#else ++template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { ++ c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = ++ reinterpret_cast(&v); ++ *ptr = *(v_ptr + 1); ++} ++ ++#ifdef __AVX512F__ ++inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) ++ : reg(_mm256_cvtepi32_epi16( ++ _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {} ++ ++inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) ++ : reg(_mm512_cvtepi32_epi16( ++ _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {} ++#else ++namespace{ ++__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) { ++ __m256i ai = _mm256_castps_si256(a); ++ ai = _mm256_srli_epi32(ai, 16); ++ ai = _mm256_packus_epi32(ai, ai); ++ ai = _mm256_permute4x64_epi64(ai, 0b00111001); ++ return _mm256_extracti128_si256(ai, 0); ++} ++} ++ ++inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) ++ : reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {} ++ ++inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { ++ BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low)); ++ BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high)); ++ reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1); ++} ++#endif // __AVX512F__ ++#endif // __AVX512BF16__ ++ ++inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); } ++ ++}; // namespace vec_op ++ ++#endif +diff --git a/csrc/cpu/dnnl_helper.hpp b/csrc/cpu/dnnl_helper.hpp +new file mode 100644 +index 0000000..8b5011d +--- /dev/null ++++ b/csrc/cpu/dnnl_helper.hpp +@@ -0,0 +1,174 @@ ++#ifndef DNNL_HELPER_HPP ++#define DNNL_HELPER_HPP ++ ++#include ++#include ++ ++#include "oneapi/dnnl/dnnl.hpp" ++ ++namespace { ++template ++struct DNNLType { ++ static constexpr dnnl::memory::data_type type = ++ dnnl::memory::data_type::undef; ++}; ++ ++template <> ++struct DNNLType { ++ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8; ++}; ++ ++template <> ++struct DNNLType { ++ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32; ++}; ++ ++template <> ++struct DNNLType { ++ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32; ++}; ++ ++template <> ++struct DNNLType { ++ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; ++}; ++ ++template <> ++struct DNNLType { ++ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16; ++}; ++ ++template ++constexpr inline dnnl::memory::data_type get_dnnl_type() { ++ return DNNLType>::type; ++} ++}; // namespace ++ ++template ++class DNNLPrimitiveHelper { ++ public: ++ // I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias) ++ // A: [M, K], row-major ++ // B: [K, N], column-major ++ // C: [M, N], row-major ++ // bias: [N], row-major, optional ++ // a_scales: [MS] ++ // b_scales: [NS] ++ // Note: Due to the limitation of oneDNN ++ // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is ++ // not supported. ++ template ++ static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c, ++ const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N, ++ dnnl_dim_t K, const float* a_scales, ++ const float* b_scales, dnnl_dim_t MS, ++ dnnl_dim_t NS) { ++ auto&& OutputType = get_dnnl_type(); ++ auto&& BiasType = get_dnnl_type(); ++ ++ dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1}); ++ dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K}); ++ dnnl::memory::desc c_md({M, N}, OutputType, {N, 1}); ++ ++ dnnl::primitive_attr attr; ++ if constexpr (!InputNoScale) { ++ if (MS == 1) { ++ // per-tensor ++ attr.set_scales_mask(DNNL_ARG_SRC, 0); ++ } else { ++ // per-token ++ TORCH_CHECK(false, "per-token quantization is unsupported."); ++ } ++ } ++ ++ if (NS == 1) { ++ // per-tensor ++ attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); ++ } else { ++ // per-channel ++ attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2); ++ } ++ ++ dnnl::matmul::primitive_desc matmul_pd; ++ if (bias) { ++ dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); ++ matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, ++ bias_md, c_md, attr); ++ } else { ++ matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, ++ c_md, attr); ++ } ++ dnnl::matmul matmul(matmul_pd); ++ ++ auto& engine = default_engine(); ++ ++ dnnl::memory a_m(a_md, engine, (void*)a); ++ dnnl::memory b_m(b_md, engine, (void*)b); ++ dnnl::memory c_m(c_md, engine, (void*)c); ++ dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine, ++ (void*)a_scales); ++ dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine, ++ (void*)b_scales); ++ ++ auto& stream = default_stream(); ++ if constexpr (InputNoScale) { ++ if (bias) { ++ dnnl::memory::desc bias_md({N}, BiasType, {1}); ++ dnnl::memory bias_m(bias_md, engine, (void*)bias); ++ matmul.execute( ++ stream, { ++ {DNNL_ARG_SRC, a_m}, ++ {DNNL_ARG_WEIGHTS, b_m}, ++ {DNNL_ARG_BIAS, bias_m}, ++ {DNNL_ARG_DST, c_m}, ++ {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, ++ }); ++ } else { ++ matmul.execute( ++ stream, { ++ {DNNL_ARG_SRC, a_m}, ++ {DNNL_ARG_WEIGHTS, b_m}, ++ {DNNL_ARG_DST, c_m}, ++ {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, ++ }); ++ } ++ } else { ++ if (bias) { ++ dnnl::memory::desc bias_md({N}, BiasType, {1}); ++ dnnl::memory bias_m(bias_md, engine, (void*)bias); ++ matmul.execute( ++ stream, { ++ {DNNL_ARG_SRC, a_m}, ++ {DNNL_ARG_WEIGHTS, b_m}, ++ {DNNL_ARG_BIAS, bias_m}, ++ {DNNL_ARG_DST, c_m}, ++ {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, ++ {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, ++ }); ++ } else { ++ matmul.execute( ++ stream, { ++ {DNNL_ARG_SRC, a_m}, ++ {DNNL_ARG_WEIGHTS, b_m}, ++ {DNNL_ARG_DST, c_m}, ++ {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, ++ {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, ++ }); ++ } ++ } ++ stream.wait(); ++ } ++ ++ private: ++ static dnnl::engine& default_engine() { ++ static dnnl::engine engine(dnnl::engine::kind::cpu, 0); ++ return engine; ++ } ++ ++ static dnnl::stream& default_stream() { ++ static dnnl::stream stream(default_engine()); ++ return stream; ++ } ++}; ++ ++#endif +diff --git a/csrc/cpu/layernorm.cpp b/csrc/cpu/layernorm.cpp +index 467f0dc..a76ad08 100644 +--- a/csrc/cpu/layernorm.cpp ++++ b/csrc/cpu/layernorm.cpp +@@ -2,10 +2,10 @@ + + namespace { + template +-void rms_norm_impl(scalar_t *__restrict__ out, +- const scalar_t *__restrict__ input, +- const scalar_t *__restrict__ weight, const float epsilon, +- const int num_tokens, const int hidden_size) { ++void rms_norm_impl(scalar_t* __restrict__ out, ++ const scalar_t* __restrict__ input, ++ const scalar_t* __restrict__ weight, const float epsilon, ++ const int num_tokens, const int hidden_size) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); +@@ -41,11 +41,11 @@ void rms_norm_impl(scalar_t *__restrict__ out, + } + + template +-void fused_add_rms_norm_impl(scalar_t *__restrict__ input, +- scalar_t *__restrict__ residual, +- const scalar_t *__restrict__ weight, +- const float epsilon, const int num_tokens, +- const int hidden_size) { ++void fused_add_rms_norm_impl(scalar_t* __restrict__ input, ++ scalar_t* __restrict__ residual, ++ const scalar_t* __restrict__ weight, ++ const float epsilon, const int num_tokens, ++ const int hidden_size) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); +@@ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input, + } + } + } +-} // namespace ++} // namespace + +-void rms_norm(torch::Tensor &out, torch::Tensor &input, +- torch::Tensor &weight, float epsilon) { ++void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, ++ double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] { + CPU_KERNEL_GUARD_IN(rms_norm_impl) + rms_norm_impl(out.data_ptr(), input.data_ptr(), +- weight.data_ptr(), epsilon, num_tokens, +- hidden_size); ++ weight.data_ptr(), epsilon, num_tokens, ++ hidden_size); + CPU_KERNEL_GUARD_OUT(rms_norm_impl) + }); + } + +-void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, +- torch::Tensor &weight, float epsilon) { ++void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, ++ torch::Tensor& weight, double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + +diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp +index e9b3992..96bce7d 100644 +--- a/csrc/cpu/pos_encoding.cpp ++++ b/csrc/cpu/pos_encoding.cpp +@@ -4,107 +4,107 @@ + namespace { + template + void rotary_embedding_impl( +- const int64_t +- *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] +- scalar_t +- *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or +- /// [num_tokens, num_heads, head_size] +- scalar_t +- *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or +- // [num_tokens, num_kv_heads, head_size] +- const scalar_t +- *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] ++ const int64_t* __restrict__ positions, // [batch_size, seq_len] or ++ // [num_tokens] ++ scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, ++ /// head_size] or [num_tokens, num_heads, ++ /// head_size] ++ scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, ++ // head_size] or [num_tokens, num_kv_heads, ++ // head_size] ++ const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // ++ // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size, + const int num_tokens) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); +- constexpr int ELEM_SIZE = sizeof(scalar_t); + + const int embed_dim = rot_dim / 2; +- TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); ++ bool flag = (embed_dim % VEC_ELEM_NUM == 0); ++ const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM; + +-#pragma omp parallel for +- for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { +- int64_t pos = positions[token_idx]; +- const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; +- +- for (int i = 0; i < num_heads; ++i) { +- const int head_idx = i; +- const int64_t token_head = +- token_idx * query_stride + head_idx * head_size; +- for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { +- const int rot_offset = j; +- const int x_index = rot_offset; +- const int y_index = embed_dim + rot_offset; ++ auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr, ++ scalar_t* qk) { ++ int j = 0; ++ for (; j < loop_upper; j += VEC_ELEM_NUM) { ++ const int rot_offset = j; ++ const int x_index = rot_offset; ++ const int y_index = embed_dim + rot_offset; + +- const int64_t out_x = token_head + x_index; +- const int64_t out_y = token_head + y_index; ++ const int64_t out_x = token_head + x_index; ++ const int64_t out_y = token_head + y_index; + +- const scalar_vec_t cos(cache_ptr + x_index); +- const scalar_vec_t sin(cache_ptr + y_index); ++ const scalar_vec_t cos(cache_ptr + x_index); ++ const scalar_vec_t sin(cache_ptr + y_index); + +- const scalar_vec_t q_x(query + out_x); +- const scalar_vec_t q_y(query + out_y); ++ const scalar_vec_t q_x(qk + out_x); ++ const scalar_vec_t q_y(qk + out_y); + +- vec_op::FP32Vec8 fp32_cos(cos); +- vec_op::FP32Vec8 fp32_sin(sin); ++ vec_op::FP32Vec8 fp32_cos(cos); ++ vec_op::FP32Vec8 fp32_sin(sin); + +- vec_op::FP32Vec8 fp32_q_x(q_x); +- vec_op::FP32Vec8 fp32_q_y(q_y); ++ vec_op::FP32Vec8 fp32_q_x(q_x); ++ vec_op::FP32Vec8 fp32_q_y(q_y); + +- auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; +- scalar_vec_t(out1).save(query + out_x); ++ auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; ++ scalar_vec_t(out1).save(qk + out_x); + +- auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; +- scalar_vec_t(out2).save(query + out_y); +- } ++ auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; ++ scalar_vec_t(out2).save(qk + out_y); + } +- +- for (int i = 0; i < num_kv_heads; ++i) { +- const int head_idx = i; +- const int64_t token_head = token_idx * key_stride + head_idx * head_size; +- for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { +- const int rot_offset = j; +- const int x_index = rot_offset; +- const int y_index = embed_dim + rot_offset; ++ if (!flag) { ++ for (; j < embed_dim; ++j) { ++ const int x_index = j; ++ const int y_index = embed_dim + j; + + const int64_t out_x = token_head + x_index; + const int64_t out_y = token_head + y_index; + +- const scalar_vec_t cos(cache_ptr + x_index); +- const scalar_vec_t sin(cache_ptr + y_index); ++ const float fp32_cos = cache_ptr[x_index]; ++ const float fp32_sin = cache_ptr[y_index]; + +- const scalar_vec_t k_x(key + out_x); +- const scalar_vec_t k_y(key + out_y); ++ const float fp32_q_x = qk[out_x]; ++ const float fp32_q_y = qk[out_y]; + +- vec_op::FP32Vec8 fp32_cos(cos); +- vec_op::FP32Vec8 fp32_sin(sin); ++ qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; ++ qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; ++ } ++ } ++ }; + +- vec_op::FP32Vec8 fp32_k_x(k_x); +- vec_op::FP32Vec8 fp32_k_y(k_y); ++#pragma omp parallel for ++ for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { ++ int64_t pos = positions[token_idx]; ++ const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + +- auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin; +- scalar_vec_t(out1).save(key + out_x); +- auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin; +- scalar_vec_t(out2).save(key + out_y); +- } ++ for (int i = 0; i < num_heads; ++i) { ++ const int head_idx = i; ++ const int64_t token_head = ++ token_idx * query_stride + head_idx * head_size; ++ compute_loop(token_head, cache_ptr, query); ++ } ++ ++ for (int i = 0; i < num_kv_heads; ++i) { ++ const int head_idx = i; ++ const int64_t token_head = token_idx * key_stride + head_idx * head_size; ++ compute_loop(token_head, cache_ptr, key); + } + } + } + + template + void rotary_embedding_gptj_impl( +- const int64_t +- *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] +- scalar_t +- *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or +- /// [num_tokens, num_heads, head_size] +- scalar_t +- *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or +- // [num_tokens, num_kv_heads, head_size] +- const scalar_t +- *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] ++ const int64_t* __restrict__ positions, // [batch_size, seq_len] or ++ // [num_tokens] ++ scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, ++ /// head_size] or [num_tokens, num_heads, ++ /// head_size] ++ scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, ++ // head_size] or [num_tokens, num_kv_heads, ++ // head_size] ++ const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // ++ // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size, + const int num_tokens) { +@@ -114,13 +114,13 @@ void rotary_embedding_gptj_impl( + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (int i = 0; i < num_heads; ++i) { + int64_t pos = positions[token_idx]; +- const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; +- const scalar_t *cos_cache_ptr = cache_ptr; +- const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; ++ const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; ++ const scalar_t* cos_cache_ptr = cache_ptr; ++ const scalar_t* sin_cache_ptr = cache_ptr + embed_dim; + const int head_idx = i; + const int64_t token_head = + token_idx * query_stride + head_idx * head_size; +- scalar_t *head_query = token_head + query; ++ scalar_t* head_query = token_head + query; + for (int j = 0; j < embed_dim; j += 1) { + const int rot_offset = j; + const int x_index = 2 * rot_offset; +@@ -142,12 +142,12 @@ void rotary_embedding_gptj_impl( + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (int i = 0; i < num_kv_heads; ++i) { + int64_t pos = positions[token_idx]; +- const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; +- const scalar_t *cos_cache_ptr = cache_ptr; +- const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; ++ const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; ++ const scalar_t* cos_cache_ptr = cache_ptr; ++ const scalar_t* sin_cache_ptr = cache_ptr + embed_dim; + const int head_idx = i; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; +- scalar_t *head_key = key + token_head; ++ scalar_t* head_key = key + token_head; + for (int j = 0; j < embed_dim; j += 1) { + const int rot_offset = j; + const int x_index = 2 * rot_offset; +@@ -165,11 +165,11 @@ void rotary_embedding_gptj_impl( + } + } + } +-}; // namespace ++}; // namespace + +-void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, +- torch::Tensor &key, int head_size, +- torch::Tensor &cos_sin_cache, bool is_neox) { ++void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, ++ torch::Tensor& key, int64_t head_size, ++ torch::Tensor& cos_sin_cache, bool is_neox) { + int num_tokens = query.numel() / query.size(-1); + int rot_dim = cos_sin_cache.size(1); + int num_heads = query.size(-1) / head_size; +diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp +new file mode 100644 +index 0000000..33b1637 +--- /dev/null ++++ b/csrc/cpu/quant.cpp +@@ -0,0 +1,613 @@ ++#include "cpu_types.hpp" ++#include "dnnl_helper.hpp" ++ ++namespace { ++template ++struct KernelVecType { ++ using load_vec_type = void; ++ using azp_adj_load_vec_type = void; ++ using cvt_vec_type = void; ++}; ++ ++template <> ++struct KernelVecType { ++ using load_vec_type = vec_op::FP32Vec16; ++ using azp_adj_load_vec_type = vec_op::INT32Vec16; ++ using cvt_vec_type = vec_op::FP32Vec16; ++}; ++ ++template <> ++struct KernelVecType { ++ using load_vec_type = vec_op::BF16Vec16; ++ using azp_adj_load_vec_type = vec_op::INT32Vec16; ++ using cvt_vec_type = vec_op::FP32Vec16; ++}; ++ ++template <> ++struct KernelVecType { ++#ifdef __powerpc64__ ++ // Power architecture-specific vector type ++ using load_vec_type = vec_op::FP32Vec16; ++#else ++ // Fallback for other architectures ++ using load_vec_type = vec_op::FP16Vec16; ++#endif ++ using azp_adj_load_vec_type = vec_op::INT32Vec16; ++ using cvt_vec_type = vec_op::FP32Vec16; ++}; ++ ++#ifdef __AVX512F__ ++template ++void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, ++ const float* scale, const int32_t* azp, ++ const int num_tokens, ++ const int hidden_size) { ++ using load_vec_t = typename KernelVecType::load_vec_type; ++ using cvt_vec_t = typename KernelVecType::cvt_vec_type; ++ constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; ++ ++ constexpr float i8_min = ++ static_cast(std::numeric_limits::min()); ++ constexpr float i8_max = ++ static_cast(std::numeric_limits::max()); ++ const cvt_vec_t inv_scale(1.0 / *scale); ++ const cvt_vec_t i8_min_vec(i8_min); ++ const cvt_vec_t i8_max_vec(i8_max); ++ ++ cvt_vec_t zp_vec; ++ if constexpr (AZP) { ++ zp_vec = cvt_vec_t(static_cast(*azp)); ++ } ++ ++ #pragma omp parallel for ++ for (int i = 0; i < num_tokens; ++i) { ++ int j = 0; ++ for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { ++ load_vec_t elems(input + i * hidden_size + j); ++ cvt_vec_t elems_fp32(elems); ++ elems_fp32 = elems_fp32 * inv_scale; ++ ++ if constexpr (AZP) { ++ elems_fp32 = elems_fp32 + zp_vec; ++ } ++ ++ elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); ++ vec_op::INT8Vec16 elems_int8(elems_fp32); ++ elems_int8.save(output + i * hidden_size + j); ++ } ++ ++ load_vec_t elems(input + i * hidden_size + j); ++ cvt_vec_t elems_fp32(elems); ++ elems_fp32 = elems_fp32 * inv_scale; ++ ++ if constexpr (AZP) { ++ elems_fp32 = elems_fp32 + zp_vec; ++ } ++ ++ elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); ++ vec_op::INT8Vec16 elems_int8(elems_fp32); ++ elems_int8.save(output + i * hidden_size + j, hidden_size - j); ++ } ++} ++ ++template ++void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, ++ float* scale, int32_t* azp, ++ const int num_tokens, ++ const int hidden_size) { ++ using load_vec_t = typename KernelVecType::load_vec_type; ++ using cvt_vec_t = typename KernelVecType::cvt_vec_type; ++ constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; ++ ++ constexpr float i8_min = ++ static_cast(std::numeric_limits::min()); ++ constexpr float i8_max = ++ static_cast(std::numeric_limits::max()); ++ const cvt_vec_t i8_min_vec(i8_min); ++ const cvt_vec_t i8_max_vec(i8_max); ++ ++ #pragma omp parallel for ++ for (int i = 0; i < num_tokens; ++i) { ++ cvt_vec_t max_value(std::numeric_limits::lowest()); ++ cvt_vec_t min_value(std::numeric_limits::max()); ++ { ++ int j = 0; ++ for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { ++ load_vec_t elems(input + i * hidden_size + j); ++ cvt_vec_t elems_fp32(elems); ++ if constexpr (AZP) { ++ max_value = max_value.max(elems_fp32); ++ min_value = min_value.min(elems_fp32); ++ } else { ++ max_value = max_value.max(elems_fp32.abs()); ++ } ++ } ++ ++ load_vec_t elems(input + i * hidden_size + j); ++ cvt_vec_t elems_fp32(elems); ++ ++ if (j + vec_elem_num == hidden_size) { ++ if constexpr (AZP) { ++ max_value = max_value.max(elems_fp32); ++ min_value = min_value.min(elems_fp32); ++ } else { ++ max_value = max_value.max(elems_fp32.abs()); ++ } ++ } else { ++ if constexpr (AZP) { ++ max_value = max_value.max(elems_fp32, hidden_size - j); ++ min_value = min_value.min(elems_fp32, hidden_size - j); ++ } else { ++ max_value = max_value.max(elems_fp32.abs(), hidden_size - j); ++ } ++ } ++ } ++ ++ float scale_val, azp_val; ++ if constexpr (AZP) { ++ float max_scalar = max_value.reduce_max(); ++ float min_scalar = min_value.reduce_min(); ++ scale_val = (max_scalar - min_scalar) / 255.0f; ++ azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); ++ azp[i] = static_cast(azp_val); ++ scale[i] = scale_val; ++ } else { ++ scale_val = max_value.reduce_max() / 127.0f; ++ scale[i] = scale_val; ++ } ++ ++ const cvt_vec_t inv_scale(1.0 / scale_val); ++ const cvt_vec_t azp_vec(azp_val); ++ ++ { ++ int j = 0; ++ for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { ++ load_vec_t elems(input + i * hidden_size + j); ++ cvt_vec_t elems_fp32(elems); ++ elems_fp32 = (elems_fp32 * inv_scale); ++ ++ if constexpr (AZP) { ++ elems_fp32 = elems_fp32 + azp_vec; ++ } ++ elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); ++ vec_op::INT8Vec16 elems_int8(elems_fp32); ++ elems_int8.save(output + i * hidden_size + j); ++ } ++ ++ load_vec_t elems(input + i * hidden_size + j); ++ cvt_vec_t elems_fp32(elems); ++ elems_fp32 = (elems_fp32 * inv_scale); ++ ++ if constexpr (AZP) { ++ elems_fp32 = elems_fp32 + azp_vec; ++ } ++ elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); ++ vec_op::INT8Vec16 elems_int8(elems_fp32); ++ elems_int8.save(output + i * hidden_size + j, hidden_size - j); ++ } ++ } ++} ++ ++template ++void static_quant_epilogue(const float* input, scalar_t* output, ++ const float a_scale, const float* b_scale, ++ const int32_t* azp_with_adj, const int num_tokens, ++ const int hidden_size) { ++ CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) ++ using load_vec_t = typename KernelVecType::load_vec_type; ++ using azp_adj_load_vec_t = ++ typename KernelVecType::azp_adj_load_vec_type; ++ using cvt_vec_t = typename KernelVecType::cvt_vec_type; ++ constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; ++ ++ #pragma omp parallel for ++ for (int i = 0; i < num_tokens; ++i) { ++ cvt_vec_t a_scale_vec(a_scale); ++ cvt_vec_t b_scale_vec(*b_scale); ++ cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; ++ ++ int j = 0; ++ for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { ++ cvt_vec_t elems_fp32(input + i * hidden_size + j); ++ azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); ++ cvt_vec_t azp_adj_fp32(azp_adj_vec); ++ ++ if constexpr (PerChannel) { ++ b_scale_vec = cvt_vec_t(b_scale + j); ++ scale_vec = b_scale_vec * a_scale_vec; ++ } ++ ++ elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; ++ ++ load_vec_t elems_out(elems_fp32); ++ elems_out.save(output + i * hidden_size + j); ++ } ++ ++ cvt_vec_t elems_fp32(input + i * hidden_size + j); ++ azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); ++ cvt_vec_t azp_adj_fp32(azp_adj_vec); ++ ++ if constexpr (PerChannel) { ++ b_scale_vec = cvt_vec_t(b_scale + j); ++ scale_vec = b_scale_vec * a_scale_vec; ++ } ++ ++ elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; ++ ++ load_vec_t elems_out(elems_fp32); ++ elems_out.save(output + i * hidden_size + j, hidden_size - j); ++ } ++} ++ ++template ++void dynamic_quant_epilogue(const float* input, scalar_t* output, ++ const float* a_scale, const float* b_scale, ++ const int32_t* azp, const int32_t* azp_adj, ++ const scalar_t* bias, const int num_tokens, ++ const int hidden_size) { ++ CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) ++ using load_vec_t = typename KernelVecType::load_vec_type; ++ using azp_adj_load_vec_t = ++ typename KernelVecType::azp_adj_load_vec_type; ++ using cvt_vec_t = typename KernelVecType::cvt_vec_type; ++ constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; ++ ++ #pragma omp parallel for ++ for (int i = 0; i < num_tokens; ++i) { ++ int j = 0; ++ cvt_vec_t token_scale_vec(a_scale[i]); ++ cvt_vec_t token_zp_scale_vec; ++ if constexpr (AZP) { ++ float zp_scale_val = a_scale[i] * static_cast(azp[i]); ++ if constexpr (!PerChannel) { ++ zp_scale_val *= *b_scale; ++ } ++ token_zp_scale_vec = cvt_vec_t(zp_scale_val); ++ } ++ ++ for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { ++ cvt_vec_t elems_fp32(input + i * hidden_size + j); ++ elems_fp32 = elems_fp32 * token_scale_vec; ++ ++ if constexpr (AZP) { ++ azp_adj_load_vec_t azp_adj_vec(azp_adj + j); ++ cvt_vec_t azp_adj_fp32(azp_adj_vec); ++ azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; ++ ++ if constexpr (PerChannel) { ++ cvt_vec_t b_scale_vec(b_scale + j); ++ azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; ++ } ++ ++ elems_fp32 = elems_fp32 - azp_adj_fp32; ++ } ++ ++ if constexpr (Bias) { ++ load_vec_t bias_vec(bias + j); ++ cvt_vec_t bias_vec_fp32(bias_vec); ++ elems_fp32 = elems_fp32 + bias_vec_fp32; ++ } ++ ++ load_vec_t elems_out(elems_fp32); ++ elems_out.save(output + i * hidden_size + j); ++ } ++ ++ cvt_vec_t elems_fp32(input + i * hidden_size + j); ++ elems_fp32 = elems_fp32 * token_scale_vec; ++ ++ if constexpr (AZP) { ++ azp_adj_load_vec_t azp_adj_vec(azp_adj + j); ++ cvt_vec_t azp_adj_fp32(azp_adj_vec); ++ azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; ++ ++ if constexpr (PerChannel) { ++ cvt_vec_t b_scale_vec(b_scale + j); ++ azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; ++ } ++ ++ elems_fp32 = elems_fp32 - azp_adj_fp32; ++ } ++ ++ if constexpr (Bias) { ++ load_vec_t bias_vec(bias + j); ++ cvt_vec_t bias_vec_fp32(bias_vec); ++ elems_fp32 = elems_fp32 + bias_vec_fp32; ++ } ++ ++ load_vec_t elems_out(elems_fp32); ++ elems_out.save(output + i * hidden_size + j, hidden_size - j); ++ } ++} ++#else ++template ++void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, ++ const float* scale, const int32_t* azp, ++ const int num_tokens, ++ const int hidden_size) { ++ TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.") ++} ++ ++template ++void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, ++ float* scale, int32_t* azp, ++ const int num_tokens, ++ const int hidden_size) { ++ TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.") ++} ++ ++template ++void static_quant_epilogue(const float* input, scalar_t* output, ++ const float a_scale, const float* b_scale, ++ const int32_t* azp_with_adj, const int num_tokens, ++ const int hidden_size) { ++ TORCH_CHECK(false, "static_quant_epilogue requires AVX512 support.") ++} ++ ++template ++void dynamic_quant_epilogue(const float* input, scalar_t* output, ++ const float* a_scale, const float* b_scale, ++ const int32_t* azp, const int32_t* azp_with_adj, ++ const scalar_t* bias, const int num_tokens, ++ const int hidden_size) { ++ TORCH_CHECK(false, "dynamic_quant_epilogue requires AVX512 support.") ++} ++#endif ++} // namespace ++ ++void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major ++ const torch::Tensor& a, // [M, IC], row-major ++ const torch::Tensor& b, // [IC, OC], column-major ++ const torch::Tensor& a_scales, // [1] or [M] ++ const torch::Tensor& b_scales, // [1] or [OC] ++ const std::optional& bias // [OC] ++) { ++ CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) ++ // Checks for conformality ++ TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, ++ "int8_scaled_mm only supports INT8 inputs.") ++ TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); ++ TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && ++ b.size(1) == c.size(1)); ++ TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); ++ TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); ++ ++ // Check for strides and alignment ++ TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major ++ TORCH_CHECK(b.stride(0) == 1); // Column-major ++ TORCH_CHECK(c.stride(0) % 16 == 0 && ++ b.stride(1) % 16 == 0); // 16 Byte Alignment ++ TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); ++ ++ if (bias) { ++ TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && ++ bias->dim() == 1); ++ } ++ ++ VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] { ++ if (a_scales.numel() != 1) { ++ // per-token ++ // Note: oneDNN doesn't support per-token activation quantization ++ // Ideally we want to fuse the GEMM and the scale procedure with oneDNN ++ // JIT, the intermediate data is cached in registers or L1. But for now ++ // the oneDNN GEMM code generation only supports two quantization ++ // patterns: per-tensor or per-output-channel of weight. ++ // So we have to apply the per-token scale with a 'epilogue'. In C=s_a * ++ // s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN ++ // GEMM, then the per-token scale (and bias) is applied with the epilogue ++ // C=s_a * C_inter + bias. ++ torch::Tensor tmp_fp32_out = ++ torch::empty_like(c, ::at::ScalarType::Float); ++ // Compute C_inter=s_b * (A@B) ++ DNNLPrimitiveHelper::gemm_s8s8_jit( ++ a.data_ptr(), b.data_ptr(), ++ tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), ++ a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); ++ if (bias.has_value()) { ++ // Compute C=s_a * C_inter + bias ++ dynamic_quant_epilogue( ++ tmp_fp32_out.data_ptr(), c.data_ptr(), ++ a_scales.data_ptr(), nullptr, nullptr, nullptr, ++ bias->data_ptr(), c.size(0), c.size(1)); ++ } else { ++ // Compute C=s_a * C_inter ++ dynamic_quant_epilogue( ++ tmp_fp32_out.data_ptr(), c.data_ptr(), ++ a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, ++ c.size(0), c.size(1)); ++ } ++ } else { ++ // per-tensor ++ if (bias.has_value()) { ++ // Compute C=s_a * s_b * (A@B) + bias ++ DNNLPrimitiveHelper::gemm_s8s8_jit( ++ a.data_ptr(), b.data_ptr(), c.data_ptr(), ++ bias->data_ptr(), a.size(0), b.size(1), a.size(1), ++ a_scales.data_ptr(), b_scales.data_ptr(), ++ a_scales.numel(), b_scales.numel()); ++ } else { ++ // Compute C=s_a * s_b * (A@B) ++ DNNLPrimitiveHelper::gemm_s8s8_jit( ++ a.data_ptr(), b.data_ptr(), c.data_ptr(), ++ nullptr, a.size(0), b.size(1), a.size(1), ++ a_scales.data_ptr(), b_scales.data_ptr(), ++ a_scales.numel(), b_scales.numel()); ++ } ++ } ++ }); ++} ++ ++void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major ++ const torch::Tensor& a, // [M, IC], row-major ++ const torch::Tensor& b, // [IC, OC], column-major ++ const torch::Tensor& a_scales, // [1] or [M] ++ const torch::Tensor& b_scales, // [1] or [OC] ++ const torch::Tensor& azp_adj, // [OC] ++ const std::optional& azp, // [1] or [M] ++ const std::optional& bias // [OC] ++) { ++ CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp) ++ // Checks for conformality ++ TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, ++ "int8_scaled_mm_azp only supports INT8 inputs.") ++ TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); ++ TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && ++ b.size(1) == c.size(1)); ++ TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); ++ TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); ++ ++ // Check for strides and alignment ++ TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major ++ TORCH_CHECK(b.stride(0) == 1); // Column-major ++ TORCH_CHECK(c.stride(0) % 16 == 0 && ++ b.stride(1) % 16 == 0); // 16 Byte Alignment ++ TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); ++ ++ if (bias) { ++ TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous()); ++ } ++ if (azp) { ++ TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous()); ++ } ++ TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous()); ++ ++ // azp & bias types ++ TORCH_CHECK(azp_adj.dtype() == torch::kInt32); ++ TORCH_CHECK(!azp || azp->dtype() == torch::kInt32); ++ TORCH_CHECK(!bias || bias->dtype() == c.dtype(), ++ "currently bias dtype must match output dtype ", c.dtype()); ++ ++ VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_azp", [&] { ++ torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); ++ if (a_scales.numel() != 1) { ++ // per-token ++ // Note: oneDNN doesn't support per-token activation quantization ++ // Compute C_inter=s_b * (A@B) ++ DNNLPrimitiveHelper::gemm_s8s8_jit( ++ a.data_ptr(), b.data_ptr(), ++ tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), ++ a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); ++ if (bias.has_value()) { ++ // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + bias ++ if (b_scales.numel() != 1) { ++ // Per-Channel ++ dynamic_quant_epilogue( ++ tmp_fp32_out.data_ptr(), c.data_ptr(), ++ a_scales.data_ptr(), b_scales.data_ptr(), ++ azp->data_ptr(), azp_adj.data_ptr(), ++ bias->data_ptr(), c.size(0), c.size(1)); ++ } else { ++ // Per-Tensor ++ dynamic_quant_epilogue( ++ tmp_fp32_out.data_ptr(), c.data_ptr(), ++ a_scales.data_ptr(), b_scales.data_ptr(), ++ azp->data_ptr(), azp_adj.data_ptr(), ++ bias->data_ptr(), c.size(0), c.size(1)); ++ } ++ } else { ++ // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj ++ if (b_scales.numel() != 1) { ++ // Per-Channel ++ dynamic_quant_epilogue( ++ tmp_fp32_out.data_ptr(), c.data_ptr(), ++ a_scales.data_ptr(), b_scales.data_ptr(), ++ azp->data_ptr(), azp_adj.data_ptr(), nullptr, ++ c.size(0), c.size(1)); ++ } else { ++ // Per-Tensor ++ dynamic_quant_epilogue( ++ tmp_fp32_out.data_ptr(), c.data_ptr(), ++ a_scales.data_ptr(), b_scales.data_ptr(), ++ azp->data_ptr(), azp_adj.data_ptr(), nullptr, ++ c.size(0), c.size(1)); ++ } ++ } ++ } else { ++ // per-tensor ++ if (bias.has_value()) { ++ // Compute C_inter=s_a * s_b * (A@B) + bias ++ DNNLPrimitiveHelper::gemm_s8s8_jit( ++ a.data_ptr(), b.data_ptr(), ++ tmp_fp32_out.data_ptr(), bias->data_ptr(), ++ a.size(0), b.size(1), a.size(1), a_scales.data_ptr(), ++ b_scales.data_ptr(), a_scales.numel(), b_scales.numel()); ++ } else { ++ // Compute C_inter=s_a * s_b * (A@B) ++ DNNLPrimitiveHelper::gemm_s8s8_jit( ++ a.data_ptr(), b.data_ptr(), ++ tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), ++ a.size(1), a_scales.data_ptr(), b_scales.data_ptr(), ++ a_scales.numel(), b_scales.numel()); ++ } ++ ++ // Compute C=C_inter - s_a * s_b * azp_adj ++ if (b_scales.numel() != 1) { ++ // Per-Channel ++ static_quant_epilogue( ++ tmp_fp32_out.data_ptr(), c.data_ptr(), ++ *a_scales.data_ptr(), b_scales.data_ptr(), ++ azp_adj.data_ptr(), a.size(0), b.size(1)); ++ } else { ++ // Per-Tensor ++ static_quant_epilogue( ++ tmp_fp32_out.data_ptr(), c.data_ptr(), ++ *a_scales.data_ptr(), b_scales.data_ptr(), ++ azp_adj.data_ptr(), a.size(0), b.size(1)); ++ } ++ } ++ }); ++} ++ ++// static-per-tensor quantization. ++void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] ++ const torch::Tensor& input, // [..., hidden_size] ++ const torch::Tensor& scale, ++ std::optional const& azp) { ++ CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) ++ TORCH_CHECK(input.is_contiguous()); ++ TORCH_CHECK(out.is_contiguous()); ++ TORCH_CHECK(scale.numel() == 1); ++ TORCH_CHECK(!azp.has_value() || azp->numel() == 1); ++ ++ const int hidden_size = input.size(-1); ++ const int num_tokens = input.numel() / hidden_size; ++ VLLM_DISPATCH_FLOATING_TYPES( ++ input.scalar_type(), "static_scaled_int8_quant_impl", [&] { ++ if (azp.has_value()) { ++ static_scaled_int8_quant_impl( ++ input.data_ptr(), out.data_ptr(), ++ scale.data_ptr(), azp->data_ptr(), num_tokens, ++ hidden_size); ++ } else { ++ static_scaled_int8_quant_impl( ++ input.data_ptr(), out.data_ptr(), ++ scale.data_ptr(), nullptr, num_tokens, hidden_size); ++ } ++ }); ++} ++ ++// dynamic-per-token quantization. ++void dynamic_scaled_int8_quant( ++ torch::Tensor& out, // [..., hidden_size] ++ const torch::Tensor& input, // [..., hidden_size] ++ torch::Tensor& scale, // [..., 1] ++ std::optional const& azp) { ++ CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) ++ TORCH_CHECK(input.is_contiguous()); ++ TORCH_CHECK(out.is_contiguous()); ++ ++ int const hidden_size = input.size(-1); ++ int const num_tokens = input.numel() / hidden_size; ++ VLLM_DISPATCH_FLOATING_TYPES( ++ input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] { ++ if (azp.has_value()) { ++ dynamic_scaled_int8_quant_impl( ++ input.data_ptr(), out.data_ptr(), ++ scale.data_ptr(), azp->data_ptr(), num_tokens, ++ hidden_size); ++ } else { ++ dynamic_scaled_int8_quant_impl( ++ input.data_ptr(), out.data_ptr(), ++ scale.data_ptr(), nullptr, num_tokens, hidden_size); ++ } ++ }); ++} +diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp +new file mode 100644 +index 0000000..74e4d81 +--- /dev/null ++++ b/csrc/cpu/torch_bindings.cpp +@@ -0,0 +1,160 @@ ++#include "cache.h" ++#include "ops.h" ++#include "core/registration.h" ++ ++#include ++ ++std::string init_cpu_threads_env(const std::string& cpu_ids); ++ ++void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a, ++ const torch::Tensor& b, const torch::Tensor& a_scales, ++ const torch::Tensor& b_scales, ++ const std::optional& bias); ++ ++void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, ++ const torch::Tensor& b, const torch::Tensor& a_scales, ++ const torch::Tensor& b_scales, ++ const torch::Tensor& azp_adj, ++ const std::optional& azp, ++ const std::optional& bias); ++ ++TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ++ // vLLM custom ops ++ ++ // Attention ops ++ // Compute the attention between an input query and the cached keys/values ++ // using PagedAttention. ++ ops.def( ++ "paged_attention_v1(" ++ " Tensor! out, Tensor query, Tensor key_cache," ++ " Tensor value_cache, int num_kv_heads, float scale," ++ " Tensor block_tables, Tensor seq_lens, int block_size," ++ " int max_seq_len, Tensor? alibi_slopes," ++ " str kv_cache_dtype, float k_scale, float v_scale," ++ " int tp_rank, int blocksparse_local_blocks," ++ " int blocksparse_vert_stride, int blocksparse_block_size," ++ " int blocksparse_head_sliding_step) -> ()"); ++ ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); ++ ++ // PagedAttention V2. ++ ops.def( ++ "paged_attention_v2(" ++ " Tensor! out, Tensor! exp_sums, Tensor! max_logits," ++ " Tensor! tmp_out, Tensor query, Tensor key_cache," ++ " Tensor value_cache, int num_kv_heads, float scale," ++ " Tensor block_tables, Tensor seq_lens, int block_size," ++ " int max_seq_len, Tensor? alibi_slopes," ++ " str kv_cache_dtype, float k_scale, float v_scale," ++ " int tp_rank, int blocksparse_local_blocks," ++ " int blocksparse_vert_stride, int blocksparse_block_size," ++ " int blocksparse_head_sliding_step) -> ()"); ++ ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2); ++ ++ // Activation ops ++ ++ // Activation function used in SwiGLU. ++ ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); ++ ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul); ++ ++ // Activation function used in GeGLU with `none` approximation. ++ ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); ++ ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul); ++ ++ // Activation function used in GeGLU with `tanh` approximation. ++ ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); ++ ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul); ++ ++ // GELU implementation used in GPT-2. ++ ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); ++ ops.impl("gelu_new", torch::kCPU, &gelu_new); ++ ++ // Approximate GELU implementation. ++ ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); ++ ops.impl("gelu_fast", torch::kCPU, &gelu_fast); ++ ++ // Quick GELU implementation. ++ ops.def("gelu_quick(Tensor! out, Tensor input) -> ()"); ++ ops.impl("gelu_quick", torch::kCPU, &gelu_quick); ++ ++ // Layernorm ++ // Apply Root Mean Square (RMS) Normalization to the input tensor. ++ ops.def( ++ "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> " ++ "()"); ++ ops.impl("rms_norm", torch::kCPU, &rms_norm); ++ ++ // In-place fused Add and RMS Normalization. ++ ops.def( ++ "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " ++ "float epsilon) -> ()"); ++ ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm); ++ ++ // Rotary embedding ++ // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ++ ops.def( ++ "rotary_embedding(Tensor positions, Tensor! query," ++ " Tensor! key, int head_size," ++ " Tensor cos_sin_cache, bool is_neox) -> ()"); ++ ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); ++ ++ // Quantization ++#ifdef __AVX512F__ ++ // Compute int8 quantized tensor for given scaling factor. ++ ops.def( ++ "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," ++ "Tensor? azp) -> ()"); ++ ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); ++ ++ // Compute int8 quantized tensor and scaling factor ++ ops.def( ++ "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " ++ "Tensor!? azp) -> ()"); ++ ops.impl("dynamic_scaled_int8_quant", torch::kCPU, ++ &dynamic_scaled_int8_quant); ++ // W8A8 GEMM, supporting symmetric per-tensor or per-row/column ++ // quantization. ++ ops.def( ++ "cutlass_scaled_mm(Tensor! out, Tensor a," ++ " Tensor b, Tensor a_scales," ++ " Tensor b_scales, Tensor? bias) -> ()"); ++ ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); ++ // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column ++ // quantization. ++ ops.def( ++ "cutlass_scaled_mm_azp(Tensor! out, Tensor a," ++ " Tensor b, Tensor a_scales," ++ " Tensor b_scales, Tensor azp_adj," ++ " Tensor? azp, Tensor? bias) -> ()"); ++ ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); ++#endif ++} ++ ++TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { ++ // Cache ops ++ // Swap in (out) the cache blocks from src to dst. ++ cache_ops.def( ++ "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); ++ cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks); ++ ++ // Copy the cache blocks from src to dst. ++ cache_ops.def( ++ "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, " ++ "Tensor block_mapping) -> ()"); ++ cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks); ++ ++ // Reshape the key and value tensors and cache them. ++ cache_ops.def( ++ "reshape_and_cache(Tensor key, Tensor value," ++ " Tensor! key_cache, Tensor! value_cache," ++ " Tensor slot_mapping," ++ " str kv_cache_dtype," ++ " float k_scale, float v_scale) -> ()"); ++ cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); ++} ++ ++TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) { ++ // CPU utils ++ utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env); ++} ++ ++REGISTER_EXTENSION(TORCH_EXTENSION_NAME) +diff --git a/csrc/cpu/utils.cpp b/csrc/cpu/utils.cpp +new file mode 100644 +index 0000000..42a1c1d +--- /dev/null ++++ b/csrc/cpu/utils.cpp +@@ -0,0 +1,103 @@ ++#ifndef VLLM_NUMA_DISABLED ++ #include ++ #include ++ #include ++ #include ++#endif ++ ++#include "cpu_types.hpp" ++ ++#ifdef VLLM_NUMA_DISABLED ++std::string init_cpu_threads_env(const std::string& cpu_ids) { ++ return std::string( ++ "Warning: NUMA is not enabled in this build. `init_cpu_threads_env` has " ++ "no effect to setup thread affinity."); ++} ++ ++#endif ++ ++#ifndef VLLM_NUMA_DISABLED ++std::string init_cpu_threads_env(const std::string& cpu_ids) { ++ bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str()); ++ TORCH_CHECK(omp_cpu_mask->size > 0); ++ std::vector omp_cpu_ids; ++ omp_cpu_ids.reserve(omp_cpu_mask->size); ++ ++ constexpr int group_size = 8 * sizeof(*omp_cpu_mask->maskp); ++ ++ for (int offset = 0; offset < omp_cpu_mask->size; offset += group_size) { ++ unsigned long group_mask = omp_cpu_mask->maskp[offset / group_size]; ++ int i = 0; ++ while (group_mask) { ++ if (group_mask & 1) { ++ omp_cpu_ids.emplace_back(offset + i); ++ } ++ ++i; ++ group_mask >>= 1; ++ } ++ } ++ ++ // Memory node binding ++ if (numa_available() != -1) { ++ int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front()); ++ bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str()); ++ bitmask* src_mask = numa_get_membind(); ++ ++ int pid = getpid(); ++ ++ // move all existing pages to the specified numa node. ++ *(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp); ++ int page_num = numa_migrate_pages(pid, src_mask, mask); ++ if (page_num == -1) { ++ TORCH_CHECK(false, ++ "numa_migrate_pages failed. errno: " + std::to_string(errno)); ++ } ++ ++ // restrict memory allocation node. ++ numa_set_membind(mask); ++ numa_set_strict(1); ++ } ++ ++ // OMP threads binding ++ omp_set_num_threads((int)omp_cpu_ids.size()); ++ torch::set_num_threads((int)omp_cpu_ids.size()); ++ TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads()); ++ TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads()); ++ ++ std::vector> thread_core_mapping; ++ thread_core_mapping.reserve(omp_cpu_ids.size()); ++ omp_lock_t writelock; ++ omp_init_lock(&writelock); ++ ++ #pragma omp parallel for schedule(static, 1) ++ for (size_t i = 0; i < omp_cpu_ids.size(); ++i) { ++ cpu_set_t mask; ++ CPU_ZERO(&mask); ++ CPU_SET(omp_cpu_ids[i], &mask); ++ int ret = sched_setaffinity(0, sizeof(cpu_set_t), &mask); ++ if (ret == -1) { ++ TORCH_CHECK(false, ++ "sched_setaffinity failed. errno: " + std::to_string(errno)); ++ } ++ ++ omp_set_lock(&writelock); ++ thread_core_mapping.emplace_back(gettid(), omp_cpu_ids[i]); ++ omp_unset_lock(&writelock); ++ } ++ ++ omp_destroy_lock(&writelock); ++ ++ numa_free_nodemask(omp_cpu_mask); ++ ++ std::stringstream ss; ++ ss << "OMP threads binding of Process " << getpid() << ":\n"; ++ std::sort(thread_core_mapping.begin(), thread_core_mapping.end(), ++ [](auto&& a, auto&& b) { return a.second < b.second; }); ++ for (auto&& item : thread_core_mapping) { ++ ss << "\t" ++ << "OMP tid: " << item.first << ", core " << item.second << "\n"; ++ } ++ ++ return ss.str(); ++} ++#endif +\ No newline at end of file +diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h +index c711d8d..82e5561 100644 +--- a/csrc/cuda_compat.h ++++ b/csrc/cuda_compat.h +@@ -1,7 +1,7 @@ + #pragma once + + #ifdef USE_ROCM +-#include ++ #include + #endif + + #ifndef USE_ROCM +@@ -17,9 +17,14 @@ + #endif + + #ifndef USE_ROCM +- #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) ++ #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ ++ __shfl_xor_sync(uint32_t(-1), var, lane_mask) ++ #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ ++ __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) + #else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) ++ #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ ++ __shfl_xor(var, lane_mask, width) + #endif + + #ifndef USE_ROCM +@@ -28,6 +33,13 @@ + #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) + #endif + ++#ifndef USE_ROCM ++ #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \ ++ __shfl_down_sync(uint32_t(-1), var, lane_delta) ++#else ++ #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) ++#endif ++ + #ifndef USE_ROCM + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) +@@ -35,4 +47,3 @@ + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) + #endif +- +diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h +index 1483484..c352242 100644 +--- a/csrc/cuda_utils.h ++++ b/csrc/cuda_utils.h +@@ -1,10 +1,15 @@ + #pragma once + +-#include ++#if defined(__CUDACC__) || defined(_NVHPC_CUDA) ++ #define HOST_DEVICE_INLINE __forceinline__ __host__ __device__ ++ #define DEVICE_INLINE __forceinline__ __device__ ++ #define HOST_INLINE __forceinline__ __host__ ++#else ++ #define HOST_DEVICE_INLINE inline ++ #define DEVICE_INLINE inline ++ #define HOST_INLINE inline ++#endif + +-int get_device_attribute( +- int attribute, +- int device_id); ++int64_t get_device_attribute(int64_t attribute, int64_t device_id); + +-int get_max_shared_memory_per_block_device_attribute( +- int device_id); ++int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); +diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu +index 1a443ef..d6f9eb6 100644 +--- a/csrc/cuda_utils_kernels.cu ++++ b/csrc/cuda_utils_kernels.cu +@@ -2,34 +2,28 @@ + #include + #include + #endif +-int get_device_attribute( +- int attribute, +- int device_id) +-{ +- int device, value; +- if (device_id < 0) { +- cudaGetDevice(&device); +- } +- else { +- device = device_id; +- } +- cudaDeviceGetAttribute(&value, static_cast(attribute), device); +- return value; ++int64_t get_device_attribute(int64_t attribute, int64_t device_id) { ++ int device, value; ++ if (device_id < 0) { ++ cudaGetDevice(&device); ++ } else { ++ device = device_id; ++ } ++ cudaDeviceGetAttribute(&value, static_cast(attribute), ++ device); ++ return value; + } + +- +-int get_max_shared_memory_per_block_device_attribute( +- int device_id) +-{ +-int attribute; +-// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html +-// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 ++int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) { ++ int64_t attribute; ++ // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html ++ // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 + + #ifdef USE_ROCM +- attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; ++ attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; + #else +- attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; ++ attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; + #endif + +- return get_device_attribute(attribute, device_id); ++ return get_device_attribute(attribute, device_id); + } +diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu +index 3906dcf..123278b 100644 +--- a/csrc/custom_all_reduce.cu ++++ b/csrc/custom_all_reduce.cu +@@ -1,36 +1,33 @@ + #include + #include + #include +-#include ++#include + + #include "custom_all_reduce.cuh" + +-// fake pointer type +-using fptr_t = uint64_t; +-static_assert(sizeof(void *) == sizeof(fptr_t)); ++// Fake pointer type, must match fptr_t type in ops.h. ++// We use this type alias to indicate when pointers are passed in as int64_t. ++using fptr_t = int64_t; ++static_assert(sizeof(void*) == sizeof(fptr_t)); + +-fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, +- const std::vector &handles, +- const std::vector &offsets, int rank, ++fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, ++ torch::Tensor& rank_data, int64_t rank, + bool full_nvlink) { +- int world_size = offsets.size(); ++ int world_size = fake_ipc_ptrs.size(); + if (world_size > 8) + throw std::invalid_argument("world size > 8 is not supported"); + if (world_size % 2 != 0) + throw std::invalid_argument("Odd num gpus is not supported for now"); +- if (world_size != handles.size()) +- throw std::invalid_argument( +- "handles length should equal to offsets length"); + if (rank < 0 || rank >= world_size) + throw std::invalid_argument("invalid rank passed in"); + +- cudaIpcMemHandle_t ipc_handles[8]; ++ vllm::Signal* ipc_ptrs[8]; + for (int i = 0; i < world_size; i++) { +- std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); ++ ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); + } +- return (fptr_t) new vllm::CustomAllreduce( +- reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), +- rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); ++ return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(), ++ rank_data.numel(), rank, world_size, ++ full_nvlink); + } + + /** +@@ -49,46 +46,55 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, + * 5. A[None].expand(2, -1, -1, -1): Not OK + * 6. A[:, 1:, 1:]: Not OK + */ +-bool _is_weak_contiguous(torch::Tensor &t) { ++bool _is_weak_contiguous(torch::Tensor& t) { + return t.is_contiguous() || + (t.storage().nbytes() - t.storage_offset() * t.element_size() == + t.numel() * t.element_size()); + } + +-bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, +- bool full_nvlink) { +- auto inp_size = inp.numel() * inp.element_size(); +- // custom allreduce requires input byte size to be multiples of 16 +- if (inp_size % 16 != 0) return false; +- if (!_is_weak_contiguous(inp)) return false; +- if (world_size == 2 || full_nvlink) return inp_size <= max_size; +- // for 4 or more non NVLink-capable GPUs, custom allreduce provides little +- // performance improvement over NCCL. +- return false; +-} ++/** ++ * Performs an out-of-place allreduce and stores result in out. ++ * ++ * If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered. ++ * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first ++ * copied into _reg_buffer. ++ */ ++void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, ++ fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) { ++ auto fa = reinterpret_cast(_fa); ++ const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); ++ auto stream = c10::cuda::getCurrentCUDAStream().stream(); + +-void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, +- cudaStream_t stream) { +- auto fa = reinterpret_cast(_fa); ++ TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); ++ TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK(_is_weak_contiguous(out)); ++ TORCH_CHECK(_is_weak_contiguous(inp)); ++ auto input_size = inp.numel() * inp.element_size(); ++ auto reg_buffer = reinterpret_cast(_reg_buffer); ++ if (reg_buffer) { ++ TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes); ++ AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size, ++ cudaMemcpyDeviceToDevice, stream)); ++ } else { ++ reg_buffer = inp.data_ptr(); ++ } + switch (out.scalar_type()) { + case at::ScalarType::Float: { +- fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), +- reinterpret_cast(out.data_ptr()), ++ fa->allreduce(stream, reinterpret_cast(reg_buffer), ++ reinterpret_cast(out.data_ptr()), + out.numel()); + break; + } + case at::ScalarType::Half: { +- fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), +- reinterpret_cast(out.data_ptr()), +- out.numel()); ++ fa->allreduce(stream, reinterpret_cast(reg_buffer), ++ reinterpret_cast(out.data_ptr()), out.numel()); + break; + } + #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) + case at::ScalarType::BFloat16: { + fa->allreduce( +- stream, reinterpret_cast(inp.data_ptr()), +- reinterpret_cast(out.data_ptr()), out.numel()); ++ stream, reinterpret_cast(reg_buffer), ++ reinterpret_cast(out.data_ptr()), out.numel()); + break; + } + #endif +@@ -98,51 +104,41 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, + } + } + +-void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { +- const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); +- auto stream = c10::cuda::getCurrentCUDAStream().stream(); +- TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); +- TORCH_CHECK_EQ(inp.numel(), out.numel()); +- _all_reduce(_fa, inp, out, stream); +-} +- +-void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, +- torch::Tensor &out) { +- const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); +- auto stream = c10::cuda::getCurrentCUDAStream().stream(); +- +- auto input_size = inp.numel() * inp.element_size(); +- TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); +- TORCH_CHECK_EQ(inp.numel(), out.numel()); +- TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(), +- "registered buffer is too small to contain the input"); +- AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(), +- input_size, cudaMemcpyDeviceToDevice, stream)); +- _all_reduce(_fa, reg_buffer, out, stream); +-} +- + void dispose(fptr_t _fa) { +- auto fa = reinterpret_cast(_fa); +- delete fa; ++ delete reinterpret_cast(_fa); + } + +-int meta_size() { return sizeof(vllm::Signal); } ++int64_t meta_size() { return sizeof(vllm::Signal); } + +-void register_buffer(fptr_t _fa, torch::Tensor &t, +- const std::vector &handles, +- const std::vector &offsets) { +- auto fa = reinterpret_cast(_fa); +- fa->register_buffer(handles, offsets, t.data_ptr()); ++void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs) { ++ auto fa = reinterpret_cast(_fa); ++ TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_); ++ void* ipc_ptrs[8]; ++ for (int i = 0; i < fake_ipc_ptrs.size(); i++) { ++ ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); ++ } ++ fa->register_buffer(ipc_ptrs); + } + +-std::pair, std::vector> get_graph_buffer_ipc_meta( +- fptr_t _fa) { +- auto fa = reinterpret_cast(_fa); +- return fa->get_graph_buffer_ipc_meta(); ++// Use vector to represent byte data for python binding compatibility. ++std::tuple, std::vector> ++get_graph_buffer_ipc_meta(fptr_t _fa) { ++ auto fa = reinterpret_cast(_fa); ++ auto [handle, offsets] = fa->get_graph_buffer_ipc_meta(); ++ std::vector bytes(handle.begin(), handle.end()); ++ return std::make_tuple(bytes, offsets); + } + +-void register_graph_buffers(fptr_t _fa, const std::vector &handles, +- const std::vector> &offsets) { +- auto fa = reinterpret_cast(_fa); +- fa->register_graph_buffers(handles, offsets); ++// Use vector to represent byte data for python binding compatibility. ++void register_graph_buffers(fptr_t _fa, ++ const std::vector>& handles, ++ const std::vector>& offsets) { ++ auto fa = reinterpret_cast(_fa); ++ std::vector bytes; ++ bytes.reserve(handles.size()); ++ for (int i = 0; i < handles.size(); i++) { ++ bytes.emplace_back(handles[i].begin(), handles[i].end()); ++ } ++ bytes.reserve(handles.size()); ++ fa->register_graph_buffers(bytes, offsets); + } +diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh +index 750e68d..6be4d4f 100644 +--- a/csrc/custom_all_reduce.cuh ++++ b/csrc/custom_all_reduce.cuh +@@ -6,6 +6,7 @@ + #include + + #include ++#include + #include + #include + #include +@@ -23,17 +24,23 @@ + + namespace vllm { + +-constexpr int kMaxBlocks = 64; +-// note: we don't want to use atomics for signals because peer atomics are no +-// supported on PCIe links ++constexpr int kMaxBlocks = 36; ++// Counter may overflow, but it's fine since unsigned int overflow is ++// well-defined behavior. ++using FlagType = uint32_t; + struct Signal { +- alignas(128) uint32_t start[kMaxBlocks][8]; +- alignas(128) uint32_t end[kMaxBlocks][8]; ++ alignas(128) FlagType self_counter[kMaxBlocks][8]; ++ // Two sets of peer counters are needed for two syncs. The reason is that ++ // it's possible for peer GPU block to arrive at the second sync point while ++ // the current GPU block haven't passed the first sync point. Thus, peer GPU ++ // may write counter+1 while current GPU is busy waiting for counter. We use ++ // alternating counter array to avoid this possibility. ++ alignas(128) FlagType peer_counter[2][kMaxBlocks][8]; + }; + +-struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; ++struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; + +-struct __align__(16) RankSignals { volatile Signal *signals[8]; }; ++struct __align__(16) RankSignals { Signal* signals[8]; }; + + // like std::array, but aligned + template +@@ -68,11 +75,11 @@ DINLINE half downcast_s(float val) { + // scalar add functions + // for some reason when compiling with Pytorch, the + operator for half and + // bfloat is disabled so we call the intrinsics directly +-DINLINE half &assign_add(half &a, half b) { ++DINLINE half& assign_add(half& a, half b) { + a = __hadd(a, b); + return a; + } +-DINLINE float &assign_add(float &a, float b) { return a += b; } ++DINLINE float& assign_add(float& a, float b) { return a += b; } + + #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) + DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } +@@ -80,14 +87,14 @@ template <> + DINLINE nv_bfloat16 downcast_s(float val) { + return __float2bfloat16(val); + } +-DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) { ++DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) { + a = __hadd(a, b); + return a; + } + #endif + + template +-DINLINE array_t &packed_assign_add(array_t &a, array_t b) { ++DINLINE array_t& packed_assign_add(array_t& a, array_t b) { + #pragma unroll + for (int i = 0; i < N; i++) { + assign_add(a.data[i], b.data[i]); +@@ -123,53 +130,75 @@ DINLINE O downcast(array_t val) { + } + } + +-// This function is meant to be used as the first synchronization in the all +-// reduce kernel. Thus, it doesn't need to make any visibility guarantees for +-// prior memory accesses. Note: volatile writes will not be reordered against +-// other volatile writes. +-template +-DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, +- int rank) { +- if (threadIdx.x < ngpus) { +- // reset flag for next time +- self_sg->end[blockIdx.x][threadIdx.x] = 0; +- // simultaneously write to the corresponding flag of all ranks. +- // Latency = 1 p2p write +- sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; +- // wait until we got true from all ranks +- while (!self_sg->start[blockIdx.x][threadIdx.x]) +- ; +- } +- __syncthreads(); ++static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) { ++#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 ++ asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), ++ "l"(flag_addr)); ++#else ++ asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag), ++ "l"(flag_addr)); ++#endif ++} ++ ++static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) { ++ FlagType flag; ++#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 ++ asm volatile("ld.acquire.sys.global.u32 %0, [%1];" ++ : "=r"(flag) ++ : "l"(flag_addr)); ++#else ++ asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;" ++ : "=r"(flag) ++ : "l"(flag_addr)); ++#endif ++ return flag; ++} ++ ++static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) { ++ asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); ++} ++ ++static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) { ++ FlagType flag; ++ asm volatile("ld.volatile.global.u32 %0, [%1];" ++ : "=r"(flag) ++ : "l"(flag_addr)); ++ return flag; + } + +-// This function is meant to be used as the second or the final synchronization +-// barrier in the all reduce kernel. If it's the final synchronization barrier, +-// we don't need to make any visibility guarantees for prior memory accesses. +-template +-DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, +- int rank) { +- __syncthreads(); +- // eliminate the case that prior writes are not visible after signals become +- // visible. Note that I did not managed to make this happen through a lot of +- // testing. Might be the case that hardware provides stronger guarantee than +- // the memory model. +- if constexpr (!final_sync) __threadfence_system(); ++// is_start: whether this is the very first synchronization barrier. ++// need_fence: whether a memory fence is needed. If true, a release-acquire ++// semantic is used to enforce memory access order before and after this ++// barrier. ++template ++DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, ++ int rank) { ++ if constexpr (!is_start) __syncthreads(); ++ static_assert( ++ !(is_start && need_fence)); // Start barrier shouldn't need fence. + if (threadIdx.x < ngpus) { +- // reset flag for next time +- self_sg->start[blockIdx.x][threadIdx.x] = 0; +- // simultaneously write to the corresponding flag of all ranks. +- // Latency = 1 p2p write +- sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; +- // wait until we got true from all ranks +- while (!self_sg->end[blockIdx.x][threadIdx.x]) +- ; ++ // Increment the counter. Technically we only need one counter, but we use ++ // multiple per block to eliminate the need to share the counter via smem. ++ auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1; ++ // Write the expected counter value to peer and wait for correct value from ++ // peer. ++ auto peer_counter_ptr = ++ &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank]; ++ auto self_counter_ptr = ++ &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x]; ++ if constexpr (need_fence) { ++ st_flag_release(peer_counter_ptr, val); ++ while (ld_flag_acquire(self_counter_ptr) != val); ++ } else { ++ st_flag_volatile(peer_counter_ptr, val); ++ while (ld_flag_volatile(self_counter_ptr) != val); ++ } + } +- if constexpr (!final_sync) __syncthreads(); ++ if constexpr (is_start || need_fence) __syncthreads(); + } + + template +-DINLINE P packed_reduce(const P *ptrs[], int idx) { ++DINLINE P packed_reduce(const P* ptrs[], int idx) { + A tmp = upcast(ptrs[0][idx]); + #pragma unroll + for (int i = 1; i < ngpus; i++) { +@@ -180,34 +209,31 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) { + + template + __global__ void __launch_bounds__(512, 1) +- cross_device_reduce_1stage(RankData *_dp, RankSignals sg, +- volatile Signal *self_sg, T *__restrict__ result, +- int rank, int size) { ++ cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg, ++ T* __restrict__ result, int rank, int size) { + using P = typename packed_t::P; + using A = typename packed_t::A; + // note: we don't reorder the address so the accumulation order is the same + // for all ranks, ensuring bitwise identical results + auto dp = *_dp; +- start_sync(sg, self_sg, rank); ++ multi_gpu_barrier(sg, self_sg, rank); + // do the actual reduction + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { +- ((P *)result)[idx] = +- packed_reduce((const P **)&dp.ptrs[0], idx); ++ ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); + } +- end_sync(sg, self_sg, rank); ++ multi_gpu_barrier(sg, self_sg, rank); + } + + template +-DINLINE P *get_tmp_buf(volatile Signal *sg) { +- return (P *)(((Signal *)sg) + 1); ++DINLINE P* get_tmp_buf(Signal* sg) { ++ return (P*)(((Signal*)sg) + 1); + } + + template + __global__ void __launch_bounds__(512, 1) +- cross_device_reduce_2stage(RankData *_dp, RankSignals sg, +- volatile Signal *self_sg, T *__restrict__ result, +- int rank, int size) { ++ cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg, ++ T* __restrict__ result, int rank, int size) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = gridDim.x * blockDim.x; + using P = typename packed_t::P; +@@ -216,21 +242,21 @@ __global__ void __launch_bounds__(512, 1) + int start = rank * part; + int end = rank == ngpus - 1 ? size : start + part; + int largest_part = part + size % ngpus; +- const P *ptrs[ngpus]; +- P *tmps[ngpus]; ++ const P* ptrs[ngpus]; ++ P* tmps[ngpus]; + #pragma unroll + for (int i = 0; i < ngpus; i++) { + int target = (rank + i) % ngpus; +- ptrs[i] = (const P *)_dp->ptrs[target]; ++ ptrs[i] = (const P*)_dp->ptrs[target]; + tmps[i] = get_tmp_buf

(sg.signals[target]); + } + auto tmp_out = tmps[0]; +- start_sync(sg, self_sg, rank); ++ multi_gpu_barrier(sg, self_sg, rank); + // stage 1: reduce scatter + for (int idx = start + tid; idx < end; idx += stride) { + tmp_out[idx - start] = packed_reduce(ptrs, idx); + } +- end_sync(sg, self_sg, rank); ++ multi_gpu_barrier(sg, self_sg, rank); + + // stage 2: allgather. Note: it's important to match the tid between + // the two stages, because visibility across devices is only guaranteed +@@ -243,7 +269,7 @@ __global__ void __launch_bounds__(512, 1) + int gather_from_rank = ((rank + i) % ngpus); + if (gather_from_rank == ngpus - 1 || idx < part) { + int dst_idx = gather_from_rank * part + idx; +- ((P *)result)[dst_idx] = tmps[i][idx]; ++ ((P*)result)[dst_idx] = tmps[i][idx]; + } + } + } +@@ -259,71 +285,76 @@ class CustomAllreduce { + int world_size_; + bool full_nvlink_; + +- // below are device pointers + RankSignals sg_; +- std::unordered_map buffers_; +- Signal *self_sg_; +- +- // stores the registered device pointers from all ranks ++ // Stores an map from a pointer to its peer pointters from all ranks. ++ std::unordered_map buffers_; ++ Signal* self_sg_; ++ ++ // Stores rank data from all ranks. This is mainly for cuda graph purposes. ++ // For cuda graph to work, all kernel arguments must be fixed during graph ++ // capture time. However, the peer pointers are not known during graph capture ++ // time. Therefore, during capture, we increment the rank data pointer and use ++ // that as the argument to the kernel. The kernel arguments are stored in ++ // graph_unreg_buffers_. The actual peer pointers will be filled in at the ++ // memory pointed to by the pointers in graph_unreg_buffers_ when ++ // the IPC handles are exchanged between ranks. ++ // ++ // The overall process looks like this: ++ // 1. Graph capture. ++ // 2. Each rank obtains the IPC handles for each addresses used during cuda ++ // graph capture using get_graph_buffer_ipc_meta. ++ // 3. (In Python) all gather the IPC handles. ++ // 4. Obtain the peer pointers by opening the IPC handles, and store them in ++ // the rank data array at corresponding positions. + RankData *d_rank_data_base_, *d_rank_data_end_; +- std::vector graph_unreg_buffers_; ++ std::vector graph_unreg_buffers_; + // a map from IPC handles to opened IPC pointers +- std::map ipc_handles_; ++ std::map ipc_handles_; + + /** +- * meta is a pointer to device metadata and temporary buffer for allreduce. ++ * Signals are an array of ipc-enabled buffers from all ranks. ++ * For each of the buffer, the layout is as follows: ++ * | -- sizeof(Signal) -- | ------ a few MB ----- | ++ * The first section is for allreduce synchronization, and the second section ++ * is for storing the intermediate results required by some allreduce algos. + * +- * There's a total of sizeof(Signal) of prefix before the actual data, +- * so meta + 1 points to actual temporary buffer. +- * +- * note: this class does not own any device memory. Any required buffers +- * are passed in from the constructor ++ * Note: this class does not own any device memory. Any required buffers ++ * are passed in from the constructor. + */ +- CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz, +- const cudaIpcMemHandle_t *handles, +- const std::vector &offsets, int rank, +- bool full_nvlink = true) ++ CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz, ++ int rank, int world_size, bool full_nvlink = true) + : rank_(rank), +- world_size_(offsets.size()), ++ world_size_(world_size), + full_nvlink_(full_nvlink), +- self_sg_(meta), +- d_rank_data_base_(reinterpret_cast(rank_data)), ++ self_sg_(signals[rank]), ++ d_rank_data_base_(reinterpret_cast(rank_data)), + d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { + for (int i = 0; i < world_size_; i++) { +- Signal *rank_sg; +- if (i != rank_) { +- char *handle = open_ipc_handle(&handles[i]); +- handle += offsets[i]; +- rank_sg = (Signal *)handle; +- } else { +- rank_sg = self_sg_; +- } +- sg_.signals[i] = rank_sg; ++ sg_.signals[i] = signals[i]; + } + } + +- char *open_ipc_handle(const void *ipc_handle) { ++ char* open_ipc_handle(const void* ipc_handle) { + auto [it, new_handle] = +- ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr}); ++ ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); + if (new_handle) { +- char *ipc_ptr; +- CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr, +- *((const cudaIpcMemHandle_t *)ipc_handle), ++ char* ipc_ptr; ++ CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr, ++ *((const cudaIpcMemHandle_t*)ipc_handle), + cudaIpcMemLazyEnablePeerAccess)); + it->second = ipc_ptr; + } + return it->second; + } + +- std::pair, std::vector> +- get_graph_buffer_ipc_meta() { ++ std::pair> get_graph_buffer_ipc_meta() { + auto num_buffers = graph_unreg_buffers_.size(); + auto handle_sz = sizeof(cudaIpcMemHandle_t); +- std::vector handles(handle_sz * num_buffers, 0); ++ std::string handles(handle_sz * num_buffers, static_cast(0)); + std::vector offsets(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto ptr = graph_unreg_buffers_[i]; +- void *base_ptr; ++ void* base_ptr; + // note: must share the base address of each allocation, or we get wrong + // address + if (cuPointerGetAttribute(&base_ptr, +@@ -331,8 +362,8 @@ class CustomAllreduce { + (CUdeviceptr)ptr) != CUDA_SUCCESS) + throw std::runtime_error("failed to get pointer attr"); + CUDACHECK(cudaIpcGetMemHandle( +- (cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr)); +- offsets[i] = ((char *)ptr) - ((char *)base_ptr); ++ (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); ++ offsets[i] = ((char*)ptr) - ((char*)base_ptr); + } + return std::make_pair(handles, offsets); + } +@@ -344,26 +375,22 @@ class CustomAllreduce { + std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); + } + +- void register_buffer(const std::vector &handles, +- const std::vector &offsets, void *self) { ++ /** ++ * Register already-shared IPC pointers. ++ */ ++ void register_buffer(void** ptrs) { + check_rank_data_capacity(); + RankData data; + for (int i = 0; i < world_size_; i++) { +- if (i != rank_) { +- char *handle = open_ipc_handle(handles[i].data()); +- handle += offsets[i]; +- data.ptrs[i] = handle; +- } else { +- data.ptrs[i] = self; +- } ++ data.ptrs[i] = ptrs[i]; + } + auto d_data = d_rank_data_base_++; + CUDACHECK( + cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); +- buffers_[self] = d_data; ++ buffers_[ptrs[rank_]] = d_data; + } + +- // note: when registering graph buffers, we intentionally choose to not ++ // Note: when registering graph buffers, we intentionally choose to not + // deduplicate the addresses. That means if the allocator reuses some + // addresses, they will be registered again. This is to account for the remote + // possibility of different allocation patterns between ranks. For example, +@@ -371,17 +398,17 @@ class CustomAllreduce { + // got a different address. IPC handles have internal reference counting + // mechanism so overhead should be small. + void register_graph_buffers( +- const std::vector &handles, +- const std::vector> &offsets) { ++ const std::vector& handles, ++ const std::vector>& offsets) { + auto num_buffers = graph_unreg_buffers_.size(); + check_rank_data_capacity(num_buffers); + std::vector rank_data(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto self_ptr = graph_unreg_buffers_[i]; +- auto &rd = rank_data[i]; ++ auto& rd = rank_data[i]; + for (int j = 0; j < world_size_; j++) { + if (j != rank_) { +- char *handle = ++ char* handle = + open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); + handle += offsets[j][i]; + rd.ptrs[j] = handle; +@@ -398,14 +425,16 @@ class CustomAllreduce { + } + + /** +- * This is the result after careful grid search. Using 36 blocks give the best +- * or close to the best runtime on the devices I tried: A100, A10, A30, T4, +- * V100. You'll notice that NCCL kernels also only take a small amount of SMs. +- * Not quite sure the underlying reason, but my guess is that too many SMs +- * will cause contention on NVLink bus. ++ * Performs allreduce, assuming input has already been registered. ++ * ++ * Block and grid default configs are results after careful grid search. Using ++ * 36 blocks give the best or close to the best runtime on the devices I ++ * tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only ++ * take a small amount of SMs. Not quite sure the underlying reason, but my ++ * guess is that too many SMs will cause contention on NVLink bus. + */ + template +- void allreduce(cudaStream_t stream, T *input, T *output, int size, ++ void allreduce(cudaStream_t stream, T* input, T* output, int size, + int threads = 512, int block_limit = 36) { + auto d = packed_t::P::size; + if (size % d != 0) +@@ -418,7 +447,7 @@ class CustomAllreduce { + std::to_string(kMaxBlocks) + ". Got " + + std::to_string(block_limit)); + +- RankData *ptrs; ++ RankData* ptrs; + cudaStreamCaptureStatus status; + CUDACHECK(cudaStreamIsCapturing(stream, &status)); + if (status == cudaStreamCaptureStatusActive) { +@@ -440,6 +469,8 @@ class CustomAllreduce { + #define KL(ngpus, name) \ + name<<>>(ptrs, sg_, self_sg_, output, \ + rank_, size); ++ // TODO(hanzhi713): Threshold is different for A100 and H100. ++ // Add per device threshold. + #define REDUCE_CASE(ngpus) \ + case ngpus: { \ + if (world_size_ == 2) { \ +diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu +index c34a503..b59ea40 100644 +--- a/csrc/custom_all_reduce_test.cu ++++ b/csrc/custom_all_reduce_test.cu +@@ -1,15 +1,15 @@ + /** + * This is a standalone test for custom allreduce. + * To compile, make sure you have MPI and NCCL installed in your system. +- * export MPI_HOME=XXX ++ * export MPI_HOME=xxx + * nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o +- * custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi ++ * custom_all_reduce_test -lnccl -I${MPI_HOME} -lmpi + * + * Warning: this C++ test is not designed to be very readable and was used + * during the rapid prototyping process. + * + * To run: +- * mpirun -np 8 ./custom_all_reduce_test ++ * mpirun --allow-run-as-root -np 8 ./custom_all_reduce_test + */ + #include + #include +@@ -44,11 +44,18 @@ + } while (0) + + __global__ void dummy_kernel() { ++#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms ++#else ++ for (int i = 0; i < 100; i++) { ++ long long int start = clock64(); ++ while (clock64() - start < 150000000); // approximately 98.4ms on P40 ++ } ++#endif + } + + template +-__global__ void set_data(T *data, int size, int myRank) { ++__global__ void set_data(T* data, int size, int myRank) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + data[idx] = myRank * 0.11f; +@@ -56,8 +63,8 @@ __global__ void set_data(T *data, int size, int myRank) { + } + + template +-__global__ void convert_data(const T *data1, const T *data2, double *fdata1, +- double *fdata2, int size) { ++__global__ void convert_data(const T* data1, const T* data2, double* fdata1, ++ double* fdata2, int size) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + fdata1[idx] = data1[idx]; +@@ -65,7 +72,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1, + } + } + +-__global__ void init_rand(curandState_t *state, int size, int nRanks) { ++__global__ void init_rand(curandState_t* state, int size, int nRanks) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + for (int i = 0; i < nRanks; i++) { +@@ -75,7 +82,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) { + } + + template +-__global__ void gen_data(curandState_t *state, T *data, double *ground_truth, ++__global__ void gen_data(curandState_t* state, T* data, double* ground_truth, + int myRank, int nRanks, int size) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { +@@ -91,9 +98,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth, + } + + template +-void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, ++void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, + int data_size, bool performance_test) { +- T *result; ++ T* result; + cudaStream_t stream; + CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); +@@ -101,8 +108,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, + + cudaIpcMemHandle_t self_data_handle; + cudaIpcMemHandle_t data_handles[8]; +- vllm::Signal *buffer; +- T *self_data_copy; ++ vllm::Signal* buffer; ++ T* self_data_copy; + /** + * Allocate IPC buffer + * +@@ -125,32 +132,34 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, + MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t), + MPI_BYTE, MPI_COMM_WORLD)); + +- void *rank_data; ++ void* rank_data; + size_t rank_data_sz = 16 * 1024 * 1024; + CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); +- std::vector offsets(nRanks, 0); +- vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, +- offsets, myRank); +- auto *self_data = +- reinterpret_cast(reinterpret_cast(buffer) + +- sizeof(vllm::Signal) + data_size * sizeof(T)); ++ vllm::Signal* ipc_ptrs[8]; ++ for (int i = 0; i < nRanks; i++) { ++ if (i == myRank) ++ ipc_ptrs[i] = buffer; ++ else ++ CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptrs[i], data_handles[i], ++ cudaIpcMemLazyEnablePeerAccess)); ++ } ++ vllm::CustomAllreduce fa(ipc_ptrs, rank_data, rank_data_sz, myRank, nRanks); ++ auto* self_data = ++ reinterpret_cast(reinterpret_cast(buffer) + ++ sizeof(vllm::Signal) + data_size * sizeof(T)); + // hack buffer registration + { +- std::vector handles; +- handles.reserve(nRanks); ++ void* data[8]; + for (int i = 0; i < nRanks; i++) { +- char *begin = (char *)&data_handles[i]; +- char *end = (char *)&data_handles[i + 1]; +- handles.emplace_back(begin, end); ++ data[i] = ++ ((char*)ipc_ptrs[i]) + sizeof(vllm::Signal) + data_size * sizeof(T); + } +- std::vector offsets(nRanks, +- sizeof(vllm::Signal) + data_size * sizeof(T)); +- fa.register_buffer(handles, offsets, self_data); ++ fa.register_buffer(data); + } + +- double *ground_truth; ++ double* ground_truth; + CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double))); +- curandState_t *states; ++ curandState_t* states; + CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); + init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); + gen_data<<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank, +@@ -287,7 +296,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, + CUDACHECK(cudaStreamDestroy(stream)); + } + +-int main(int argc, char **argv) { ++int main(int argc, char** argv) { + int nRanks, myRank; + MPICHECK(MPI_Init(&argc, &argv)); + MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank)); +@@ -296,21 +305,25 @@ int main(int argc, char **argv) { + ncclUniqueId id; + ncclComm_t comm; + if (myRank == 0) ncclGetUniqueId(&id); +- MPICHECK(MPI_Bcast(static_cast(&id), sizeof(id), MPI_BYTE, 0, ++ MPICHECK(MPI_Bcast(static_cast(&id), sizeof(id), MPI_BYTE, 0, + MPI_COMM_WORLD)); + NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); + + bool performance_test = true; + cudaProfilerStart(); +- // for (int threads : {256, 512}) { ++ // Uncomment to scan through different block size configs. ++ // for (int threads : {256, 512, 1024}) { + // for (int block_limit = 16; block_limit < 112; block_limit += 4) { +- // run(myRank, nRanks, comm, threads, block_limit, 4096 * 1024); ++ // run(myRank, nRanks, comm, threads, block_limit, 1024 * 1024, ++ // performance_test); + // } + // } ++ // Scan through different sizes to test performance. + for (int sz = 512; sz <= (8 << 20); sz *= 2) { + run(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test); + } + + cudaProfilerStop(); ++ MPICHECK(MPI_Finalize()); + return EXIT_SUCCESS; + } +diff --git a/csrc/cutlass_extensions/common.cpp b/csrc/cutlass_extensions/common.cpp +new file mode 100644 +index 0000000..3d2093a +--- /dev/null ++++ b/csrc/cutlass_extensions/common.cpp +@@ -0,0 +1,11 @@ ++#include "cutlass_extensions/common.hpp" ++ ++int32_t get_sm_version_num() { ++ int32_t major_capability, minor_capability; ++ cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, ++ 0); ++ cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, ++ 0); ++ int32_t version_num = major_capability * 10 + minor_capability; ++ return version_num; ++} +\ No newline at end of file +diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp +new file mode 100644 +index 0000000..85e359a +--- /dev/null ++++ b/csrc/cutlass_extensions/common.hpp +@@ -0,0 +1,35 @@ ++#pragma once ++ ++#include "cutlass/cutlass.h" ++#include ++#include "cuda_runtime.h" ++#include ++ ++/** ++ * Helper function for checking CUTLASS errors ++ */ ++#define CUTLASS_CHECK(status) \ ++ { \ ++ cutlass::Status error = status; \ ++ TORCH_CHECK(error == cutlass::Status::kSuccess, \ ++ cutlassGetStatusString(error)); \ ++ } ++ ++/** ++ * Panic wrapper for unwinding CUDA runtime errors ++ */ ++#define CUDA_CHECK(status) \ ++ { \ ++ cudaError_t error = status; \ ++ TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \ ++ } ++ ++inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { ++ int max_shared_mem_per_block_opt_in = 0; ++ cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, ++ cudaDevAttrMaxSharedMemoryPerBlockOptin, ++ device); ++ return max_shared_mem_per_block_opt_in; ++} ++ ++int32_t get_sm_version_num(); +diff --git a/csrc/cutlass_extensions/cute_utils.cuh b/csrc/cutlass_extensions/cute_utils.cuh +new file mode 100644 +index 0000000..f61fe3c +--- /dev/null ++++ b/csrc/cutlass_extensions/cute_utils.cuh +@@ -0,0 +1,68 @@ ++#pragma once ++ ++#include ++#include ++namespace cute { ++ ++//////////////////////////////////////////////////////////////////// ++// layout utils ++//////////////////////////////////////////////////////////////////// ++ ++// Permute layout based on indices, example: ++// permute_layout<1, 0>(layout) will swap the two dimensions ++// permute_layout<0, 2, 1>(layout) will swap the last two dimensions ++template ++CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) { ++ static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch"); ++ return cute::make_layout(cute::get(l)...); ++} ++ ++// is the layout f(x) = x ++template ++CUTE_HOST_DEVICE static constexpr bool is_identity_layout() { ++ if constexpr (std::is_same_v) { ++ return true; ++ } else { ++ constexpr auto coalesced_layout = coalesce(Layout{}); ++ if constexpr (rank(coalesced_layout) == 1 && ++ stride<0>(coalesced_layout) == 1) { ++ return true; ++ } ++ return false; ++ } ++} ++ ++//////////////////////////////////////////////////////////////////// ++// Pointer utils ++//////////////////////////////////////////////////////////////////// ++ ++template ++static constexpr auto get_logical_ptr(PointerType* ptr) { ++ if constexpr (cute::sizeof_bits_v < 8) { ++ return cute::subbyte_iterator(ptr); ++ } else { ++ return ptr; ++ } ++} ++ ++//////////////////////////////////////////////////////////////////// ++// Misc utils ++//////////////////////////////////////////////////////////////////// ++ ++template ++CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() { ++ constexpr auto bits = sizeof_bits_v * Elements{}; ++ if constexpr (bits % 128 == 0) { ++ return AutoVectorizingCopyWithAssumedAlignment<128>{}; ++ } else if constexpr (bits % 64 == 0) { ++ return AutoVectorizingCopyWithAssumedAlignment<64>{}; ++ } else if constexpr (bits % 32 == 0) { ++ return AutoVectorizingCopyWithAssumedAlignment<32>{}; ++ } else if constexpr (bits % 16 == 0) { ++ return AutoVectorizingCopyWithAssumedAlignment<16>{}; ++ } else { ++ return AutoVectorizingCopyWithAssumedAlignment<8>{}; ++ } ++} ++ ++}; // namespace cute +diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp +new file mode 100644 +index 0000000..7aa87fe +--- /dev/null ++++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp +@@ -0,0 +1,497 @@ ++/*************************************************************************************************** ++ * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights ++ *reserved. SPDX-License-Identifier: BSD-3-Clause ++ * ++ * Redistribution and use in source and binary forms, with or without ++ * modification, are permitted provided that the following conditions are met: ++ * ++ * 1. Redistributions of source code must retain the above copyright notice, ++ *this list of conditions and the following disclaimer. ++ * ++ * 2. Redistributions in binary form must reproduce the above copyright notice, ++ * this list of conditions and the following disclaimer in the documentation ++ * and/or other materials provided with the distribution. ++ * ++ * 3. Neither the name of the copyright holder nor the names of its ++ * contributors may be used to endorse or promote products derived from ++ * this software without specific prior written permission. ++ * ++ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" ++ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE ++ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ++ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE ++ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR ++ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF ++ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS ++ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN ++ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ++ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE ++ *POSSIBILITY OF SUCH DAMAGE. ++ * ++ **************************************************************************************************/ ++ ++// ++// This file is a modified excerpt of ++// include/cutlass/epilogue/fusion/visitor_load.hpp from ++// https://github.com/NVIDIA/cutlass v3.5.0 ++// It has been modified to support either ++// row/column or scalar broadcasting where the tensor being loaded from is ++// always passed in via a device pointer. This lets one compiled kernel handle ++// all cases of per-tensor or per-channel/per-token quantization. ++// ++// This interface also allows the scales to be passed in as tensors that ++// consistently reside on the device, which avoids an issue with a previous ++// implementation where scalars needed to be on the CPU since they ++// were passed in via float values. This created a potential performance hazard ++// if scales were initially on the device, and caused torch.compile graph ++// breaks when moving scales to the CPU. ++// ++#pragma once ++ ++// Turn off clang-format for the entire file to keep it close to upstream ++// clang-format off ++ ++#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" ++#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" ++#include "cute/tensor.hpp" ++ ++namespace cutlass::epilogue::threadblock { ++ ++using namespace cute; ++using namespace detail; ++ ++template< ++ class ThreadMap, ++ class Element, ++ class StrideMNL ++> ++struct VisitorRowOrScalarBroadcast { ++ ++ // This struct has been modified to have a bool indicating that ptr_row is a ++ // scalar that must be broadcast. ++ struct Arguments { ++ Element const* ptr_row = nullptr; ++ bool row_broadcast = true; ++ StrideMNL dRow = {}; ++ }; ++ ++ using Params = Arguments; ++ ++ template ++ static constexpr Params ++ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { ++ return args; ++ } ++ ++ template ++ static size_t ++ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { ++ return 0; ++ } ++ ++ struct SharedStorage {}; ++ ++ // Global load type ++ static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; ++ using VecType = uint_bit_t; ++ static int constexpr VecLength = sizeof(VecType) / sizeof(Element); ++ ++ CUTLASS_HOST_DEVICE ++ VisitorRowOrScalarBroadcast() { } ++ ++ CUTLASS_HOST_DEVICE ++ VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) ++ : params_ptr(¶ms) { } ++ ++ Params const* params_ptr; ++ ++ template ++ struct Callbacks : EmptyCallbacks { ++ CUTLASS_DEVICE ++ Callbacks( ++ GTensor&& tC_gRow, ++ RTensor&& tC_rRow, ++ CTensor&& tC_cRow, ++ ProblemShape problem_shape, ++ Params const* params_ptr ++ ): ++ tC_gRow(cute::forward(tC_gRow)), ++ tC_rRow(cute::forward(tC_rRow)), ++ tC_cRow(cute::forward(tC_cRow)), ++ n(get<1>(problem_shape)), ++ params_ptr(params_ptr) { } ++ ++ GTensor tC_gRow; ++ RTensor tC_rRow; ++ CTensor tC_cRow; ++ Params const* params_ptr; ++ int n; ++ ++ // This function is modified from VisitorRowBroadcast ++ CUTLASS_DEVICE void ++ begin_epilogue() { ++ clear(tC_rRow); ++ auto src_v = filter(tC_gRow); ++ auto coord_v = filter(tC_cRow); ++ auto dst_v = filter(tC_rRow); ++ ++ if (params_ptr->row_broadcast) { ++ // In this case we are loading from a row vector and broadcasting ++ CUTLASS_PRAGMA_UNROLL ++ for (int i = 0; i < size(src_v); ++i) { ++ bool guard = get<1>(coord_v(i)) < n; ++ cutlass::arch::global_load( ++ dst_v(i), (void const*)&src_v(i), guard); ++ } ++ } else { ++ // In this case we are loading from a scalar and broadcasting ++ VecType filled_vec; ++ CUTLASS_PRAGMA_UNROLL ++ for (int i = 0; i < VecLength; i++) { ++ reinterpret_cast(&filled_vec)[i] = *(params_ptr->ptr_row); ++ } ++ ++ CUTLASS_PRAGMA_UNROLL ++ for (int i = 0; i < size(src_v); ++i) { ++ if (get<1>(coord_v(i)) < n) { ++ dst_v(i) = filled_vec; ++ } ++ } ++ } ++ } ++ ++ template ++ CUTLASS_DEVICE auto // returns an Array ++ visit(int iter_idx, int row_idx, int column_idx, int frg_idx, ++ Array const& frg_acc) { ++ Tensor rRow_frg = recast>(coalesce(tC_rRow)); ++ return rRow_frg(column_idx); ++ } ++ }; ++ ++ template ++ CUTLASS_DEVICE auto ++ get_callbacks( ++ gemm::GemmCoord threadblock_tile_offset, ++ int thread_idx, ++ ProblemShape problem_shape ++ ) { ++ Tensor mRow = make_tensor( ++ make_gmem_ptr(params_ptr->ptr_row), ++ problem_shape, ++ params_ptr->dRow); ++ ++ // VECTOR, FRAGMENT_COLUMN ++ Tensor tC_gRow = recast( ++ ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) ++ )(_,_,_0{},_0{},_0{},_0{}); ++ Tensor tC_rRow = make_tensor_like(tC_gRow); ++ ++ // Generate the pred tensor ++ Tensor cRow = make_identity_tensor(mRow.shape()); ++ Tensor tC_cRow = outer_partition( ++ ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), ++ Shape>{}, ++ (_0{}) ++ ); ++ ++ return Callbacks< ++ decltype(tC_gRow), decltype(tC_rRow), ++ decltype(tC_cRow), ProblemShape>( ++ cute::move(tC_gRow), ++ cute::move(tC_rRow), ++ cute::move(tC_cRow), ++ problem_shape, ++ params_ptr ++ ); ++ } ++ ++}; ++ ++///////////////////////////////////////////////////////////////////////////////////////////////// ++ ++// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null ++template< ++ class ThreadMap, ++ class Element, ++ class StrideMNL ++> ++struct VisitorRowOrZeroBroadcast { ++ ++ // This struct has been modified to remove null_default (because it's always 0) ++ struct Arguments { ++ Element const* ptr_row = nullptr; ++ StrideMNL dRow = {}; ++ }; ++ ++ using Params = Arguments; ++ ++ template ++ static constexpr Params ++ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { ++ return args; ++ } ++ ++ template ++ static size_t ++ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { ++ return 0; ++ } ++ ++ struct SharedStorage {}; ++ ++ // Global load type ++ static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; ++ using VecType = uint_bit_t; ++ static int constexpr VecLength = sizeof(VecType) / sizeof(Element); ++ ++ CUTLASS_HOST_DEVICE ++ VisitorRowOrZeroBroadcast() { } ++ ++ CUTLASS_HOST_DEVICE ++ VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage) ++ : params_ptr(¶ms) { } ++ ++ Params const* params_ptr; ++ ++ template ++ struct Callbacks : EmptyCallbacks { ++ CUTLASS_DEVICE ++ Callbacks( ++ GTensor&& tC_gRow, ++ RTensor&& tC_rRow, ++ CTensor&& tC_cRow, ++ ProblemShape problem_shape, ++ Params const* params_ptr ++ ): ++ tC_gRow(cute::forward(tC_gRow)), ++ tC_rRow(cute::forward(tC_rRow)), ++ tC_cRow(cute::forward(tC_cRow)), ++ n(get<1>(problem_shape)), ++ params_ptr(params_ptr) { } ++ ++ GTensor tC_gRow; ++ RTensor tC_rRow; ++ CTensor tC_cRow; ++ Params const* params_ptr; ++ int n; ++ ++ // This function is modified from VisitorRowBroadcast ++ CUTLASS_DEVICE void ++ begin_epilogue() { ++ clear(tC_rRow); ++ auto src_v = filter(tC_gRow); ++ auto coord_v = filter(tC_cRow); ++ auto dst_v = filter(tC_rRow); ++ ++ if (params_ptr->ptr_row != nullptr) { ++ // In this case we are loading from a row vector and broadcasting ++ CUTLASS_PRAGMA_UNROLL ++ for (int i = 0; i < size(src_v); ++i) { ++ bool guard = get<1>(coord_v(i)) < n; ++ cutlass::arch::global_load( ++ dst_v(i), (void const*)&src_v(i), guard); ++ } ++ } else { ++ // In this case we are broadcasting 0 ++ VecType filled_vec; ++ CUTLASS_PRAGMA_UNROLL ++ for (int i = 0; i < VecLength; i++) { ++ reinterpret_cast(&filled_vec)[i] = Element{0}; ++ } ++ ++ CUTLASS_PRAGMA_UNROLL ++ for (int i = 0; i < size(src_v); ++i) { ++ if (get<1>(coord_v(i)) < n) { ++ dst_v(i) = filled_vec; ++ } ++ } ++ } ++ } ++ ++ template ++ CUTLASS_DEVICE auto // returns an Array ++ visit(int iter_idx, int row_idx, int column_idx, int frg_idx, ++ Array const& frg_acc) { ++ Tensor rRow_frg = recast>(coalesce(tC_rRow)); ++ return rRow_frg(column_idx); ++ } ++ }; ++ ++ template ++ CUTLASS_DEVICE auto ++ get_callbacks( ++ gemm::GemmCoord threadblock_tile_offset, ++ int thread_idx, ++ ProblemShape problem_shape ++ ) { ++ Tensor mRow = make_tensor( ++ make_gmem_ptr(params_ptr->ptr_row), ++ problem_shape, ++ params_ptr->dRow); ++ ++ // VECTOR, FRAGMENT_COLUMN ++ Tensor tC_gRow = recast( ++ ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) ++ )(_,_,_0{},_0{},_0{},_0{}); ++ Tensor tC_rRow = make_tensor_like(tC_gRow); ++ ++ // Generate the pred tensor ++ Tensor cRow = make_identity_tensor(mRow.shape()); ++ Tensor tC_cRow = outer_partition( ++ ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), ++ Shape>{}, ++ (_0{}) ++ ); ++ ++ return Callbacks< ++ decltype(tC_gRow), decltype(tC_rRow), ++ decltype(tC_cRow), ProblemShape>( ++ cute::move(tC_gRow), ++ cute::move(tC_rRow), ++ cute::move(tC_cRow), ++ problem_shape, ++ params_ptr ++ ); ++ } ++ ++}; ++ ++ ++///////////////////////////////////////////////////////////////////////////////////////////////// ++ ++// Column vector broadcast ++template< ++ class ThreadMap, ++ class Element, ++ class StrideMNL = Stride<_1,_0,_0> ++> ++struct VisitorColOrScalarBroadcast { ++ ++ // This struct has been modified to have a bool indicating that ptr_col is a ++ // scalar that must be broadcast. ++ struct Arguments { ++ Element const* ptr_col = nullptr; ++ bool col_broadcast = true; ++ StrideMNL dCol = {}; ++ }; ++ ++ using Params = Arguments; ++ ++ template ++ static constexpr Params ++ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { ++ return args; ++ } ++ ++ template ++ static size_t ++ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { ++ return 0; ++ } ++ ++ struct SharedStorage { }; ++ ++ CUTLASS_HOST_DEVICE ++ VisitorColOrScalarBroadcast() { } ++ ++ CUTLASS_HOST_DEVICE ++ VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) ++ : params_ptr(¶ms) { } ++ ++ Params const* params_ptr; ++ ++ template ++ struct Callbacks : EmptyCallbacks { ++ CUTLASS_DEVICE ++ Callbacks( ++ GTensor&& tC_gCol, ++ RTensor&& tC_rCol, ++ CTensor&& tC_cCol, ++ ProblemShape problem_shape, ++ Params const* params_ptr ++ ): ++ tC_gCol(cute::forward(tC_gCol)), ++ tC_rCol(cute::forward(tC_rCol)), ++ tC_cCol(cute::forward(tC_cCol)), ++ m(get<0>(problem_shape)), ++ params_ptr(params_ptr) { } ++ ++ GTensor tC_gCol; ++ RTensor tC_rCol; ++ CTensor tC_cCol; ++ Params const* params_ptr; ++ int m; ++ ++ // This function is modified from VisitorColBroadcast ++ CUTLASS_DEVICE void ++ begin_epilogue() { ++ clear(tC_rCol); ++ ++ Tensor pred = make_tensor(shape(tC_gCol)); ++ CUTLASS_PRAGMA_UNROLL ++ for (int i = 0; i < size(pred); ++i) { ++ pred(i) = get<0>(tC_cCol(i)) < m; ++ } ++ ++ if (params_ptr->col_broadcast) { ++ // In this case we are loading from a column vector and broadcasting ++ copy_if(pred, tC_gCol, tC_rCol); ++ } else { ++ // In this case we are loading from a scalar and broadcasting ++ auto dst_v = filter(tC_rCol); ++ ++ CUTLASS_PRAGMA_UNROLL ++ for (int i = 0; i < size(dst_v); ++i) { ++ if (pred(i)) { ++ dst_v(i) = *(params_ptr->ptr_col); ++ } ++ } ++ } ++ } ++ ++ template ++ CUTLASS_DEVICE auto // returns an Array ++ visit(int iter_idx, int row_idx, int column_idx, int frg_idx, ++ Array const& frg_acc) { ++ Array frg_col; ++ frg_col.fill(tC_rCol(row_idx,iter_idx)); ++ return frg_col; ++ } ++ }; ++ ++ template ++ CUTLASS_DEVICE auto ++ get_callbacks( ++ gemm::GemmCoord threadblock_tile_offset, ++ int thread_idx, ++ ProblemShape problem_shape ++ ) { ++ Tensor mCol = make_tensor( ++ make_gmem_ptr(params_ptr->ptr_col), ++ problem_shape, ++ params_ptr->dCol); ++ ++ // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER ++ Tensor tC_gCol = group_modes<1,4>( ++ ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); ++ Tensor tC_rCol = make_tensor_like(tC_gCol); ++ ++ // Generate the pred tensor ++ Tensor cCol = make_identity_tensor(mCol.shape()); ++ Tensor tC_cCol = group_modes<1,4>( ++ ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); ++ ++ return Callbacks< ++ decltype(tC_gCol), decltype(tC_rCol), ++ decltype(tC_cCol), ProblemShape>( ++ cute::move(tC_gCol), ++ cute::move(tC_rCol), ++ cute::move(tC_cCol), ++ problem_shape, ++ params_ptr ++ ); ++ } ++}; ++ ++} +diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp +new file mode 100644 +index 0000000..58b1e8f +--- /dev/null ++++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp +@@ -0,0 +1,447 @@ ++/*************************************************************************************************** ++ * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights ++ *reserved. SPDX-License-Identifier: BSD-3-Clause ++ * ++ * Redistribution and use in source and binary forms, with or without ++ * modification, are permitted provided that the following conditions are met: ++ * ++ * 1. Redistributions of source code must retain the above copyright notice, ++ *this list of conditions and the following disclaimer. ++ * ++ * 2. Redistributions in binary form must reproduce the above copyright notice, ++ * this list of conditions and the following disclaimer in the documentation ++ * and/or other materials provided with the distribution. ++ * ++ * 3. Neither the name of the copyright holder nor the names of its ++ * contributors may be used to endorse or promote products derived from ++ * this software without specific prior written permission. ++ * ++ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" ++ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE ++ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ++ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE ++ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR ++ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF ++ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS ++ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN ++ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ++ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE ++ *POSSIBILITY OF SUCH DAMAGE. ++ * ++ **************************************************************************************************/ ++ ++// ++// This file is a modified excerpt of ++// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp ++// from https://github.com/NVIDIA/cutlass v3.5.0 ++// It has been modified to support either row/column or scalar broadcasting ++// where the tensor being loaded from is always passed in via a device pointer. ++// This lets one compiled kernel handle all cases of per-tensor or ++// per-channel/per-token quantization. ++// ++// This interface also allows the scales to be passed in as tensors that ++// consistently reside on the device, which avoids an issue with a previous ++// implementation where scalars needed to be on the CPU since they ++// were passed in via float values. This created a potential performance hazard ++// if scales were initially on the device, and caused torch.compile graphs ++// breaks when moving scales to the CPU. ++// ++#pragma once ++ ++// Turn off clang-format for the entire file to keep it close to upstream ++// clang-format off ++ ++#include "cutlass/cutlass.h" ++#include "cutlass/arch/barrier.h" ++ ++#include "cute/tensor.hpp" ++#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" ++ ++namespace cutlass::epilogue::fusion { ++ ++using namespace cute; ++using namespace detail; ++ ++// Row vector broadcast ++template< ++ int Stages, ++ class CtaTileShapeMNK, ++ class Element, ++ class StrideMNL = Stride<_0,_1,_0>, ++ int Alignment = 128 / sizeof_bits_v ++> ++struct Sm90RowOrScalarBroadcast { ++ static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); ++ static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static ++ static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); ++ ++ struct SharedStorage { ++ array_aligned(CtaTileShapeMNK{})> smem; ++ }; ++ ++ // This struct has been modified to have a bool indicating that ptr_row is a ++ // scalar that must be broadcast, instead of containing a scalar that is ++ // valid if ptr_row is null. ++ struct Arguments { ++ Element const* ptr_row = nullptr; ++ bool row_broadcast = true; ++ StrideMNL dRow = {}; ++ }; ++ ++ using Params = Arguments; ++ ++ template ++ static constexpr Params ++ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { ++ return args; ++ } ++ ++ template ++ static bool ++ can_implement(ProblemShape const& problem_shape, Arguments const& args) { ++ return true; ++ } ++ ++ template ++ static size_t ++ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { ++ return 0; ++ } ++ ++ template ++ static cutlass::Status ++ initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, ++ CudaHostAdapter* cuda_adapter = nullptr) { ++ return cutlass::Status::kSuccess; ++ } ++ ++ CUTLASS_HOST_DEVICE ++ Sm90RowOrScalarBroadcast() { } ++ ++ CUTLASS_HOST_DEVICE ++ Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) ++ : params(params) ++ , smem(const_cast(shared_storage.smem.data())) { } ++ ++ Params params; ++ Element *smem = nullptr; ++ ++ CUTLASS_DEVICE bool ++ is_producer_load_needed() const { ++ return false; ++ } ++ ++ CUTLASS_DEVICE bool ++ is_C_load_needed() const { ++ return false; ++ } ++ ++ CUTLASS_DEVICE bool ++ is_zero() const { ++ return (!params.row_broadcast && *(params.ptr_row) == Element(0)); ++ } ++ ++ template ++ CUTLASS_DEVICE auto ++ get_producer_load_callbacks(ProducerLoadArgs const& args) { ++ return EmptyProducerLoadCallbacks{}; ++ } ++ ++ template ++ struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { ++ CUTLASS_DEVICE ++ ConsumerStoreCallbacks( ++ GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, ++ GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, ++ SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, ++ CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_) ++ : tGS_gRow(tGS_gRow_) ++ , tGS_sRow(tGS_sRow_) ++ , tGS_cRow(tGS_cRow_) ++ , tiled_G2S(tiled_g2s_) ++ , tSR_sRow(tSR_sRow_) ++ , tSR_rRow(tSR_rRow_) ++ , tCcRow(tCcRow_) ++ , residue_tCcRow(residue_tCcRow_) ++ , params(params_) {} ++ ++ GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) ++ GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) ++ GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) ++ Tiled_G2S tiled_G2S; ++ ++ SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) ++ SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) ++ ++ CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) ++ ThrResidue residue_tCcRow; // (m, n) ++ ThrNum thr_num; ++ Params const& params; ++ ++ CUTLASS_DEVICE void ++ begin() { ++ if (!params.row_broadcast) { ++ fill(tSR_rRow, *(params.ptr_row)); ++ return; ++ } ++ ++ auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; ++ Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); ++ Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); ++ Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); ++ ++ for (int i = 0; i < size(tGS_gRow_flt); ++i) { ++ if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { ++ continue; // OOB of SMEM, ++ } ++ if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { ++ tGS_sRow_flt(i) = tGS_gRow_flt(i); ++ } ++ else { ++ tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. ++ } ++ } ++ synchronize(); ++ } ++ ++ CUTLASS_DEVICE void ++ begin_loop(int epi_m, int epi_n) { ++ if (epi_m == 0) { // Assumes M-major subtile loop ++ if (!params.row_broadcast) return; // Do not issue LDS when row is scalar ++ Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); ++ Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); ++ copy(tSR_sRow_flt, tSR_rRow_flt); ++ } ++ } ++ ++ template ++ CUTLASS_DEVICE Array ++ visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { ++ Array frg_row; ++ ++ CUTLASS_PRAGMA_UNROLL ++ for (int i = 0; i < FragmentSize; ++i) { ++ frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); ++ } ++ ++ return frg_row; ++ } ++ }; ++ ++ template < ++ bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy ++ class... Args ++ > ++ CUTLASS_DEVICE auto ++ get_consumer_store_callbacks(ConsumerStoreArgs const& args) { ++ auto [M, N, K, L] = args.problem_shape_mnkl; ++ auto [m, n, k, l] = args.tile_coord_mnkl; ++ using ThreadCount = decltype(size(args.tiled_copy)); ++ ++ Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); ++ Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) ++ Tensor sRow = make_tensor(make_smem_ptr(smem), ++ make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) ++ //// G2S: Gmem to Smem ++ auto tiled_g2s = make_tiled_copy(Copy_Atom{}, ++ Layout< Shape<_1, ThreadCount>, ++ Stride<_0, _1>>{}, ++ Layout<_1>{}); ++ auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); ++ Tensor tGS_gRow = thr_g2s.partition_S(gRow); ++ Tensor tGS_sRow = thr_g2s.partition_D(sRow); ++ ++ //// G2S: Coord ++ auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); ++ Tensor tGS_cRow = thr_g2s.partition_S(cRow); ++ ++ //// S2R: Smem to Reg ++ Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); ++ Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) ++ ++ return ConsumerStoreCallbacks( ++ tGS_gRow, ++ tGS_sRow, ++ tGS_cRow, tiled_g2s, ++ tSR_sRow, ++ tSR_rRow, ++ args.tCcD, ++ args.residue_cD, ++ ThreadCount{}, ++ params); ++ } ++}; ++ ++///////////////////////////////////////////////////////////////////////////////////////////////// ++ ++// Column vector broadcast ++template< ++ int Stages, ++ class CtaTileShapeMNK, ++ class Element, ++ class StrideMNL = Stride<_1,_0,_0>, ++ int Alignment = 128 / sizeof_bits_v ++> ++struct Sm90ColOrScalarBroadcast { ++ static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); ++ static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); ++ static_assert( ++ (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias ++ (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias ++ ++ // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem ++ struct SharedStorage { }; ++ ++ // This struct has been modified to have a bool indicating that ptr_col is a ++ // scalar that must be broadcast, instead of containing a scalar that is ++ // valid if ptr_col is null. ++ struct Arguments { ++ Element const* ptr_col = nullptr; ++ bool col_broadcast = true; ++ StrideMNL dCol = {}; ++ }; ++ ++ using Params = Arguments; ++ ++ template ++ static constexpr Params ++ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { ++ return args; ++ } ++ ++ template ++ static bool ++ can_implement(ProblemShape const& problem_shape, Arguments const& args) { ++ return true; ++ } ++ ++ template ++ static size_t ++ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { ++ return 0; ++ } ++ ++ template ++ static cutlass::Status ++ initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, ++ CudaHostAdapter* cuda_adapter = nullptr) { ++ return cutlass::Status::kSuccess; ++ } ++ ++ CUTLASS_DEVICE bool ++ is_producer_load_needed() const { ++ return false; ++ } ++ ++ CUTLASS_DEVICE bool ++ is_C_load_needed() const { ++ return false; ++ } ++ ++ CUTLASS_DEVICE bool ++ is_zero() const { ++ return (!params.col_broadcast && *(params.ptr_col) == Element(0)); ++ } ++ ++ CUTLASS_HOST_DEVICE ++ Sm90ColOrScalarBroadcast() { } ++ ++ CUTLASS_HOST_DEVICE ++ Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) ++ : params(params) { } ++ ++ Params params; ++ ++ template ++ CUTLASS_DEVICE auto ++ get_producer_load_callbacks(ProducerLoadArgs const& args) { ++ return EmptyProducerLoadCallbacks{}; ++ } ++ ++ template ++ struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { ++ CUTLASS_DEVICE ++ ConsumerStoreCallbacks( ++ GTensor&& tCgCol, ++ RTensor&& tCrCol, ++ CTensor&& tCcCol, ++ ProblemShape problem_shape, ++ Params const& params ++ ): ++ tCgCol(cute::forward(tCgCol)), ++ tCrCol(cute::forward(tCrCol)), ++ tCcCol(cute::forward(tCcCol)), ++ m(get<0>(problem_shape)), ++ params(params) {} ++ ++ GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) ++ RTensor tCrCol; ++ CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) ++ Params const& params; ++ int m; ++ ++ CUTLASS_DEVICE void ++ begin() { ++ Tensor pred = make_tensor(shape(tCgCol)); ++ CUTLASS_PRAGMA_UNROLL ++ for (int i = 0; i < size(pred); ++i) { ++ pred(i) = get<0>(tCcCol(i)) < m; ++ } ++ ++ if (!params.col_broadcast) { ++ fill(tCrCol, *(params.ptr_col)); ++ return; ++ } ++ ++ // Filter so we don't issue redundant copies over stride-0 modes ++ // (only works if 0-strides are in same location, which is by construction) ++ copy_if(pred, filter(tCgCol), filter(tCrCol)); ++ } ++ ++ template ++ CUTLASS_DEVICE Array ++ visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { ++ Array frg_col; ++ Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); ++ ++ CUTLASS_PRAGMA_UNROLL ++ for (int i = 0; i < FragmentSize; ++i) { ++ frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); ++ } ++ ++ return frg_col; ++ } ++ ++ }; ++ ++ template < ++ bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy ++ class... Args ++ > ++ CUTLASS_DEVICE auto ++ get_consumer_store_callbacks(ConsumerStoreArgs const& args) { ++ ++ auto [M, N, K, L] = args.problem_shape_mnkl; ++ Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); ++ Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) ++ mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); ++ Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) ++ ++ // Generate an identity tensor matching the shape of the global tensor and ++ // partition the same way, this will be used to generate the predicate ++ // tensor for loading ++ Tensor cCol = make_identity_tensor(mCol.shape()); ++ Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) ++ cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); ++ ++ return ConsumerStoreCallbacks( ++ cute::move(tCgCol), ++ cute::move(tCrCol), ++ cute::move(tCcCol), ++ args.problem_shape_mnkl, ++ params ++ ); ++ } ++}; ++ ++} +diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +new file mode 100644 +index 0000000..ef413e6 +--- /dev/null ++++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +@@ -0,0 +1,319 @@ ++#pragma once ++ ++#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp" ++ ++/* ++ This file defines custom epilogues for fusing channel scales, token scales, ++ bias, and activation zero-points onto a GEMM operation using the ++ CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs. ++ ++ Epilogues must contain a public type named EVTCompute of type Sm80EVT, ++ as well as a static prepare_args function that constructs an ++ EVTCompute::Arguments struct. ++*/ ++ ++namespace vllm::c2x { ++ ++using namespace cute; ++ ++/* ++ * This class provides the common load descriptors for the ++ * ScaledEpilogue[...] classes ++ */ ++template ++struct ScaledEpilogueBase { ++ protected: ++ using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; ++ ++ template ++ using ColOrScalarLoad = ++ cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< ++ OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; ++ ++ template ++ using RowOrScalarLoad = ++ cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< ++ OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; ++ ++ template ++ using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast< ++ OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; ++ ++ template ++ using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast< ++ OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; ++ ++ template ++ using RowOrZeroLoad = ++ cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast< ++ OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; ++ ++ // This utility function constructs the arguments for the load descriptors ++ // from a tensor. It can handle both row and column, as well as row/column or ++ // scalar cases. ++ template ++ static auto args_from_tensor(torch::Tensor const& tensor) { ++ using Arguments = typename Descriptor::Arguments; ++ auto* data_ptr = static_cast(tensor.data_ptr()); ++ if constexpr (std::is_same_v> || ++ std::is_same_v>) { ++ return Arguments{data_ptr, tensor.numel() != 1}; ++ } else { ++ // it would technically work but no use case as data_ptr is never nullptr ++ static_assert(!std::is_same_v>); ++ return Arguments{data_ptr}; ++ } ++ } ++ ++ // This overload handles the case where there might not be a tensor, in which ++ // case a nullptr is passed and a constant (0) is used. ++ template ++ static auto args_from_tensor(std::optional const& tensor) { ++ static_assert(std::is_same_v>); ++ using Arguments = typename Descriptor::Arguments; ++ auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; ++ return Arguments{data_ptr}; ++ } ++}; ++ ++/* ++ This epilogue function defines a quantized GEMM operation similar to ++ torch._scaled_mm. ++ ++ A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or ++ per-row. B can be quantized per-tensor or per-column. ++ Any combination of per-tensor and per-row or column is supported. ++ A and B must have symmetric quantization (zero point == 0). ++ ++ So the GEMM operation is D = (a_scales * A) (b_scales * B), where the ++ scales are applied elementwise with numpy-style broadcasting. ++ ++ ScaleA and ScaleB define the epilogue functions that apply the scales for ++ the A and B operands respectively. These scales may be either per-tensor or ++ per row or column. ++*/ ++template ++struct ScaledEpilogue ++ : private ScaledEpilogueBase { ++ private: ++ using SUPER = ScaledEpilogueBase; ++ using Accum = typename SUPER::Accum; ++ using ScaleA = typename SUPER::template ColOrScalarLoad; ++ using ScaleB = typename SUPER::template RowOrScalarLoad; ++ ++ using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< ++ cutlass::multiplies, float, float, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ using EVTCompute0 = ++ cutlass::epilogue::threadblock::Sm80EVT; ++ ++ using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< ++ cutlass::multiplies, ElementD, float, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ public: ++ using EVTCompute = ++ cutlass::epilogue::threadblock::Sm80EVT; ++ using ArgumentType = typename EVTCompute::Arguments; ++ ++ static ArgumentType prepare_args(torch::Tensor const& a_scales, ++ torch::Tensor const& b_scales) { ++ auto a_args = SUPER::template args_from_tensor(a_scales); ++ auto b_args = SUPER::template args_from_tensor(b_scales); ++ ++ typename EVTCompute0::Arguments evt0_args{b_args}; ++ return ArgumentType{a_args, evt0_args}; ++ } ++}; ++ ++/* ++ * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. ++ * This bias can also be used in the per-tensor azp case, where the activation ++ * zero point (azp) is used to compute an azp correction term, ++ * which is folded into the bias. ++ * ++ * The bias tensor must be per-output channel. ++ * ScaleA and ScaleB can be per-tensor or per-token/per-channel. ++ */ ++template ++struct ScaledEpilogueBias ++ : protected ScaledEpilogueBase { ++ protected: ++ using SUPER = ScaledEpilogueBase; ++ using Accum = typename SUPER::Accum; ++ using ScaleA = typename SUPER::template ColOrScalarLoad; ++ using ScaleB = typename SUPER::template RowOrScalarLoad; ++ using Bias = typename SUPER::template RowLoad; ++ using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< ++ cutlass::multiplies, float, float, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ using EVTCompute0 = ++ cutlass::epilogue::threadblock::Sm80EVT; ++ ++ using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< ++ cutlass::multiply_add, ElementD, float, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ public: ++ using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; ++ using ArgumentType = typename EVTCompute::Arguments; ++ static ArgumentType prepare_args(torch::Tensor const& a_scales, ++ torch::Tensor const& b_scales, ++ torch::Tensor const& bias) { ++ auto a_args = SUPER::template args_from_tensor(a_scales); ++ auto b_args = SUPER::template args_from_tensor(b_scales); ++ auto bias_args = SUPER::template args_from_tensor(bias); ++ ++ typename EVTCompute0::Arguments evt0_args{b_args}; ++ return ArgumentType{a_args, evt0_args, bias_args}; ++ } ++}; ++ ++/* ++ * This epilogue directly supports per-tensor azp in int32 form. ++ * As opposed to the per-token epilogue below, this epilogue only has an azp_adj ++ * term, which should already be multiplied with the scalar azp. ++ * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. ++ * ++ * This epilogue also supports bias, which remains per-channel. ++ */ ++template ++struct ScaledEpilogueBiasAzp ++ : protected ScaledEpilogueBase { ++ private: ++ using SUPER = ScaledEpilogueBase; ++ using Accum = typename SUPER::Accum; ++ using ScaleA = typename SUPER::template ColOrScalarLoad; ++ using ScaleB = typename SUPER::template RowOrScalarLoad; ++ using Bias = typename SUPER::template RowOrZeroLoad; ++ ++ // This is the full AZP term, azp * J @ B, shape (1,n) ++ using AzpWithAdj = typename SUPER::template RowLoad; ++ ++ // Compute float(accum - azp_adj), both operands are int32_t ++ using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< ++ cutlass::minus, float, int32_t, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ using EVTComputeAzp = ++ cutlass::epilogue::threadblock::Sm80EVT; ++ ++ using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< ++ cutlass::multiplies, float, float, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ using EVTComputeScaleB = ++ cutlass::epilogue::threadblock::Sm80EVT; ++ ++ using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< ++ cutlass::multiply_add, ElementD, float, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ public: ++ using EVTCompute = ++ cutlass::epilogue::threadblock::Sm80EVT; ++ ++ using ArgumentType = typename EVTCompute::Arguments; ++ ++ static ArgumentType prepare_args(torch::Tensor const& a_scales, ++ torch::Tensor const& b_scales, ++ torch::Tensor const& azp_adj, ++ std::optional const& bias) { ++ auto a_args = SUPER::template args_from_tensor(a_scales); ++ auto b_args = SUPER::template args_from_tensor(b_scales); ++ auto bias_args = SUPER::template args_from_tensor(bias); ++ auto azp_adj_args = ++ SUPER::template args_from_tensor(azp_adj); ++ ++ typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; ++ typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; ++ return ArgumentType{a_args, evt_scale_b_args, bias_args}; ++ } ++}; ++ ++/* ++ * This epilogue supports per-token azp by computing and applying ++ * the correction term using a rank-1 update. If the term were materialized, ++ * it would require O(m*n) space, and this way it only requires O(m+n) space. ++ * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero ++ * point for each row of A. ++ * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. ++ * ++ * This epilogue also supports bias, which remains per-channel. ++ */ ++template ++struct ScaledEpilogueBiasAzpToken ++ : protected ScaledEpilogueBase { ++ private: ++ using SUPER = ScaledEpilogueBase; ++ using Accum = typename SUPER::Accum; ++ using ScaleA = typename SUPER::template ColOrScalarLoad; ++ using ScaleB = typename SUPER::template RowOrScalarLoad; ++ using Bias = typename SUPER::template RowOrZeroLoad; ++ ++ // Per-token azp term, shape (m,1) ++ using Azp = typename SUPER::template ColLoad; ++ ++ // This is the AZP adjustment term, J @ B, shape (1,n) ++ using AzpAdj = typename SUPER::template RowLoad; ++ ++ // Compute azp * azp_adj ++ using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< ++ cutlass::multiplies, int32_t, int32_t, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ using EVTComputeAzp = ++ cutlass::epilogue::threadblock::Sm80EVT; ++ ++ // Compute float(accum - azp*azp_adj), all operands are int32_t ++ using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute< ++ cutlass::minus, float, int32_t, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ using EVTComputeAcc = ++ cutlass::epilogue::threadblock::Sm80EVT; ++ ++ using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< ++ cutlass::multiplies, float, float, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ using EVTComputeScaleB = ++ cutlass::epilogue::threadblock::Sm80EVT; ++ ++ using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< ++ cutlass::multiply_add, ElementD, float, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ public: ++ using EVTCompute = ++ cutlass::epilogue::threadblock::Sm80EVT; ++ ++ using ArgumentType = typename EVTCompute::Arguments; ++ ++ static ArgumentType prepare_args(torch::Tensor const& a_scales, ++ torch::Tensor const& b_scales, ++ torch::Tensor const& azp_adj, ++ torch::Tensor const& azp, ++ std::optional const& bias) { ++ auto a_args = SUPER::template args_from_tensor(a_scales); ++ auto b_args = SUPER::template args_from_tensor(b_scales); ++ auto bias_args = SUPER::template args_from_tensor(bias); ++ auto azp_args = SUPER::template args_from_tensor(azp); ++ auto azp_adj_args = ++ SUPER::template args_from_tensor(azp_adj); ++ ++ typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; ++ typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; ++ typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; ++ return ArgumentType{a_args, evt_scale_b_args, bias_args}; ++ } ++}; ++ ++}; // namespace vllm::c2x +\ No newline at end of file +diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +new file mode 100644 +index 0000000..c590c66 +--- /dev/null ++++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +@@ -0,0 +1,317 @@ ++#pragma once ++ ++#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" ++ ++/* ++ This file defines custom epilogues for fusing channel scales, token scales, ++ bias, and activation zero-points onto a GEMM operation using the ++ CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later. ++ ++ Epilogues must contain a public type named EVTCompute of type Sm90EVT, ++ as well as a static prepare_args function that constructs an ++ EVTCompute::Arguments struct. ++*/ ++ ++namespace vllm::c3x { ++ ++using namespace cute; ++ ++/* ++ * This class provides the common load descriptors for the ++ * ScaledEpilogue[...] classes ++ */ ++template ++struct ScaledEpilogueBase { ++ protected: ++ using Accum = cutlass::epilogue::fusion::Sm90AccFetch; ++ ++ template ++ using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< ++ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, ++ Stride, Int<0>, Int<0>>>; ++ ++ template ++ using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< ++ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, ++ Stride, Int<1>, Int<0>>>; ++ ++ // Don't want to support nullptr by default ++ template ++ using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< ++ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, ++ Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; ++ ++ // Don't want to support nullptr by default ++ template ++ using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< ++ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, ++ Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; ++ ++ // This utility function constructs the arguments for the load descriptors ++ // from a tensor. It can handle both row and column, as well as row/column or ++ // scalar cases. ++ template ++ static auto args_from_tensor(torch::Tensor const& tensor) { ++ using Arguments = typename Descriptor::Arguments; ++ auto* data_ptr = static_cast(tensor.data_ptr()); ++ if constexpr (std::is_same_v> || ++ std::is_same_v>) { ++ return Arguments{data_ptr, tensor.numel() != 1}; ++ } else { ++ static_assert(!std::is_same_v> && ++ !std::is_same_v>); ++ return Arguments{data_ptr}; ++ } ++ } ++ ++ // This overload handles the case where there might not be a tensor, in which ++ // case a nullptr is passed and a constant (0) is used. ++ template ++ static auto args_from_tensor(std::optional const& tensor) { ++ using Arguments = typename Descriptor::Arguments; ++ auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; ++ static_assert(std::is_same_v> || ++ std::is_same_v>); ++ return Arguments{data_ptr}; ++ } ++}; ++ ++/* ++ This epilogue function defines a quantized GEMM operation similar to ++ torch.scaled_mm_. ++ ++ A and B may be both either int8 or fp8_e4m3. A can be ++ quantized per-tensor or per-row. B can be quantized per-tensor or per-column. ++ Any combination of per-tensor and per-row or column is supported. ++ A and B must have symmetric quantization (zero point == 0). ++ ++ So the GEMM operation is D = (a_scales * A) (b_scales * B), where the ++ scales are applied elementwise with numpy-style broadcasting. ++ ++ ScaleA and ScaleB define the epilogue functions that apply the scales for ++ the A and B operands respectively. These scales may be either per-tensor or ++ per row or column. ++*/ ++template ++struct ScaledEpilogue ++ : private ScaledEpilogueBase { ++ private: ++ using SUPER = ScaledEpilogueBase; ++ using Accum = typename SUPER::Accum; ++ using ScaleA = typename SUPER::template ColOrScalarLoad; ++ using ScaleB = typename SUPER::template RowOrScalarLoad; ++ ++ using Compute0 = cutlass::epilogue::fusion::Sm90Compute< ++ cutlass::multiplies, float, float, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ using EVTCompute0 = ++ cutlass::epilogue::fusion::Sm90EVT; ++ ++ using Compute1 = cutlass::epilogue::fusion::Sm90Compute< ++ cutlass::multiplies, ElementD, float, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ public: ++ using EVTCompute = ++ cutlass::epilogue::fusion::Sm90EVT; ++ using ArgumentType = typename EVTCompute::Arguments; ++ ++ static ArgumentType prepare_args(torch::Tensor const& a_scales, ++ torch::Tensor const& b_scales) { ++ auto a_args = SUPER::template args_from_tensor(a_scales); ++ auto b_args = SUPER::template args_from_tensor(b_scales); ++ ++ typename EVTCompute0::Arguments evt0_args{b_args}; ++ return ArgumentType{a_args, evt0_args}; ++ } ++}; ++ ++/* ++ * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. ++ * This bias can also be used in the per-tensor azp case, where the activation ++ * zero point (azp) is used to compute an azp correction term, ++ * which is folded into the bias. ++ * ++ * The bias tensor must be per-output channel. ++ * ScaleA and ScaleB can be per-tensor or per-token/per-channel. ++ */ ++template ++struct ScaledEpilogueBias ++ : private ScaledEpilogueBase { ++ private: ++ using SUPER = ScaledEpilogueBase; ++ using Accum = typename SUPER::Accum; ++ using ScaleA = typename SUPER::template ColOrScalarLoad; ++ using ScaleB = typename SUPER::template RowOrScalarLoad; ++ using Bias = typename SUPER::template RowLoad; ++ ++ using Compute0 = cutlass::epilogue::fusion::Sm90Compute< ++ cutlass::multiplies, float, float, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ using EVTCompute0 = ++ cutlass::epilogue::fusion::Sm90EVT; ++ ++ using Compute1 = cutlass::epilogue::fusion::Sm90Compute< ++ cutlass::multiply_add, ElementD, float, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ public: ++ using EVTCompute = ++ cutlass::epilogue::fusion::Sm90EVT; ++ ++ using ArgumentType = typename EVTCompute::Arguments; ++ static ArgumentType prepare_args(torch::Tensor const& a_scales, ++ torch::Tensor const& b_scales, ++ torch::Tensor const& bias) { ++ auto a_args = SUPER::template args_from_tensor(a_scales); ++ auto b_args = SUPER::template args_from_tensor(b_scales); ++ auto bias_args = SUPER::template args_from_tensor(bias); ++ ++ typename EVTCompute0::Arguments evt0_args{b_args}; ++ return ArgumentType{a_args, evt0_args, bias_args}; ++ } ++}; ++ ++/* ++ * This epilogue directly supports per-tensor azp in int32 form. ++ * As opposed to the per-token epilogue below, this epilogue only has an azp_adj ++ * term, which should already be multiplied with the scalar azp. ++ * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. ++ * ++ * This epilogue also supports bias, which remains per-channel. ++ */ ++template ++struct ScaledEpilogueBiasAzp ++ : private ScaledEpilogueBase { ++ private: ++ using SUPER = ScaledEpilogueBase; ++ using Accum = typename SUPER::Accum; ++ using ScaleA = typename SUPER::template ColOrScalarLoad; ++ using ScaleB = typename SUPER::template RowOrScalarLoad; ++ using Bias = typename SUPER::template RowLoad; ++ ++ // This is the full AZP term, azp * J @ B, shape (1,n) ++ using AzpWithAdj = typename SUPER::template RowLoad; ++ ++ // Compute float(accum - azp_adj), both operands are int32_t ++ using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< ++ cutlass::minus, float, int32_t, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ using EVTComputeAzp = ++ cutlass::epilogue::fusion::Sm90EVT; ++ ++ using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< ++ cutlass::multiplies, float, float, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ using EVTComputeScaleB = ++ cutlass::epilogue::fusion::Sm90EVT; ++ ++ using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< ++ cutlass::multiply_add, ElementD, float, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ public: ++ using EVTCompute = ++ cutlass::epilogue::fusion::Sm90EVT; ++ using ArgumentType = typename EVTCompute::Arguments; ++ ++ static ArgumentType prepare_args(torch::Tensor const& a_scales, ++ torch::Tensor const& b_scales, ++ torch::Tensor const& azp_adj, ++ std::optional const& bias) { ++ auto a_args = SUPER::template args_from_tensor(a_scales); ++ auto b_args = SUPER::template args_from_tensor(b_scales); ++ auto bias_args = SUPER::template args_from_tensor(bias); ++ auto azp_adj_args = ++ SUPER::template args_from_tensor(azp_adj); ++ ++ typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; ++ typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; ++ return ArgumentType{a_args, evt_scale_b_args, bias_args}; ++ } ++}; ++ ++/* ++ * This epilogue supports per-token azp by computing and applying ++ * the correction term using a rank-1 update. If the term were materialized, ++ * it would require O(m*n) space, and this way it only requires O(m+n) space. ++ * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero ++ * point for each row of A. ++ * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. ++ * ++ * This epilogue also supports bias, which remains per-channel. ++ */ ++template ++struct ScaledEpilogueBiasAzpToken ++ : private ScaledEpilogueBase { ++ private: ++ using SUPER = ScaledEpilogueBase; ++ using Accum = typename SUPER::Accum; ++ using ScaleA = typename SUPER::template ColOrScalarLoad; ++ using ScaleB = typename SUPER::template RowOrScalarLoad; ++ using Bias = typename SUPER::template RowLoad; ++ ++ // Per-token azp term, shape (m,1) ++ using Azp = typename SUPER::template ColLoad; ++ ++ // This is the AZP adjustment term, J @ B, shape (1,n) ++ using AzpAdj = typename SUPER::template RowLoad; ++ ++ // Compute azp * azp_adj ++ using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< ++ cutlass::multiplies, int32_t, int32_t, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ using EVTComputeAzp = ++ cutlass::epilogue::fusion::Sm90EVT; ++ ++ // Compute float(accum - azp*azp_adj), all operands are int32_t ++ using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< ++ cutlass::minus, float, int32_t, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ using EVTComputeAcc = ++ cutlass::epilogue::fusion::Sm90EVT; ++ ++ using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< ++ cutlass::multiplies, float, float, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ using EVTComputeScaleB = ++ cutlass::epilogue::fusion::Sm90EVT; ++ ++ using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< ++ cutlass::multiply_add, ElementD, float, ++ cutlass::FloatRoundStyle::round_to_nearest>; ++ ++ public: ++ using EVTCompute = ++ cutlass::epilogue::fusion::Sm90EVT; ++ using ArgumentType = typename EVTCompute::Arguments; ++ ++ static ArgumentType prepare_args(torch::Tensor const& a_scales, ++ torch::Tensor const& b_scales, ++ torch::Tensor const& azp_adj, ++ torch::Tensor const& azp, ++ std::optional const& bias) { ++ auto a_args = SUPER::template args_from_tensor(a_scales); ++ auto b_args = SUPER::template args_from_tensor(b_scales); ++ auto bias_args = SUPER::template args_from_tensor(bias); ++ auto azp_args = SUPER::template args_from_tensor(azp); ++ auto azp_adj_args = ++ SUPER::template args_from_tensor(azp_adj); ++ ++ typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; ++ typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; ++ typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; ++ return ArgumentType{a_args, evt_scale_b_args, bias_args}; ++ } ++}; ++ ++}; // namespace vllm::c3x +\ No newline at end of file +diff --git a/csrc/cutlass_extensions/torch_utils.hpp b/csrc/cutlass_extensions/torch_utils.hpp +new file mode 100644 +index 0000000..a1ff933 +--- /dev/null ++++ b/csrc/cutlass_extensions/torch_utils.hpp +@@ -0,0 +1,160 @@ ++#pragma once ++ ++#include ++ ++#include "cute/layout.hpp" ++#include "cutlass/layout/matrix.h" ++#include "cutlass/bfloat16.h" ++#include "cutlass/half.h" ++ ++using ColumnMajor = typename cutlass::layout::ColumnMajor; ++using RowMajor = typename cutlass::layout::RowMajor; ++ ++namespace cute { ++ ++namespace detail { ++ ++template ++CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g, ++ seq) { ++ return g(f(cute::get(static_cast(t)), I)...); ++} ++ ++template ++CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq) { ++ return make_shape(f(I)...); ++} ++ ++}; // namespace detail ++ ++template ++CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) { ++ if constexpr (cute::is_tuple::value) { ++ return detail::tapply_with_idx( ++ t, f, [](auto const&... a) { return cute::make_tuple(a...); }, ++ tuple_seq{}); ++ } else { ++ return f(t); ++ } ++ ++ CUTE_GCC_UNREACHABLE; ++} ++ ++// calls: make_shape(f(0), f(1), ..., f(N-1)) ++template ++CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) { ++ return detail::make_shape_from_idx(f, make_seq{}); ++} ++ ++}; // namespace cute ++ ++// Make a layout from a tensor with `rank(Stride{})`, where the shape is the ++// shape of the passed in tensor and the strides are of type `Stride` and ++// contain the strides of the passed in tensor, checking that any static strides ++// in `Stride{}` match the strides of the passed in tensor. ++// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra ++// strides are set to be 0 or 1. ++template ++static inline auto make_cute_layout(torch::Tensor const& tensor, ++ std::string_view name = "tensor") { ++ TORCH_CHECK(tensor.dim() <= rank(Stride{})); ++ auto stride = cute::transform_with_idx( ++ Stride{}, [&](auto const& stride_ele, auto const& idx) { ++ using StrideEle = std::decay_t; ++ ++ if (idx < tensor.dim()) { ++ if constexpr (cute::is_static_v) { ++ TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ", ++ name, ".stride(", idx, ") to be ", StrideEle::value); ++ return StrideEle{}; ++ } else { ++ if (tensor.size(idx) == 1) { ++ // use 0 stride for dim with size 1, this is easier for ++ // cute/cutlass to optimize (helps the TMA code flatten dims) ++ return StrideEle{0}; ++ } else { ++ return tensor.stride(idx); ++ } ++ } ++ } else { ++ // Extra strides are assumed to be 0 or 1 ++ if constexpr (cute::is_static_v) { ++ static_assert(StrideEle::value == 0 || StrideEle::value == 1); ++ } ++ return StrideEle{}; ++ } ++ }); ++ ++ auto shape = cute::make_shape_from_idx([&](auto const& idx) { ++ if (idx < tensor.dim()) ++ return tensor.size(idx); ++ else ++ return int64_t(1); ++ }); ++ ++ return make_layout(shape, stride); ++} ++ ++template ++static inline auto maybe_make_cute_layout( ++ std::optional const& tensor, ++ std::string_view name = "tensor") { ++ using Layout = decltype(make_cute_layout(*tensor)); ++ ++ if (tensor) { ++ return std::optional{make_cute_layout(*tensor, name)}; ++ } else { ++ return std::optional{}; ++ } ++} ++ ++// ++// Torch Type to Cutlass Type (equivalent_cutlass_type) ++// ++ ++template ++struct equivalent_cutlass_type { ++ using type = T; ++}; ++ ++template ++using equivalent_cutlass_type_t = typename equivalent_cutlass_type::type; ++ ++template <> ++struct equivalent_cutlass_type { ++ using type = cutlass::half_t; ++}; ++ ++template <> ++struct equivalent_cutlass_type { ++ using type = cutlass::bfloat16_t; ++}; ++ ++// ++// equivalent_scalar_t (basically inverse of equivalent_cutlass_type) ++// ++ ++// Return a `c10::CppTypeToScalarType` compatible type, i.e. get the C++ from ++// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half` ++template ++struct equivalent_scalar_type { ++ using type = T; ++}; ++ ++template ++using equivalent_scalar_type_t = typename equivalent_scalar_type::type; ++ ++template <> ++struct equivalent_scalar_type { ++ using type = c10::Half; ++}; ++ ++template <> ++struct equivalent_scalar_type { ++ using type = c10::BFloat16; ++}; ++ ++// get equivalent c10::ScalarType tag from compile time type ++template ++static inline constexpr c10::ScalarType equivalent_scalar_type_v = ++ c10::CppTypeToScalarType>::value; +\ No newline at end of file +diff --git a/csrc/cutlass_extensions/vllm_collective_builder.cuh b/csrc/cutlass_extensions/vllm_collective_builder.cuh +new file mode 100644 +index 0000000..085ee12 +--- /dev/null ++++ b/csrc/cutlass_extensions/vllm_collective_builder.cuh +@@ -0,0 +1,43 @@ ++#pragma once ++ ++#include "cutlass/gemm/collective/collective_builder.hpp" ++ ++namespace cutlass::gemm::collective { ++using namespace cute; ++ ++// ++// VLLMCollectiveBuilder is a wrapper around CollectiveBuilder that allows for ++// for custom kernel tags, allowing you to build custom collectives. Without ++// touching the cutlass library headers, using `CutlassKernelTag` will mean it ++// will resort to using the standard cutlass collective builder. ++// ++ ++// Use the default Cutlass collective builder, i.e. use an unmodified cutless ++// collective ++struct CutlassKernelTag {}; ++ ++template ++struct VLLMCollectiveBuilder { ++ static_assert(sizeof(ElementA) == 0, ++ "Could not build a collective for given parameters."); ++}; ++ ++template ++struct VLLMCollectiveBuilder< ++ CutlassKernelTag, ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ++ ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, ++ ClusterShape_MNK, StageCountType, KernelScheduleType> { ++ using CollectiveOp = typename CollectiveBuilder< ++ ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ElementB, ++ GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, ++ ClusterShape_MNK, StageCountType, KernelScheduleType>::CollectiveOp; ++}; ++ ++}; // namespace cutlass::gemm::collective +\ No newline at end of file +diff --git a/csrc/cutlass_extensions/vllm_custom_types.cuh b/csrc/cutlass_extensions/vllm_custom_types.cuh +new file mode 100644 +index 0000000..6146bdc +--- /dev/null ++++ b/csrc/cutlass_extensions/vllm_custom_types.cuh +@@ -0,0 +1,50 @@ ++#pragma once ++ ++#include "cutlass/integer_subbyte.h" ++ ++namespace cutlass { ++ ++/////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++template ++struct vllm_biased_integer_subbyte : public integer_subbyte { ++ using Base = integer_subbyte; ++ ++ using Storage = typename Base::Storage; ++ using xint_t = typename Base::xint_t; ++ ++ using Base::bits_mask_; ++ using Base::sign_mask_; ++ using Base::storage; ++ ++ // ++ // Methods ++ // ++ ++ /// No operation ++ vllm_biased_integer_subbyte() = default; ++ ++ /// Conversion from integer type ++ CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(int value) ++ : Base(value) {} ++ CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(unsigned value) ++ : Base(value) {} ++ CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(double value) ++ : Base(value) {} ++}; ++/////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++// "GPTQ" types, i.e. symmetric quantization ++using vllm_uint4b8_t = vllm_biased_integer_subbyte<4, 8>; // u4b8 ++using vllm_uint8b128_t = vllm_biased_integer_subbyte<8, 128>; // u8b128 ++ ++/////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++template ++struct sizeof_bits> { ++ static constexpr int value = Bits; ++}; ++ ++/////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++} // namespace cutlass +diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py +new file mode 100644 +index 0000000..b401736 +--- /dev/null ++++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py +@@ -0,0 +1,78 @@ ++import enum ++from typing import Dict, Union ++ ++from cutlass_library import * ++ ++# ++# Extend cutlass library with custom types, and missing values ++# ++ ++ ++class VLLMDataType(enum.Enum): ++ u4b8 = enum_auto() ++ u8b128 = enum_auto() ++ ++ ++class MixedInputKernelScheduleType(enum.Enum): ++ TmaWarpSpecialized = enum_auto() ++ TmaWarpSpecializedPingpong = enum_auto() ++ TmaWarpSpecializedCooperative = enum_auto() ++ ++ ++VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = { ++ **DataTypeNames, # type: ignore ++ **{ ++ VLLMDataType.u4b8: "u4b8", ++ VLLMDataType.u8b128: "u8b128", ++ } ++} ++ ++VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = { ++ **DataTypeTag, # type: ignore ++ **{ ++ VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t", ++ VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t", ++ } ++} ++ ++VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = { ++ **DataTypeSize, # type: ignore ++ **{ ++ VLLMDataType.u4b8: 4, ++ VLLMDataType.u8b128: 8, ++ } ++} ++ ++VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = { ++ VLLMDataType.u4b8: "vllm::kU4B8", ++ VLLMDataType.u8b128: "vllm::kU8B128", ++ DataType.u4: "vllm::kU4", ++ DataType.u8: "vllm::kU8", ++ DataType.s4: "vllm::kS4", ++ DataType.s8: "vllm::kS8", ++ DataType.f16: "vllm::kFloat16", ++ DataType.bf16: "vllm::kBfloat16", ++} ++ ++VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = { ++ DataType.u8: "at::ScalarType::Byte", ++ DataType.s8: "at::ScalarType::Char", ++ DataType.e4m3: "at::ScalarType::Float8_e4m3fn", ++ DataType.s32: "at::ScalarType::Int", ++ DataType.f16: "at::ScalarType::Half", ++ DataType.bf16: "at::ScalarType::BFloat16", ++ DataType.f32: "at::ScalarType::Float", ++} ++ ++VLLMKernelScheduleTag: Dict[Union[ ++ MixedInputKernelScheduleType, KernelScheduleType], str] = { ++ **KernelScheduleTag, # type: ignore ++ **{ ++ MixedInputKernelScheduleType.TmaWarpSpecialized: ++ "cutlass::gemm::KernelTmaWarpSpecialized", ++ MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: ++ "cutlass::gemm::KernelTmaWarpSpecializedPingpong", ++ MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: ++ "cutlass::gemm::KernelTmaWarpSpecializedCooperative", ++ } ++ } +diff --git a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh +new file mode 100644 +index 0000000..90f226c +--- /dev/null ++++ b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh +@@ -0,0 +1,992 @@ ++#pragma once ++ ++#include "cutlass/numeric_conversion.h" ++#include "cutlass_extensions/vllm_custom_types.cuh" ++#include "cutlass_extensions/cute_utils.cuh" ++#include "cutlass_extensions/vllm_type_utils.cuh" ++ ++// this file extends: ++// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h ++// with vllm specific type conversions, namely: vllm_uint4b8_t, vllm_uint8b128_t ++// as well as adds interleaved numeric array converters for specific types. ++// (interleaved numeric array converters can be more efficient for subbyte ++// types) ++ ++namespace cutlass { ++ ++// InterleavedNumericArrayConverter is like NumericArrayConverter but also ++// deinterleaves converted elements based on IlvBlkLayout, interleaving can ++// make subbyte converts more efficient by allowing for efficient extraction ++// of subbyte elements from a 32bit register. ++template ++struct InterleavedNumericArrayConverter { ++ using Converter = NumericArrayConverter; ++ ++ using result_type = typename Converter::result_type; ++ using source_type = typename Converter::source_type; ++ ++ CUTLASS_DEVICE ++ static result_type convert(source_type const& source) { ++ if (cute::elect_one_sync()) { ++ if constexpr (std::is_same_v) { ++ printf( ++ "Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n", ++ nameof_v, nameof_v, N); ++ } else { ++ printf( ++ "Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not " ++ "implemented\n", ++ nameof_v, nameof_v, N, size(IlvBlkLayout{})); ++ } ++ __brkpt(); ++ } ++ return {}; ++ } ++ ++ CUTLASS_DEVICE ++ result_type operator()(source_type const& s) const { return convert(s); } ++}; ++ ++template ++struct InterleavedNumericArrayConverter< ++ IlvBlkLayout, T, S, N, Round, ++ std::enable_if_t()>> { ++ using Converter = NumericArrayConverter; ++ ++ using result_type = typename Converter::result_type; ++ using source_type = typename Converter::source_type; ++ ++ CUTLASS_DEVICE ++ static result_type convert(source_type const& source) { ++ return Converter::convert(source); ++ } ++ ++ CUTLASS_DEVICE ++ result_type operator()(source_type const& s) const { return convert(s); } ++}; ++ ++template ++struct ArrayConverterPacked32Bit { ++ using result_type = Array; ++ using source_type = Array; ++ ++ using result_packed_8_t = Array; ++ using result_packed_4_t = Array; ++ using result_packed_2_t = Array; ++ using src_packed_8_t = Array; ++ using src_packed_4_t = Array; ++ using src_packed_2_t = Array; ++ ++ static_assert(N % 2 == 0, "N must be a multiple of 2"); ++ static_assert(cutlass::sizeof_bits_v >= 4); // TODO: add 16 packed sources ++ static_assert(32 % cutlass::sizeof_bits_v == 0); ++ static constexpr auto src_elems_per_32bit_reg = ++ 32 / cutlass::sizeof_bits_v; ++ ++ // Maybe not Valid. ScalarConverter will not actually work unless ++ // NumericConverter is implemented. However it won't be used ++ // anyways since we assert N % 2 == 0, just here for compliance with ++ // VectorizedConverter. ++ using ScalarConverter = NumericConverter; ++ ++ template ++ CUTLASS_DEVICE static auto to_regs(PackedSrc const& src) { ++ if constexpr (sizeof(PackedSrc) == 1) { ++ return Array{reinterpret_cast(src)}; ++ } else if constexpr (sizeof(PackedSrc) == 2) { ++ return Array{reinterpret_cast(src)}; ++ } else if constexpr (sizeof(PackedSrc) == 4) { ++ return Array{reinterpret_cast(src)}; ++ } else { ++ static_assert(sizeof(PackedSrc) == 8); ++ return reinterpret_cast const&>(src); ++ } ++ } ++ ++ // The core converter uses bit tricks to construct a known FP16 number, then ++ // does a subtraction in FP16 for the final result. ++ template ++ CUTLASS_DEVICE static PackedResultType packed_convert( ++ PackedSrcType const& source) { ++ static_assert(PackedSrcType::kElements == PackedResultType::kElements); ++ static_assert(PackedResultType::kElements == 2 || ++ PackedResultType::kElements == 4 || ++ PackedResultType::kElements == 8, ++ "Invalid PackedResultType must be 2, 4 or 8."); ++ static_assert(std::is_same_v); ++ static_assert(std::is_same_v); ++ ++ return RegConvert32bit::template convert(to_regs(source)); ++ } ++ ++ friend class detail::VectorizedConverter; ++ ++ public: ++ CUTLASS_DEVICE static result_type convert(source_type const& source) { ++ result_type result; ++ using ConverterType = ++ ArrayConverterPacked32Bit; ++ ++ if constexpr (src_elems_per_32bit_reg >= 8) { ++ detail::VectorizedConverter::convert< ++ ConverterType, result_packed_8_t, src_packed_8_t, result_packed_4_t, ++ src_packed_4_t, result_packed_2_t, src_packed_2_t>(result, source); ++ } else if constexpr (src_elems_per_32bit_reg >= 4) { ++ detail::VectorizedConverter::convert(result, source); ++ } else { ++ detail::VectorizedConverter::convert(result, source); ++ } ++ ++ return result; ++ } ++}; ++ ++// Convert 8 4bit values packed into a 32bit register to 8 8bit values packed ++// into 2 32bit register. ++template ++CUTLASS_DEVICE cutlass::AlignedArray lut_4bit_to_8bit_convert( ++ uint32_t src) { ++ cutlass::AlignedArray r; ++ // Determines if the value is in the top half of the LUT if set or ++ // (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move ++ // into bit position 0x4 of each nibble so when or'd with final_prmt_base it ++ // selects the correct candidate. When elements in final_prmt_base ++ // are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements ++ // are < 0x4, the low candidate is selected (i.e. LUT[0:7]) ++ uint32_t high_bit = (src & 0x88888888) >> 1; ++ ++ // `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT ++ // (selects correct high or low candidate) ++ const uint32_t final_prmt_base = 0x32103210; ++ ++ // Ignore the high bit when indexing into LUT, for each 4bit value ++ // we index into both the high and low candidates then use ++ // high_bit | final_prmt_base to select the correct candidate ++ uint32_t lut_idx = (src & 0x77777777); ++ ++ auto pack = [](uint8_t a, uint8_t b, uint8_t c, uint8_t d) { ++ return uint32_t(a) | (uint32_t(b) << 8) | (uint32_t(c) << 16) | ++ (uint32_t(d) << 24); ++ }; ++ ++ static constexpr uint32_t LOW_0 = pack(LUT0, LUT1, LUT2, LUT3); ++ static constexpr uint32_t LOW_1 = pack(LUT4, LUT5, LUT6, LUT7); ++ static constexpr uint32_t HIGH_0 = pack(LUT8, LUT9, LUT10, LUT11); ++ static constexpr uint32_t HIGH_1 = pack(LUT12, LUT13, LUT14, LUT15); ++ ++ CUTLASS_PRAGMA_UNROLL ++ for (int ii = 0; ii < 2; ++ii, lut_idx >>= 16, high_bit >>= 16) { ++ uint32_t final_prmt_idx = final_prmt_base | high_bit; ++ ++ // This uses a look up table to convert packed int4s to packed int8s, ++ // using the int4 value as the index to prmt. It first select both the ++ // high and low candidates, then uses the high bit (i.e. `high_bit`) to ++ // select the correct candidate. ++ asm volatile( ++ "{\n" ++ " .reg .b32 low, high;\n" ++ " prmt.b32 low, %1, %2, %5;\n" ++ " prmt.b32 high, %3, %4, %5;\n" ++ " prmt.b32 %0, low, high, %6;\n" ++ "}\n" ++ : "=r"(r[ii]) ++ : "n"(LOW_0), "n"(LOW_1), "n"(HIGH_0), "n"(HIGH_1), "r"(lut_idx), ++ "r"(final_prmt_idx)); ++ } ++ ++ return r; ++}; ++ ++// for Array <= Array ++template ++struct NumericArrayConverter { ++ using result_type = Array; ++ using source_type = Array; ++ ++ static FloatRoundStyle const round_style = Round; ++ ++ private: ++ struct RegConvert { ++ template ++ CUTLASS_DEVICE static PackedResultType convert(Array src_) { ++ // [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as int8s ++ auto r = lut_4bit_to_8bit_convert<0xF8, 0xF9, 0xFA, 0xFB, // ++ 0xFC, 0xFD, 0xFE, 0xFF, // ++ 0x00, 0x01, 0x02, 0x03, // ++ 0x04, 0x05, 0x06, 0x07>(src_[0]); ++ return reinterpret_cast(r); ++ }; ++ }; ++ ++ public: ++ CUTLASS_DEVICE ++ static result_type convert(source_type const& source) { ++ return ArrayConverterPacked32Bit::convert(source); ++ } ++ ++ CUTLASS_DEVICE ++ result_type operator()(source_type const& s) const { return convert(s); } ++}; ++ ++// for Array <= Array ++template ++struct NumericArrayConverter { ++ using result_type = Array; ++ using source_type = Array; ++ ++ static FloatRoundStyle const round_style = Round; ++ ++ private: ++ struct RegConvert { ++ template ++ CUTLASS_DEVICE static PackedResultType convert(Array src_) { ++ // [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as fp8s ++ auto r = lut_4bit_to_8bit_convert<0xD0, 0xCE, 0xCC, 0xCA, // ++ 0xC8, 0xC4, 0xC0, 0xB8, // ++ 0x00, 0x38, 0x40, 0x44, // ++ 0x48, 0x4A, 0x4C, 0x4E>(src_[0]); ++ return reinterpret_cast(r); ++ }; ++ }; ++ ++ public: ++ CUTLASS_DEVICE ++ static result_type convert(source_type const& source) { ++ return ArrayConverterPacked32Bit::convert(source); ++ } ++ ++ CUTLASS_DEVICE ++ result_type operator()(source_type const& s) const { return convert(s); } ++}; ++ ++// for Array <= Array ++template ++struct NumericArrayConverter { ++ using result_type = Array; ++ using source_type = Array; ++ ++ struct RegConvert { ++ template ++ CUTLASS_DEVICE static PackedResultType convert(Array src_) { ++ uint32_t src = src_[0]; ++ using RegArray = ++ cutlass::AlignedArray; ++ RegArray r; ++ ++ // Below constructs the following temporary: ++ // fp16s_01 = {0x00, i4_01, 0x00, i4_01} ++ // fp16s_23 = {0x00, i4_23, 0x00, i4_23} ++ // fp16s_45 = {0x00, i4_45, 0x00, i4_45} ++ // fp16s_67 = {0x00, i4_67, 0x00, i4_67} ++ // We use inline asm instead of __byte_perm intrinsic since we don't want ++ // the documented (& 0x7) on the index. NVCC might be able to optimize it ++ // out since the index is a constexpr, but we choose to be safe about it ++ // here. ++ uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343}; ++ static_assert(RegArray::kElements <= 4, ++ "Too many inputs for F16 -> I4 vector converter"); ++ CUTLASS_PRAGMA_UNROLL ++ for (int ii = 0; ii < RegArray::kElements; ++ii) { ++ asm volatile( ++ "{\n" ++ " prmt.b32 %0, %1, %2, %3;\n" ++ "}\n" ++ : "=r"(r[ii]) ++ : "r"(src), "n"(0), "r"(prmt_indices[ii])); ++ } ++ ++ // Since the stored 4bit values are biased by 8 we get stored_val = (x+8) ++ // we are trying to construct x and a fp16 value ++ // The below XOR does the following: ++ // 1) Sets the exponent bits of the FP16 to the correct value for the ++ // FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)}, ++ // where x1 in the high nibble and x0 is the low nibble then using hfma ++ // to subtract 1032 from that ++ // The AND does the following: ++ // 1) Clear the set bits for the int4 we will ignore. ++ // We use lop3 so that we can use 1 instruction for AND and XOR. ++ static constexpr uint32_t xor_mask = 0x64006400; ++ static constexpr uint32_t and_mask = 0xFFF0FF0F; ++ static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; ++ ++ // For each operand, computes: ++ // r[i] = (r[i] & and_mask) ^ xor_mask ++ CUTLASS_PRAGMA_UNROLL ++ for (int ii = 0; ii < RegArray::kElements; ++ii) { ++ asm volatile( ++ "{\n" ++ " lop3.b32 %0, %0, %1, %2, %3;\n" ++ "}\n" ++ : "+r"(r[ii]) ++ : "n"(and_mask), "n"(xor_mask), "n"(immLut)); ++ } ++ ++ // We will issue 2 hfmas that do the following: ++ // {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032} ++ // = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032} ++ static constexpr uint32_t hfma_bias_rep = 0xD480E408; // {72, 1032} ++ static constexpr uint32_t hfma_scale_rep = 0x2C003C00; // {1 / 16, 1} ++ ++ const half2& hfma_bias = reinterpret_cast(hfma_bias_rep); ++ const half2& hfma_scale = reinterpret_cast(hfma_scale_rep); ++ CUTLASS_PRAGMA_UNROLL ++ for (int ii = 0; ii < RegArray::kElements; ++ii) { ++ half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); ++ fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias); ++ } ++ ++ return reinterpret_cast(r); ++ }; ++ }; ++ ++ public: ++ CUTLASS_DEVICE ++ static result_type convert(source_type const& source) { ++ return ArrayConverterPacked32Bit::convert(source); ++ } ++ ++ CUTLASS_DEVICE ++ result_type operator()(source_type const& s) const { return convert(s); } ++}; ++ ++// for Array <= Array ++// for IlvdLayout: (2, 4):(4, 1) ++template ++struct InterleavedNumericArrayConverter, Stride<_4, _1>>, ++ cutlass::half_t, vllm_uint4b8_t, N, ++ Round, void> { ++ using IlvdLayout = Layout, Stride<_4, _1>>; ++ static_assert(N % size(IlvdLayout{}) == 0); ++ ++ using result_type = Array; ++ using source_type = Array; ++ ++ static FloatRoundStyle const round_style = Round; ++ ++ private: ++ struct RegConvert { ++ template ++ CUTLASS_DEVICE static PackedResultType convert(Array src_) { ++ uint32_t src = src_[0]; ++ using RegArray = ++ cutlass::AlignedArray; ++ RegArray r; ++ ++ static_assert(PackedResultType::kElements <= size(IlvdLayout{})); ++ static constexpr uint32_t xor_mask = 0x64006400; ++ ++ for (int ii = 0; ii < RegArray::kElements; ii += 2) { ++ auto src_ = src >> (4 * (ii)); ++ r[ii + 0] = src_; ++ r[ii + 1] = src_; ++ ++ static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa; ++ ++ static constexpr uint32_t low_nib_mask = 0x000F000F; ++ static constexpr uint32_t high_nib_mask = 0x00F000F0; ++ ++ asm volatile( ++ "{\n" ++ " lop3.b32 %0, %0, %1, %2, %3;\n" ++ "}\n" ++ : "+r"(r[ii + 0]) ++ : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); ++ ++ asm volatile( ++ "{\n" ++ " lop3.b32 %0, %0, %1, %2, %3;\n" ++ "}\n" ++ : "+r"(r[ii + 1]) ++ : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); ++ ++ // For low nibble: ++ // {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032} ++ // For high nibble: ++ // {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16} ++ // - {72, 72} ++ static constexpr uint32_t low_nib_bias = 0x64086408; // {1032, 1032} ++ static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16} ++ static constexpr uint32_t high_nib_bias = 0xD480D480; // {-72, -72} ++ ++ { ++ half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); ++ fp16x2_val = ++ __hsub2(fp16x2_val, reinterpret_cast(low_nib_bias)); ++ } ++ ++ { ++ half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); ++ fp16x2_val = __hfma2(fp16x2_val, ++ reinterpret_cast(high_nib_scale), ++ reinterpret_cast(high_nib_bias)); ++ } ++ } ++ ++ return reinterpret_cast(r); ++ }; ++ }; ++ ++ public: ++ CUTLASS_DEVICE ++ static result_type convert(source_type const& source) { ++ return ArrayConverterPacked32Bit::convert(source); ++ } ++ ++ CUTLASS_DEVICE ++ result_type operator()(source_type const& s) const { return convert(s); } ++}; ++ ++// for Array <= Array ++// for IlvdLayout: (2, 4):(4, 1) ++template ++struct InterleavedNumericArrayConverter, Stride<_4, _1>>, ++ cutlass::half_t, uint4_t, N, Round, ++ void> { ++ using IlvdLayout = Layout, Stride<_4, _1>>; ++ static_assert(N % size(IlvdLayout{}) == 0); ++ ++ using result_type = Array; ++ using source_type = Array; ++ ++ static FloatRoundStyle const round_style = Round; ++ ++ private: ++ struct RegConvert { ++ template ++ CUTLASS_DEVICE static PackedResultType convert(Array src_) { ++ uint32_t src = src_[0]; ++ using RegArray = ++ cutlass::AlignedArray; ++ RegArray r; ++ ++ static_assert(PackedResultType::kElements <= size(IlvdLayout{})); ++ static constexpr uint32_t xor_mask = 0x64006400; ++ ++ for (int ii = 0; ii < RegArray::kElements; ii += 2) { ++ auto src_ = src >> (4 * (ii)); ++ r[ii + 0] = src_; ++ r[ii + 1] = src_; ++ ++ static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa; ++ ++ static constexpr uint32_t low_nib_mask = 0x000F000F; ++ static constexpr uint32_t high_nib_mask = 0x00F000F0; ++ ++ asm volatile( ++ "{\n" ++ " lop3.b32 %0, %0, %1, %2, %3;\n" ++ "}\n" ++ : "+r"(r[ii + 0]) ++ : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); ++ ++ asm volatile( ++ "{\n" ++ " lop3.b32 %0, %0, %1, %2, %3;\n" ++ "}\n" ++ : "+r"(r[ii + 1]) ++ : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); ++ ++ // For low nibble: ++ // {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024} ++ // For high nibble: ++ // {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64} ++ static constexpr uint32_t low_nib_bias = 0x64006400; // {1024, 1024} ++ static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16} ++ static constexpr uint32_t high_nib_bias = 0xD400D400; // {-64, -64} ++ ++ { ++ half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); ++ fp16x2_val = ++ __hsub2(fp16x2_val, reinterpret_cast(low_nib_bias)); ++ } ++ ++ { ++ half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); ++ fp16x2_val = __hfma2(fp16x2_val, ++ reinterpret_cast(high_nib_scale), ++ reinterpret_cast(high_nib_bias)); ++ } ++ } ++ ++ return reinterpret_cast(r); ++ }; ++ }; ++ ++ public: ++ CUTLASS_DEVICE ++ static result_type convert(source_type const& source) { ++ return ArrayConverterPacked32Bit::convert(source); ++ } ++ ++ CUTLASS_DEVICE ++ result_type operator()(source_type const& s) const { return convert(s); } ++}; ++ ++// for Array <= Array ++template ++struct NumericArrayConverter { ++ using result_type = Array; ++ using source_type = Array; ++ ++ struct RegConvert { ++ template ++ CUTLASS_DEVICE static PackedResultType convert(Array src_) { ++ uint32_t src = src_[0]; ++ // Hold output FP16s in reg. We need 1 reg for every 2 elements ++ using RegArray = ++ cutlass::AlignedArray; ++ RegArray r; ++ ++ uint32_t const prmt_indices[2] = {0x5150, 0x5352}; ++ static constexpr uint32_t start_byte_for_fp16 = 0x64646464; ++ ++ for (int ii = 0; ii < RegArray::kElements; ++ii) { ++ asm volatile("prmt.b32 %0,%1,%2,%3;\n" ++ : "=r"(r[ii]) ++ : "r"(src), "n"(start_byte_for_fp16), ++ "r"(prmt_indices[ii])); ++ } ++ ++ // -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes ++ static constexpr uint32_t bias_rep = 0x64806480; ++ const half2& bias = reinterpret_cast(bias_rep); ++ CUTLASS_PRAGMA_UNROLL ++ for (int ii = 0; ii < RegArray::kElements; ++ii) { ++ half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); ++ fp16x2_val = __hsub2(fp16x2_val, bias); ++ } ++ ++ return reinterpret_cast(r); ++ }; ++ }; ++ ++ public: ++ CUTLASS_DEVICE ++ static result_type convert(source_type const& source) { ++ return ArrayConverterPacked32Bit::convert(source); ++ } ++ ++ CUTLASS_DEVICE ++ result_type operator()(source_type const& s) const { return convert(s); } ++}; ++ ++// for Array <= Array ++template ++struct NumericArrayConverter { ++ using result_type = Array; ++ using source_type = Array; ++ static FloatRoundStyle const round_style = Round; ++ ++ private: ++ struct RegConvert { ++ template ++ CUTLASS_DEVICE static PackedResultType convert(Array src_) { ++ uint32_t src = src_[0]; ++ PackedResultType r; ++ ++ // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of ++ // u8x4 source and stores the result in r (without introducing extra ++ // cvt.u32.u8 instruction) ++ uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653}; ++ uint32_t* result_as_int = reinterpret_cast(&r); ++ for (int ii = 0; ii < PackedResultType::kElements; ++ii) { ++ result_as_int[ii] = __byte_perm(src, 0x4B000000, prmt_indices[ii]); ++ // Subtract the magic number 0x4B000000 from tmp in floating-point ++ // arithmetic to obtain final result ++ r[ii] -= (8388608.f + 128.f); // fold in -128 bias ++ } ++ ++ return r; ++ }; ++ }; ++ ++ public: ++ CUTLASS_DEVICE ++ static result_type convert(source_type const& source) { ++ return ArrayConverterPacked32Bit::convert(source); ++ } ++ ++ CUTLASS_DEVICE ++ result_type operator()(source_type const& s) const { return convert(s); } ++}; ++ ++#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) ++ ++// for Array <= Array ++template ++struct NumericArrayConverter { ++ using result_type = Array; ++ using source_type = Array; ++ ++ static FloatRoundStyle const round_style = Round; ++ ++ private: ++ struct RegConvert { ++ template ++ CUTLASS_DEVICE static PackedResultType convert(Array src_) { ++ uint32_t src_reg = src_[0]; ++ // Hold output BF16s in reg. We need 1 reg for every 2 elements ++ using RegArray = ++ cutlass::AlignedArray; ++ RegArray r; ++ uint32_t src_reg_shifted = src_reg >> 4; ++ ++ // Below constructs the following temporary: ++ uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; ++ static_assert(RegArray::kElements <= 4, ++ "Too many inputs for uint4b8_t -> BF16 vector converter"); ++ CUTLASS_PRAGMA_UNROLL ++ for (int ii = 0; ii < RegArray::kElements; ++ii) { ++ asm volatile( ++ "{\n" ++ " prmt.b32 %0, %1, %2, %3;\n" ++ "}\n" ++ : "=r"(r[ii]) ++ : "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii])); ++ } ++ ++ // Since the stored 4bit values are biased by 8 we get stored_val = (x+8) ++ // we are trying to construct x and a BF16 value ++ // The below XOR does the following: ++ // 1) Sets the exponent bits of the BF16 to the correct value for the ++ // BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)} ++ // and subtracting 136 to get {x1, x0} ++ static constexpr uint32_t xor_mask = 0x43004300; ++ static constexpr uint32_t and_mask = 0x000F000F; ++ static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; ++ ++ // For each operand, computes: ++ // r[i] = (r[i] & and_mask) ^ xor_mask ++ CUTLASS_PRAGMA_UNROLL ++ for (int ii = 0; ii < RegArray::kElements; ++ii) { ++ asm volatile( ++ "{\n" ++ " lop3.b32 %0, %0, %1, %2, %3;\n" ++ "}\n" ++ : "+r"(r[ii]) ++ : "n"(and_mask), "n"(xor_mask), "n"(immLut)); ++ } ++ ++ // We will issue 2 bfmas that do the following: ++ // high BF16: ++ // hi_bf16 - 136, lo_bf16 - 136 ++ ++ // This is the BF16 {136, 136} represented as an integer. ++ static constexpr uint32_t bias_rep = 0x43084308; ++ const __nv_bfloat162& bias = ++ reinterpret_cast(bias_rep); ++ ++ CUTLASS_PRAGMA_UNROLL ++ for (int ii = 0; ii < RegArray::kElements; ++ii) { ++ __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); ++ bf16x2_val = __hsub2(bf16x2_val, bias); ++ } ++ ++ return reinterpret_cast(r); ++ } ++ }; ++ ++ public: ++ CUTLASS_DEVICE ++ static result_type convert(source_type const& source) { ++ return ArrayConverterPacked32Bit::convert(source); ++ } ++ ++ CUTLASS_DEVICE ++ result_type operator()(source_type const& s) const { return convert(s); } ++}; ++ ++// for Array <= Array ++// for IlvdLayout: (2, 4):(4, 1) ++template ++struct InterleavedNumericArrayConverter, Stride<_4, _1>>, ++ cutlass::bfloat16_t, vllm_uint4b8_t, N, ++ Round, void> { ++ using IlvdLayout = Layout, Stride<_4, _1>>; ++ static_assert(N % size(IlvdLayout{}) == 0); ++ ++ using result_type = Array; ++ using source_type = Array; ++ ++ private: ++ struct RegConvert { ++ template ++ CUTLASS_DEVICE static PackedResultType convert(Array src_) { ++ uint32_t src = src_[0]; ++ using RegArray = ++ cutlass::AlignedArray; ++ RegArray r; ++ ++ static_assert(PackedResultType::kElements <= size(IlvdLayout{})); ++ static constexpr uint32_t or_mask = 0x43004300; ++ ++ // Unlike float16 where the mantissa is large enough to contain 2 ++ // nibbles, bfloat16 can only fit one, so we can only convert one ++ // nibble at a time ++ for (int ii = 0; ii < RegArray::kElements; ++ii) { ++ r[ii] = src >> (4 * ii); ++ ++ static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa; ++ static constexpr uint32_t low_nib_mask = 0x000F000F; ++ ++ asm volatile( ++ "{\n" ++ " lop3.b32 %0, %0, %1, %2, %3;\n" ++ "}\n" ++ : "+r"(r[ii + 0]) ++ : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut)); ++ ++ // For low nibble: ++ // {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136} ++ static constexpr uint32_t low_nib_bias = 0x43084308; // {136, 136} ++ ++ { ++ __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); ++ fp16x2_val = ++ __hsub2(fp16x2_val, ++ reinterpret_cast(low_nib_bias)); ++ } ++ } ++ ++ return reinterpret_cast(r); ++ }; ++ }; ++ ++ public: ++ CUTLASS_DEVICE ++ static result_type convert(source_type const& source) { ++ return ArrayConverterPacked32Bit::convert(source); ++ } ++ ++ CUTLASS_DEVICE ++ result_type operator()(source_type const& s) const { return convert(s); } ++}; ++ ++// for Array <= Array ++// for IlvdLayout: (2, 4):(4, 1) ++template ++struct InterleavedNumericArrayConverter, Stride<_4, _1>>, ++ cutlass::bfloat16_t, uint4_t, N, Round, ++ void> { ++ using IlvdLayout = Layout, Stride<_4, _1>>; ++ static_assert(N % size(IlvdLayout{}) == 0); ++ ++ using result_type = Array; ++ using source_type = Array; ++ ++ private: ++ struct RegConvert { ++ template ++ CUTLASS_DEVICE static PackedResultType convert(Array src_) { ++ uint32_t src = src_[0]; ++ using RegArray = ++ cutlass::AlignedArray; ++ RegArray r; ++ ++ static_assert(PackedResultType::kElements <= size(IlvdLayout{})); ++ static constexpr uint32_t or_mask = 0x43004300; ++ ++ // Unlike float16 where the mantissa is large enough to contain 2 ++ // nibbles, bfloat16 can only fit one, so we can only convert one ++ // nibble at a time ++ for (int ii = 0; ii < RegArray::kElements; ++ii) { ++ r[ii] = src >> (4 * ii); ++ ++ static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa; ++ static constexpr uint32_t low_nib_mask = 0x000F000F; ++ ++ asm volatile( ++ "{\n" ++ " lop3.b32 %0, %0, %1, %2, %3;\n" ++ "}\n" ++ : "+r"(r[ii]) ++ : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut)); ++ ++ // For low nibble: ++ // {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128} ++ static constexpr uint32_t low_nib_bias = 0x43004300; // {128, 128} ++ ++ { ++ __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); ++ fp16x2_val = ++ __hsub2(fp16x2_val, ++ reinterpret_cast(low_nib_bias)); ++ } ++ } ++ ++ return reinterpret_cast(r); ++ }; ++ }; ++ ++ public: ++ CUTLASS_DEVICE ++ static result_type convert(source_type const& source) { ++ return ArrayConverterPacked32Bit::convert(source); ++ } ++ ++ CUTLASS_DEVICE ++ result_type operator()(source_type const& s) const { return convert(s); } ++}; ++ ++// for Array <= Array ++template ++struct NumericArrayConverter { ++ using result_type = Array; ++ using source_type = Array; ++ static FloatRoundStyle const round_style = Round; ++ ++ private: ++ using result_packed_4_t = Array; ++ using result_packed_2_t = Array; ++ using src_packed_4_t = Array; ++ using src_packed_2_t = Array; ++ ++ // Not Valid, not supported, only here to satisfy the interface and to avoid ++ // a compile error. ScalarConverter will not actually work until ++ // NumericConverter is ++ // implemented ++ using ScalarConverter = ++ NumericConverter; ++ ++ template ++ CUTLASS_DEVICE static PackedResultType packed_convert( ++ PackedSrcType const& source) { ++ static_assert( ++ (platform::is_same::value && ++ platform::is_same::value) || ++ (platform::is_same::value && ++ platform::is_same::value), ++ "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private " ++ "convert dispatch."); ++ ++ NumericArrayConverter ++ convert_uint8_to_f32; ++ Array tmp = ++ convert_uint8_to_f32(source); ++ NumericArrayConverter ++ convert_f32_to_bf16_; ++ return convert_f32_to_bf16_(tmp); ++ } ++ ++ friend class detail::VectorizedConverter; ++ ++ public: ++ CUTLASS_DEVICE ++ static result_type convert(source_type const& source) { ++ result_type result; ++ using ConverterType = ++ NumericArrayConverter; ++ detail::VectorizedConverter::convert(result, source); ++ ++ return result; ++ } ++ ++ CUTLASS_DEVICE ++ result_type operator()(source_type const& s) const { return convert(s); } ++}; ++ ++#endif ++ ++// for Array <= Array ++// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904 ++template ++struct NumericArrayConverter { ++ using result_type = Array; ++ using source_type = Array; ++ ++ struct RegConvert { ++ // FastFP16toINT8 from https://arxiv.org/pdf/2406.09904 ++ template ++ CUTLASS_DEVICE static PackedResultType convert( ++ Array src) { ++ // Hold output int8s in reg. We need 1 reg for every 4 elements ++ using RegArray = cutlass::AlignedArray< ++ uint32_t, std::max(PackedResultType::kElements / 4, size_t(1))>; ++ RegArray r; ++ ++ static constexpr uint32_t MAGIC_BIAS_ = 0x64806480; ++ auto MAGIC_BIAS = *reinterpret_cast(&MAGIC_BIAS_); ++ ++ *reinterpret_cast(&src[0]) = ++ __hadd2(*reinterpret_cast(&src[0]), MAGIC_BIAS); ++ ++ if constexpr (src_regs > 1) { ++ *reinterpret_cast(&src[1]) = ++ __hadd2(*reinterpret_cast(&src[1]), MAGIC_BIAS); ++ } ++ ++ static_assert(PackedResultType::kElements <= 4); ++ uint32_t uint8s; ++ static constexpr uint32_t MASK_0246 = 0x6420; ++ static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080; ++ asm volatile("prmt.b32 %0,%1,%2,%3;\n" ++ : "=r"(uint8s) ++ : "r"(src[0]), "r"((src_regs > 1) ? src[1] : src[0]), ++ "n"(MASK_0246)); ++ ++ uint32_t int8s = (uint8s ^ UINT8s_TO_INT8s_MASK); ++ ++ return reinterpret_cast(int8s); ++ }; ++ }; ++ ++ public: ++ CUTLASS_DEVICE ++ static result_type convert(source_type const& source) { ++ return ArrayConverterPacked32Bit::convert(source); ++ } ++ ++ CUTLASS_DEVICE ++ result_type operator()(source_type const& s) const { return convert(s); } ++}; ++ ++///////////////////////////////////////////////////////////////////////////////////////////////// ++ ++} // namespace cutlass ++ ++///////////////////////////////////////////////////////////////////////////////////////////////// +diff --git a/csrc/cutlass_extensions/vllm_type_utils.cuh b/csrc/cutlass_extensions/vllm_type_utils.cuh +new file mode 100644 +index 0000000..500ed50 +--- /dev/null ++++ b/csrc/cutlass_extensions/vllm_type_utils.cuh +@@ -0,0 +1,42 @@ ++#include "cutlass/bfloat16.h" ++#include "cutlass/half.h" ++#include "cuda_bf16.h" ++ ++#include "cutlass_extensions/vllm_custom_types.cuh" ++ ++namespace cutlass { ++ ++template ++struct nameof { ++ static constexpr char const* value = "unknown"; ++}; ++ ++template ++inline constexpr auto nameof_v = nameof::value; ++ ++#define NAMEOF_TYPE(T) \ ++ template <> \ ++ struct nameof { \ ++ static constexpr char const* value = #T; \ ++ }; ++ ++NAMEOF_TYPE(float_e4m3_t) ++NAMEOF_TYPE(float_e5m2_t) ++NAMEOF_TYPE(half_t) ++NAMEOF_TYPE(nv_bfloat16) ++NAMEOF_TYPE(bfloat16_t) ++NAMEOF_TYPE(float) ++ ++NAMEOF_TYPE(int4b_t) ++NAMEOF_TYPE(int8_t) ++NAMEOF_TYPE(int32_t) ++NAMEOF_TYPE(int64_t) ++ ++NAMEOF_TYPE(vllm_uint4b8_t) ++NAMEOF_TYPE(uint4b_t) ++NAMEOF_TYPE(uint8_t) ++NAMEOF_TYPE(vllm_uint8b128_t) ++NAMEOF_TYPE(uint32_t) ++NAMEOF_TYPE(uint64_t) ++ ++}; // namespace cutlass +\ No newline at end of file +diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h +index 91abd9e..03414b7 100644 +--- a/csrc/dispatch_utils.h ++++ b/csrc/dispatch_utils.h +@@ -4,34 +4,46 @@ + */ + #pragma once + +-#include ++#include + +-#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ +- AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ +- AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ ++#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ ++ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ ++ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +-#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ +- AT_DISPATCH_SWITCH( \ +- TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) ++#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ ++ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +-#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ +- AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ +- AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ +- AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ ++// TODO(luka/varun): use FP8_TYPE macro after refactoring ++#ifndef USE_ROCM ++ #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ ++ AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ ++ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) ++#else ++ #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ ++ AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \ ++ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) ++#endif ++ ++#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ ++ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) ++ ++#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ ++ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ ++ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ ++ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) + +-#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ +- AT_DISPATCH_SWITCH( \ +- TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) +- +-#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ +- AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ +- AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ +- AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ +- AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ ++#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ ++ AT_DISPATCH_SWITCH(TYPE, NAME, \ ++ VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) ++ ++#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ ++ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ ++ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ ++ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ ++ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +-#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ +- AT_DISPATCH_SWITCH( \ +- TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) ++#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ ++ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) +diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu +index e56b4d2..fb6882f 100644 +--- a/csrc/layernorm_kernels.cu ++++ b/csrc/layernorm_kernels.cu +@@ -1,198 +1,59 @@ +-#include +-#include ++#include "type_convert.cuh" ++#include "dispatch_utils.h" ++ ++#include + #include + +-#include "dispatch_utils.h" +-#include "reduction_utils.cuh" + #ifndef USE_ROCM +- #include +- #include ++ #include + #else +- #include +- #include +- +- using __nv_bfloat16 = __hip_bfloat16; +- using __nv_bfloat162 = __hip_bfloat162; ++ #include + #endif + + namespace vllm { + + // TODO(woosuk): Further optimize this kernel. +-template ++template + __global__ void rms_norm_kernel( +- scalar_t* __restrict__ out, // [..., hidden_size] +- const scalar_t* __restrict__ input, // [..., hidden_size] +- const scalar_t* __restrict__ weight, // [hidden_size] +- const float epsilon, +- const int num_tokens, +- const int hidden_size) { ++ scalar_t* __restrict__ out, // [..., hidden_size] ++ const scalar_t* __restrict__ input, // [..., hidden_size] ++ const scalar_t* __restrict__ weight, // [hidden_size] ++ const float epsilon, const int num_tokens, const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { +- const float x = (float) input[blockIdx.x * hidden_size + idx]; ++ const float x = (float)input[blockIdx.x * hidden_size + idx]; + variance += x * x; + } +- variance = blockReduceSum(variance); ++ ++ using BlockReduce = cub::BlockReduce; ++ __shared__ typename BlockReduce::TempStorage reduceStore; ++ variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); ++ + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { +- float x = (float) input[blockIdx.x * hidden_size + idx]; +- out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; ++ float x = (float)input[blockIdx.x * hidden_size + idx]; ++ out[blockIdx.x * hidden_size + idx] = ++ ((scalar_t)(x * s_variance)) * weight[idx]; + } + } + +- +-/* Converter structs for the conversion from torch types to HIP/CUDA types, +- and the associated type conversions within HIP/CUDA. These helpers need +- to be implemented for now because the relevant type conversion +- operators/constructors are not consistently implemented by HIP/CUDA, so +- a generic conversion via type casts cannot be implemented. +- +- Each struct should have the member static constexpr bool `exists`: +- If false, the optimized kernel is not used for the corresponding torch type. +- If true, the struct should be fully defined as shown in the examples below. +- */ +-template +-struct _typeConvert { static constexpr bool exists = false; }; +- +-#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) +-// CUDA < 12.0 runs into issues with packed type conversion +-template<> +-struct _typeConvert { +- static constexpr bool exists = true; +- using hip_type = __half; +- using packed_hip_type = __half2; +- +- __device__ static inline float convert(hip_type x) { return __half2float(x); } +- __device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); } +- __device__ static inline hip_type convert(float x) { return __float2half_rn(x); } +- __device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); } +-}; +- +-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +-// CUDA_ARCH < 800 does not have BF16 support +-// TODO: Add in ROCm support once public headers handle bf16 maturely +-template<> +-struct _typeConvert { +- static constexpr bool exists = true; +- using hip_type = __nv_bfloat16; +- using packed_hip_type = __nv_bfloat162; +- +- __device__ static inline float convert(hip_type x) { return __bfloat162float(x); } +- __device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } +- __device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } +- __device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } +-}; +-#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +-#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) +- +-/* Vector POD struct to generate vectorized and packed FP16/BF16 ops +- for appropriate specializations of fused_add_rms_norm_kernel. +- Only functions that are necessary in that kernel are implemented. +- Alignment to 16 bytes is required to use 128-bit global memory ops. +- */ +-template +-struct alignas(16) _f16Vec { +- /* Not theoretically necessary that width is a power of 2 but should +- almost always be the case for optimization purposes */ +- static_assert(width > 0 && (width & (width - 1)) == 0, +- "Width is not a positive power of 2!"); +- using Converter = _typeConvert; +- using T1 = typename Converter::hip_type; +- using T2 = typename Converter::packed_hip_type; +- T1 data[width]; +- +- __device__ _f16Vec& operator+=(const _f16Vec& other) { +- if constexpr (width % 2 == 0) { +- #pragma unroll +- for (int i = 0; i < width; i += 2) { +- T2 temp{data[i], data[i+1]}; +- temp += T2{other.data[i], other.data[i+1]}; +- data[i] = temp.x; +- data[i+1] = temp.y; +- } +- } else { +- #pragma unroll +- for (int i = 0; i < width; ++i) +- data[i] += other.data[i]; +- } +- return *this; +- } +- +- __device__ _f16Vec& operator*=(const _f16Vec& other) { +- if constexpr (width % 2 == 0) { +- #pragma unroll +- for (int i = 0; i < width; i += 2) { +- T2 temp{data[i], data[i+1]}; +- temp *= T2{other.data[i], other.data[i+1]}; +- data[i] = temp.x; +- data[i+1] = temp.y; +- } +- } else { +- #pragma unroll +- for (int i = 0; i < width; ++i) +- data[i] *= other.data[i]; +- } +- return *this; +- } +- +- __device__ _f16Vec& operator*=(const float scale) { +- if constexpr (width % 2 == 0) { +- #pragma unroll +- for (int i = 0; i < width; i += 2) { +- float2 temp_f = Converter::convert(T2{data[i], data[i+1]}); +- temp_f.x *= scale; +- temp_f.y *= scale; +- T2 temp = Converter::convert(temp_f); +- data[i] = temp.x; +- data[i+1] = temp.y; +- } +- } else { +- #pragma unroll +- for (int i = 0; i < width; ++i) { +- float temp = Converter::convert(data[i]) * scale; +- data[i] = Converter::convert(temp); +- } +- } +- return *this; +- } +- +- __device__ float sum_squares() const { +- float result = 0.0f; +- if constexpr (width % 2 == 0) { +- #pragma unroll +- for (int i = 0; i < width; i += 2) { +- float2 z = Converter::convert(T2{data[i], data[i+1]}); +- result += z.x * z.x + z.y * z.y; +- } +- } else { +- #pragma unroll +- for (int i = 0; i < width; ++i) { +- float x = Converter::convert(data[i]); +- result += x * x; +- } +- } +- return result; +- } +-}; +- + /* Function specialization in the case of FP16/BF16 tensors. + Additional optimizations we can make in this case are + packed and vectorized operations, which help with the + memory latency bottleneck. */ +-template +-__global__ std::enable_if_t< +- (width > 0) && _typeConvert::exists> fused_add_rms_norm_kernel( +- scalar_t* __restrict__ input, // [..., hidden_size] +- scalar_t* __restrict__ residual, // [..., hidden_size] +- const scalar_t* __restrict__ weight, // [hidden_size] +- const float epsilon, +- const int num_tokens, +- const int hidden_size) { ++template ++__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> ++fused_add_rms_norm_kernel( ++ scalar_t* __restrict__ input, // [..., hidden_size] ++ scalar_t* __restrict__ residual, // [..., hidden_size] ++ const scalar_t* __restrict__ weight, // [hidden_size] ++ const float epsilon, const int num_tokens, const int hidden_size) { + // Sanity checks on our vector struct and type-punned pointer arithmetic + static_assert(std::is_pod_v<_f16Vec>); + static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); +@@ -203,9 +64,12 @@ __global__ std::enable_if_t< + /* These and the argument pointers are all declared `restrict` as they are + not aliased in practice. Argument pointers should not be dereferenced + in this kernel as that would be undefined behavior */ +- auto* __restrict__ input_v = reinterpret_cast<_f16Vec*>(input); +- auto* __restrict__ residual_v = reinterpret_cast<_f16Vec*>(residual); +- auto* __restrict__ weight_v = reinterpret_cast*>(weight); ++ auto* __restrict__ input_v = ++ reinterpret_cast<_f16Vec*>(input); ++ auto* __restrict__ residual_v = ++ reinterpret_cast<_f16Vec*>(residual); ++ auto* __restrict__ weight_v = ++ reinterpret_cast*>(weight); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; +@@ -214,11 +78,11 @@ __global__ std::enable_if_t< + variance += temp.sum_squares(); + residual_v[id] = temp; + } +- /* Keep the following if-else block in sync with the +- calculation of max_block_size in fused_add_rms_norm */ +- if (num_tokens < 256) { +- variance = blockReduceSum(variance); +- } else variance = blockReduceSum(variance); ++ ++ using BlockReduce = cub::BlockReduce; ++ __shared__ typename BlockReduce::TempStorage reduceStore; ++ variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); ++ + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } +@@ -233,52 +97,49 @@ __global__ std::enable_if_t< + } + } + +- + /* Generic fused_add_rms_norm_kernel + The width field is not used here but necessary for other specializations. + */ +-template +-__global__ std::enable_if_t< +- (width == 0) || !_typeConvert::exists> fused_add_rms_norm_kernel( +- scalar_t* __restrict__ input, // [..., hidden_size] +- scalar_t* __restrict__ residual, // [..., hidden_size] +- const scalar_t* __restrict__ weight, // [hidden_size] +- const float epsilon, +- const int num_tokens, +- const int hidden_size) { ++template ++__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> ++fused_add_rms_norm_kernel( ++ scalar_t* __restrict__ input, // [..., hidden_size] ++ scalar_t* __restrict__ residual, // [..., hidden_size] ++ const scalar_t* __restrict__ weight, // [hidden_size] ++ const float epsilon, const int num_tokens, const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + scalar_t z = input[blockIdx.x * hidden_size + idx]; + z += residual[blockIdx.x * hidden_size + idx]; +- float x = (float) z; ++ float x = (float)z; + variance += x * x; + residual[blockIdx.x * hidden_size + idx] = z; + } +- /* Keep the following if-else block in sync with the +- calculation of max_block_size in fused_add_rms_norm */ +- if (num_tokens < 256) { +- variance = blockReduceSum(variance); +- } else variance = blockReduceSum(variance); ++ ++ using BlockReduce = cub::BlockReduce; ++ __shared__ typename BlockReduce::TempStorage reduceStore; ++ variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); ++ + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { +- float x = (float) residual[blockIdx.x * hidden_size + idx]; +- input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; ++ float x = (float)residual[blockIdx.x * hidden_size + idx]; ++ input[blockIdx.x * hidden_size + idx] = ++ ((scalar_t)(x * s_variance)) * weight[idx]; + } + } + +-} // namespace vllm ++} // namespace vllm + +-void rms_norm( +- torch::Tensor& out, // [..., hidden_size] +- torch::Tensor& input, // [..., hidden_size] +- torch::Tensor& weight, // [hidden_size] +- float epsilon) { ++void rms_norm(torch::Tensor& out, // [..., hidden_size] ++ torch::Tensor& input, // [..., hidden_size] ++ torch::Tensor& weight, // [hidden_size] ++ double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + +@@ -286,40 +147,27 @@ void rms_norm( + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); +- VLLM_DISPATCH_FLOATING_TYPES( +- input.scalar_type(), +- "rms_norm_kernel", +- [&] { +- vllm::rms_norm_kernel<<>>( +- out.data_ptr(), +- input.data_ptr(), +- weight.data_ptr(), +- epsilon, +- num_tokens, +- hidden_size); +- }); ++ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { ++ vllm::rms_norm_kernel<<>>( ++ out.data_ptr(), input.data_ptr(), ++ weight.data_ptr(), epsilon, num_tokens, hidden_size); ++ }); + } + +-#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ +- VLLM_DISPATCH_FLOATING_TYPES( \ +- input.scalar_type(), \ +- "fused_add_rms_norm_kernel", \ +- [&] { \ +- vllm::fused_add_rms_norm_kernel \ +- <<>>( \ +- input.data_ptr(), \ +- residual.data_ptr(), \ +- weight.data_ptr(), \ +- epsilon, \ +- num_tokens, \ +- hidden_size); \ +- }); +- +-void fused_add_rms_norm( +- torch::Tensor& input, // [..., hidden_size] +- torch::Tensor& residual, // [..., hidden_size] +- torch::Tensor& weight, // [hidden_size] +- float epsilon) { ++#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ ++ VLLM_DISPATCH_FLOATING_TYPES( \ ++ input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ ++ vllm::fused_add_rms_norm_kernel \ ++ <<>>(input.data_ptr(), \ ++ residual.data_ptr(), \ ++ weight.data_ptr(), epsilon, \ ++ num_tokens, hidden_size); \ ++ }); ++ ++void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] ++ torch::Tensor& residual, // [..., hidden_size] ++ torch::Tensor& weight, // [hidden_size] ++ double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + +@@ -342,8 +190,8 @@ void fused_add_rms_norm( + auto inp_ptr = reinterpret_cast(input.data_ptr()); + auto res_ptr = reinterpret_cast(residual.data_ptr()); + auto wt_ptr = reinterpret_cast(weight.data_ptr()); +- bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \ +- && wt_ptr % 16 == 0; ++ bool ptrs_are_aligned = ++ inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; + if (ptrs_are_aligned && hidden_size % 8 == 0) { + LAUNCH_FUSED_ADD_RMS_NORM(8); + } else { +diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu +new file mode 100644 +index 0000000..c18e2a4 +--- /dev/null ++++ b/csrc/layernorm_quant_kernels.cu +@@ -0,0 +1,234 @@ ++/* ++ * This file contains the CUDA kernels for the fused quantized layernorm. ++ * The kernels correspond to the kernels in layernorm_kernels.cu, except they ++ * also produce quantized output directly. ++ * Currently, only static fp8 quantization is supported. ++ */ ++ ++#include "type_convert.cuh" ++#include "quantization/fp8/common.cuh" ++#include "dispatch_utils.h" ++ ++#include ++#include ++ ++#ifndef USE_ROCM ++ #include ++#else ++ #include ++#endif ++ ++namespace vllm { ++ ++// TODO(woosuk): Further optimize this kernel. ++template ++__global__ void rms_norm_static_fp8_quant_kernel( ++ FP8_TYPE* __restrict__ out, // [..., hidden_size] ++ const scalar_t* __restrict__ input, // [..., hidden_size] ++ const scalar_t* __restrict__ weight, // [hidden_size] ++ const float* __restrict__ scale, // [1] ++ const float epsilon, const int num_tokens, const int hidden_size) { ++ __shared__ float s_variance; ++ float variance = 0.0f; ++ ++ for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { ++ const float x = (float)input[blockIdx.x * hidden_size + idx]; ++ variance += x * x; ++ } ++ ++ using BlockReduce = cub::BlockReduce; ++ __shared__ typename BlockReduce::TempStorage reduceStore; ++ variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); ++ ++ if (threadIdx.x == 0) { ++ s_variance = rsqrtf(variance / hidden_size + epsilon); ++ } ++ __syncthreads(); ++ ++ // invert scale to avoid division ++ float const scale_inv = 1.0f / *scale; ++ ++ for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { ++ float x = (float)input[blockIdx.x * hidden_size + idx]; ++ float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; ++ out[blockIdx.x * hidden_size + idx] = ++ scaled_fp8_conversion(out_norm, scale_inv); ++ } ++} ++ ++/* Function specialization in the case of FP16/BF16 tensors. ++ Additional optimizations we can make in this case are ++ packed and vectorized operations, which help with the ++ memory latency bottleneck. */ ++template ++__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> ++fused_add_rms_norm_static_fp8_quant_kernel( ++ FP8_TYPE* __restrict__ out, // [..., hidden_size] ++ scalar_t* __restrict__ input, // [..., hidden_size] ++ scalar_t* __restrict__ residual, // [..., hidden_size] ++ const scalar_t* __restrict__ weight, // [hidden_size] ++ const float* __restrict__ scale, // [1] ++ const float epsilon, const int num_tokens, const int hidden_size) { ++ // Sanity checks on our vector struct and type-punned pointer arithmetic ++ static_assert(std::is_pod_v<_f16Vec>); ++ static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); ++ ++ const int vec_hidden_size = hidden_size / width; ++ __shared__ float s_variance; ++ float variance = 0.0f; ++ /* These and the argument pointers are all declared `restrict` as they are ++ not aliased in practice. Argument pointers should not be dereferenced ++ in this kernel as that would be undefined behavior */ ++ auto* __restrict__ input_v = ++ reinterpret_cast<_f16Vec*>(input); ++ auto* __restrict__ residual_v = ++ reinterpret_cast<_f16Vec*>(residual); ++ auto* __restrict__ weight_v = ++ reinterpret_cast*>(weight); ++ ++ for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { ++ int id = blockIdx.x * vec_hidden_size + idx; ++ _f16Vec temp = input_v[id]; ++ temp += residual_v[id]; ++ variance += temp.sum_squares(); ++ residual_v[id] = temp; ++ } ++ ++ using BlockReduce = cub::BlockReduce; ++ __shared__ typename BlockReduce::TempStorage reduceStore; ++ variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); ++ ++ if (threadIdx.x == 0) { ++ s_variance = rsqrtf(variance / hidden_size + epsilon); ++ } ++ __syncthreads(); ++ ++ // invert scale to avoid division ++ float const scale_inv = 1.0f / *scale; ++ ++ for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { ++ int id = blockIdx.x * vec_hidden_size + idx; ++ _f16Vec temp = residual_v[id]; ++ temp *= s_variance; ++ temp *= weight_v[idx]; ++#pragma unroll ++ for (int i = 0; i < width; ++i) { ++ out[id * width + i] = ++ scaled_fp8_conversion(float(temp.data[i]), scale_inv); ++ } ++ } ++} ++ ++/* Generic fused_add_rms_norm_kernel ++ The width field is not used here but necessary for other specializations. ++ */ ++template ++__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> ++fused_add_rms_norm_static_fp8_quant_kernel( ++ FP8_TYPE* __restrict__ out, // [..., hidden_size] ++ scalar_t* __restrict__ input, // [..., hidden_size] ++ scalar_t* __restrict__ residual, // [..., hidden_size] ++ const scalar_t* __restrict__ weight, // [hidden_size] ++ const float* __restrict__ scale, // [1] ++ const float epsilon, const int num_tokens, const int hidden_size) { ++ __shared__ float s_variance; ++ float variance = 0.0f; ++ ++ for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { ++ scalar_t z = input[blockIdx.x * hidden_size + idx]; ++ z += residual[blockIdx.x * hidden_size + idx]; ++ float x = (float)z; ++ variance += x * x; ++ residual[blockIdx.x * hidden_size + idx] = z; ++ } ++ ++ using BlockReduce = cub::BlockReduce; ++ __shared__ typename BlockReduce::TempStorage reduceStore; ++ variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); ++ ++ if (threadIdx.x == 0) { ++ s_variance = rsqrtf(variance / hidden_size + epsilon); ++ } ++ __syncthreads(); ++ ++ // invert scale to avoid division ++ float const scale_inv = 1.0f / *scale; ++ ++ for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { ++ float x = (float)residual[blockIdx.x * hidden_size + idx]; ++ float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; ++ out[blockIdx.x * hidden_size + idx] = ++ scaled_fp8_conversion(out_norm, scale_inv); ++ } ++} ++ ++} // namespace vllm ++ ++void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] ++ torch::Tensor& input, // [..., hidden_size] ++ torch::Tensor& weight, // [hidden_size] ++ torch::Tensor& scale, // [1] ++ double epsilon) { ++ int hidden_size = input.size(-1); ++ int num_tokens = input.numel() / hidden_size; ++ ++ dim3 grid(num_tokens); ++ dim3 block(std::min(hidden_size, 1024)); ++ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); ++ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); ++ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { ++ vllm::rms_norm_static_fp8_quant_kernel ++ <<>>( ++ out.data_ptr(), input.data_ptr(), ++ weight.data_ptr(), scale.data_ptr(), epsilon, ++ num_tokens, hidden_size); ++ }); ++} ++ ++#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ ++ VLLM_DISPATCH_FLOATING_TYPES( \ ++ input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ ++ vllm::fused_add_rms_norm_static_fp8_quant_kernel \ ++ <<>>( \ ++ out.data_ptr(), input.data_ptr(), \ ++ residual.data_ptr(), weight.data_ptr(), \ ++ scale.data_ptr(), epsilon, num_tokens, hidden_size); \ ++ }); ++ ++void fused_add_rms_norm_static_fp8_quant( ++ torch::Tensor& out, // [..., hidden_size], ++ torch::Tensor& input, // [..., hidden_size] ++ torch::Tensor& residual, // [..., hidden_size] ++ torch::Tensor& weight, // [hidden_size] ++ torch::Tensor& scale, // [1] ++ double epsilon) { ++ int hidden_size = input.size(-1); ++ int num_tokens = input.numel() / hidden_size; ++ ++ dim3 grid(num_tokens); ++ /* This kernel is memory-latency bound in many scenarios. ++ When num_tokens is large, a smaller block size allows ++ for increased block occupancy on CUs and better latency ++ hiding on global mem ops. */ ++ const int max_block_size = (num_tokens < 256) ? 1024 : 256; ++ dim3 block(std::min(hidden_size, max_block_size)); ++ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); ++ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); ++ /*If the tensor types are FP16/BF16, try to use the optimized kernel ++ with packed + vectorized ops. ++ Max optimization is achieved with a width-8 vector of FP16/BF16s ++ since we can load at most 128 bits at once in a global memory op. ++ However, this requires each tensor's data to be aligned to 16 ++ bytes. ++ */ ++ auto inp_ptr = reinterpret_cast(input.data_ptr()); ++ auto res_ptr = reinterpret_cast(residual.data_ptr()); ++ auto wt_ptr = reinterpret_cast(weight.data_ptr()); ++ bool ptrs_are_aligned = ++ inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; ++ if (ptrs_are_aligned && hidden_size % 8 == 0) { ++ LAUNCH_FUSED_ADD_RMS_NORM(8); ++ } else { ++ LAUNCH_FUSED_ADD_RMS_NORM(0); ++ } ++} +diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu +new file mode 100644 +index 0000000..f0e5533 +--- /dev/null ++++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu +@@ -0,0 +1,662 @@ ++// clang-format off ++// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu ++// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu ++#include ++#include ++#include ++ ++#include "causal_conv1d.h" ++#include ++#include ++#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK ++ ++#include ++#include ++ ++#include "static_switch.h" ++ ++ ++ ++#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") ++ ++#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ ++ if (ITYPE == at::ScalarType::Half) { \ ++ using input_t = at::Half; \ ++ using weight_t = at::Half; \ ++ __VA_ARGS__(); \ ++ } else if (ITYPE == at::ScalarType::BFloat16) { \ ++ using input_t = at::BFloat16; \ ++ using weight_t = at::BFloat16; \ ++ __VA_ARGS__(); \ ++ } else if (ITYPE == at::ScalarType::Float) { \ ++ using input_t = float; \ ++ using weight_t = float; \ ++ __VA_ARGS__(); \ ++ } else { \ ++ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ ++ } ++ ++ ++template ++void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); ++ ++template ++void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); ++ ++void set_conv_params_fwd(ConvParamsBase ¶ms, ++ // sizes ++ const size_t batch, ++ const size_t dim, ++ const size_t seqlen, ++ const size_t width, ++ // device pointers ++ const at::Tensor x, ++ const at::Tensor weight, ++ const at::Tensor out, ++ const std::optional& bias, ++ bool silu_activation, ++ int64_t pad_slot_id, ++ const std::optional& query_start_loc = std::nullopt, ++ const std::optional& cache_indices = std::nullopt, ++ const std::optional& has_initial_state = std::nullopt) { ++ ++ // Reset the parameters ++ memset(¶ms, 0, sizeof(params)); ++ ++ params.batch = batch; ++ params.dim = dim; ++ params.seqlen = seqlen; ++ params.width = width; ++ params.pad_slot_id = pad_slot_id; ++ ++ params.silu_activation = silu_activation; ++ ++ // Set the pointers and strides. ++ params.x_ptr = x.data_ptr(); ++ params.weight_ptr = weight.data_ptr(); ++ params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr; ++ params.out_ptr = out.data_ptr(); ++ // All stride are in elements, not bytes. ++ params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr; ++ params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr; ++ params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr; ++ const bool varlen = params.query_start_loc_ptr != nullptr; ++ params.x_batch_stride = x.stride(varlen ? 1 : 0); ++ params.x_c_stride = x.stride(varlen ? 0 : 1); ++ params.x_l_stride = x.stride(varlen ? 1 : -1); ++ params.weight_c_stride = weight.stride(0); ++ params.weight_width_stride = weight.stride(1); ++ params.out_batch_stride = out.stride(varlen ? 1 : 0); ++ params.out_c_stride = out.stride(varlen ? 0 : 1); ++ params.out_l_stride = out.stride(varlen ? 1 : -1); ++} ++ ++ ++void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, ++ const std::optional &bias_, ++ const std::optional &conv_states, ++ const std::optional &query_start_loc, ++ const std::optional &cache_indices, ++ const std::optional &has_initial_state, ++ bool silu_activation, ++ // used to identify padding entries if cache_indices provided ++ // in case of padding, the kernel will return early ++ int64_t pad_slot_id) { ++ auto input_type = x.scalar_type(); ++ auto weight_type = weight.scalar_type(); ++ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); ++ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); ++ ++ TORCH_CHECK(x.is_cuda()); ++ TORCH_CHECK(weight.is_cuda()); ++ ++ const bool varlen = query_start_loc.has_value() ? true : false; ++ const auto sizes = x.sizes(); ++ const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0]; ++ const int dim = varlen ? sizes[0] : sizes[1]; ++ const int seqlen = varlen ? sizes[1] : sizes[2]; ++ const int width = weight.size(-1); ++ if (varlen){ ++ CHECK_SHAPE(x, dim, seqlen); ++ } ++ else { ++ CHECK_SHAPE(x, batch_size, dim, seqlen); ++ } ++ CHECK_SHAPE(weight, dim, width); ++ ++ ++ ++ if (bias_.has_value()) { ++ auto bias = bias_.value(); ++ TORCH_CHECK(bias.scalar_type() == weight_type); ++ TORCH_CHECK(bias.is_cuda()); ++ TORCH_CHECK(bias.stride(-1) == 1); ++ CHECK_SHAPE(bias, dim); ++ } ++ ++ ++ if (has_initial_state.has_value()) { ++ auto has_initial_state_ = has_initial_state.value(); ++ TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool); ++ TORCH_CHECK(has_initial_state_.is_cuda()); ++ CHECK_SHAPE(has_initial_state_, batch_size); ++ } ++ ++ ++ if (query_start_loc.has_value()) { ++ auto query_start_loc_ = query_start_loc.value(); ++ TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int); ++ TORCH_CHECK(query_start_loc_.is_cuda()); ++ } ++ ++ ++ if (cache_indices.has_value()) { ++ auto cache_indices_ = cache_indices.value(); ++ TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int); ++ TORCH_CHECK(cache_indices_.is_cuda()); ++ CHECK_SHAPE(cache_indices_, batch_size); ++ } ++ ++ at::Tensor out = x; ++ ++ ConvParamsBase params; ++ set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, ++ bias_, ++ silu_activation, ++ pad_slot_id, ++ query_start_loc, ++ cache_indices, ++ has_initial_state ++ ); ++ ++ if (conv_states.has_value()) { ++ auto conv_states_ = conv_states.value(); ++ TORCH_CHECK(conv_states_.scalar_type() == input_type); ++ TORCH_CHECK(conv_states_.is_cuda()); ++ params.conv_states_ptr = conv_states_.data_ptr(); ++ params.conv_states_batch_stride = conv_states_.stride(0); ++ params.conv_states_c_stride = conv_states_.stride(1); ++ params.conv_states_l_stride = conv_states_.stride(2); ++ } else { ++ params.conv_states_ptr = nullptr; ++ } ++ ++ // Otherwise the kernel will be launched from cuda:0 device ++ // Cast to char to avoid compiler warning about narrowing ++ at::cuda::CUDAGuard device_guard{(char)x.get_device()}; ++ auto stream = at::cuda::getCurrentCUDAStream().stream(); ++ DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { ++ causal_conv1d_fwd_cuda(params, stream); ++ }); ++} ++ ++ ++void causal_conv1d_update(const at::Tensor &x, ++ const at::Tensor &conv_state, ++ const at::Tensor &weight, ++ const std::optional &bias_, ++ bool silu_activation, ++ const std::optional &cache_seqlens_, ++ const std::optional &conv_state_indices_, ++ // used to identify padding entries if cache_indices provided ++ // in case of padding, the kernel will return early ++ int64_t pad_slot_id) { ++ auto input_type = x.scalar_type(); ++ auto weight_type = weight.scalar_type(); ++ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); ++ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); ++ TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations"); ++ TORCH_CHECK(conv_state.scalar_type() == input_type); ++ ++ TORCH_CHECK(x.is_cuda()); ++ TORCH_CHECK(conv_state.is_cuda()); ++ TORCH_CHECK(weight.is_cuda()); ++ ++ const auto sizes = x.sizes(); ++ const int batch_size = sizes[0]; ++ const int dim = sizes[1]; ++ const int seqlen = sizes[2]; ++ const int width = weight.size(-1); ++ const int conv_state_len = conv_state.size(2); ++ TORCH_CHECK(conv_state_len >= width - 1); ++ ++ CHECK_SHAPE(x, batch_size, dim, seqlen); ++ CHECK_SHAPE(weight, dim, width); ++ ++ TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); ++ ++ if (bias_.has_value()) { ++ auto bias = bias_.value(); ++ TORCH_CHECK(bias.scalar_type() == weight_type); ++ TORCH_CHECK(bias.is_cuda()); ++ TORCH_CHECK(bias.stride(-1) == 1); ++ CHECK_SHAPE(bias, dim); ++ } ++ ++ at::Tensor out = x; ++ ++ ConvParamsBase params; ++ set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, ++ bias_, ++ silu_activation, ++ pad_slot_id); ++ params.conv_state_ptr = conv_state.data_ptr(); ++ params.conv_state_len = conv_state_len; ++ // All stride are in elements, not bytes. ++ params.conv_state_batch_stride = conv_state.stride(0); ++ params.conv_state_c_stride = conv_state.stride(1); ++ params.conv_state_l_stride = conv_state.stride(2); ++ ++ if (cache_seqlens_.has_value()) { ++ auto cache_seqlens = cache_seqlens_.value(); ++ TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32); ++ TORCH_CHECK(cache_seqlens.is_cuda()); ++ TORCH_CHECK(cache_seqlens.stride(-1) == 1); ++ CHECK_SHAPE(cache_seqlens, batch_size); ++ params.cache_seqlens = cache_seqlens.data_ptr(); ++ } else { ++ params.cache_seqlens = nullptr; ++ } ++ ++ if (conv_state_indices_.has_value()) { ++ auto conv_state_indices = conv_state_indices_.value(); ++ TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32) ++ TORCH_CHECK(conv_state_indices.is_cuda()); ++ TORCH_CHECK(conv_state_indices.stride(0) == 1) ++ CHECK_SHAPE(conv_state_indices, batch_size); ++ ++ int conv_state_entries = conv_state.size(0); ++ CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len); ++ ++ params.conv_state_indices_ptr = conv_state_indices.data_ptr(); ++ } else { ++ CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len); ++ params.conv_state_indices_ptr = nullptr; ++ } ++ ++ // Otherwise the kernel will be launched from cuda:0 device ++ // Cast to char to avoid compiler warning about narrowing ++ at::cuda::CUDAGuard device_guard{(char)x.get_device()}; ++ auto stream = at::cuda::getCurrentCUDAStream().stream(); ++ DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { ++ causal_conv1d_update_cuda(params, stream); ++ }); ++} ++ ++template ++struct Causal_conv1d_fwd_kernel_traits { ++ using input_t = input_t_; ++ using weight_t = weight_t_; ++ static constexpr int kNThreads = kNThreads_; ++ static constexpr int kWidth = kWidth_; ++ static constexpr int kNBytes = sizeof(input_t); ++ static_assert(kNBytes == 2 || kNBytes == 4); ++ static constexpr int kNElts = kNBytes == 4 ? 4 : 8; ++ static_assert(kWidth <= kNElts); ++ static constexpr bool kIsVecLoad = kIsVecLoad_; ++ using vec_t = typename BytesToType::Type; ++ using BlockLoadT = cub::BlockLoad; ++ using BlockLoadVecT = cub::BlockLoad; ++ using BlockStoreT = cub::BlockStore; ++ using BlockStoreVecT = cub::BlockStore; ++ static constexpr int kSmemIOSize = kIsVecLoad ++ ? 0 ++ : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)}); ++ static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; ++ static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; ++}; ++ ++template ++__global__ __launch_bounds__(Ktraits::kNThreads) ++void causal_conv1d_fwd_kernel(ConvParamsBase params) { ++ constexpr int kWidth = Ktraits::kWidth; ++ constexpr int kNThreads = Ktraits::kNThreads; ++ constexpr int kNElts = Ktraits::kNElts; ++ constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; ++ using input_t = typename Ktraits::input_t; ++ using vec_t = typename Ktraits::vec_t; ++ using weight_t = typename Ktraits::weight_t; ++ ++ // Shared memory. ++ extern __shared__ char smem_[]; ++ auto& smem_load = reinterpret_cast(smem_); ++ auto& smem_load_vec = reinterpret_cast(smem_); ++ auto& smem_store = reinterpret_cast(smem_); ++ auto& smem_store_vec = reinterpret_cast(smem_); ++ vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); ++ ++ const bool kVarlen = params.query_start_loc_ptr != nullptr; ++ const int tidx = threadIdx.x; ++ const int batch_id = blockIdx.x; ++ const int channel_id = blockIdx.y; ++ const int *query_start_loc = kVarlen ? reinterpret_cast(params.query_start_loc_ptr) : nullptr; ++ const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id; ++ const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen; ++ ++ input_t *x = reinterpret_cast(params.x_ptr) + sequence_start_index * params.x_batch_stride ++ + channel_id * params.x_c_stride; ++ weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; ++ input_t *out = reinterpret_cast(params.out_ptr) + sequence_start_index * params.out_batch_stride ++ + channel_id * params.out_c_stride; ++ float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); ++ ++ bool has_initial_state = params.has_initial_state_ptr == nullptr ? false ++ : reinterpret_cast(params.has_initial_state_ptr)[batch_id]; ++ ++ int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr ++ : reinterpret_cast(params.cache_indices_ptr); ++ int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; ++ // cache_index == params.pad_slot_id is defined as padding, so we exit early ++ if (cache_index == params.pad_slot_id){ ++ return; ++ } ++ input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr ++ : reinterpret_cast(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride; ++ ++ // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. ++ if (tidx == 0) { ++ input_t initial_state[kNElts] = {0}; ++ if (has_initial_state) { ++ #pragma unroll ++ for (int w = 0; w < kWidth - 1; ++w){ initial_state[kNElts - 1 - (kWidth - 2) + w ] = conv_states[w]; } ++ } ++ smem_exchange[kNThreads - 1] = reinterpret_cast(initial_state)[0]; ++ } ++ ++ float weight_vals[kWidth]; ++ #pragma unroll ++ for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } ++ ++ constexpr int kChunkSize = kNThreads * kNElts; ++ const int n_chunks = (seqlen + kChunkSize - 1) / kChunkSize; ++ for (int chunk = 0; chunk < n_chunks; ++chunk) { ++ input_t x_vals_load[2 * kNElts] = {0}; ++ if constexpr(kIsVecLoad) { ++ typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (seqlen - chunk * kChunkSize) / kNElts); ++ } else { ++ __syncthreads(); ++ typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize); ++ } ++ x += kChunkSize; ++ __syncthreads(); ++ // Thread kNThreads - 1 don't write yet, so that thread 0 can read ++ // the last elements of the previous chunk. ++ if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } ++ __syncthreads(); ++ reinterpret_cast(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; ++ __syncthreads(); ++ // Now thread kNThreads - 1 can write the last elements of the current chunk. ++ if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } ++ ++ float x_vals[2 * kNElts]; ++ #pragma unroll ++ for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } ++ ++ float out_vals[kNElts]; ++ #pragma unroll ++ for (int i = 0; i < kNElts; ++i) { ++ out_vals[i] = bias_val; ++ #pragma unroll ++ for (int w = 0; w < kWidth; ++w) { ++ out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; ++ } ++ } ++ ++ if (params.silu_activation) { ++ #pragma unroll ++ for (int i = 0; i < kNElts; ++i) { ++ out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); ++ } ++ } ++ ++ input_t out_vals_store[kNElts]; ++ #pragma unroll ++ for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } ++ if constexpr(kIsVecLoad) { ++ typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (seqlen - chunk * kChunkSize) / kNElts); ++ } else { ++ typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize); ++ } ++ out += kChunkSize; ++ ++ int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize); ++ // in case the final state is separated between the last "smem_exchange" and ++ // and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2), ++ // (which occurs when `final_state_position` is a non-positivie index) ++ // we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it ++ if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){ ++ input_t vals_load[kNElts] = {0}; ++ if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){ ++ // chunk = n_chunks - 2, a segment of the final state sits in the last index ++ reinterpret_cast(vals_load)[0] = smem_exchange[kNThreads - 1]; ++ #pragma unroll ++ for (int w = 0; w < -final_state_position; ++w){ ++ conv_states[w] = vals_load[kNElts + final_state_position + w]; ++ } ++ } ++ if ((chunk == n_chunks - 1) && tidx == 0){ ++ // chunk = n_chunks - 1, the second segment of the final state first positions ++ reinterpret_cast(vals_load)[0] = smem_exchange[0]; ++ for (int w = -final_state_position; w < kWidth - 1; ++w){ ++ conv_states[w] = vals_load[w + final_state_position]; ++ } ++ return; ++ } ++ } ++ } ++ // Final state is stored in the smem_exchange last token slot, ++ // in case seqlen < kWidth, we would need to take the final state from the ++ // initial state which is stored in conv_states ++ // in case seqlen > kWidth, we would need to load the last kWidth - 1 data ++ // and load it into conv_state accordingly ++ int last_thread = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts; ++ if (conv_states != nullptr && tidx == last_thread) { ++ input_t x_vals_load[kNElts * 2] = {0}; ++ // in case we are on the first kWidth tokens ++ if (last_thread == 0 && seqlen < kWidth){ ++ // Need to take the initial state ++ reinterpret_cast(x_vals_load)[0] = smem_exchange[0]; ++ const int offset = seqlen - (kWidth - 1); ++ #pragma unroll ++ for (int w = 0; w < kWidth - 1; ++w){ ++ // pad the existing state ++ if ((w - seqlen) >= 0 && has_initial_state) { conv_states[w - seqlen] = conv_states[w]; } ++ else if ((w - seqlen) >= 0 && !has_initial_state) { conv_states[w - seqlen] = input_t(0.0f); } ++ } ++ #pragma unroll ++ for (int w = 0; w < kWidth - 1; ++w){ ++ if (offset + w >= 0) ++ conv_states[w] = x_vals_load[offset + w ]; ++ } ++ } ++ else { ++ // in case the final state is in between the threads data ++ const int offset = ((seqlen - (kWidth - 1)) % (kNElts)); ++ if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){ ++ // In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a ++ // illegal access error on H100. ++ // Therefore, we access last_thread + 1, only if the final state data sits there ++ reinterpret_cast(x_vals_load)[1] = smem_exchange[last_thread + 1]; ++ } ++ reinterpret_cast(x_vals_load)[0] = smem_exchange[last_thread]; ++ #pragma unroll ++ for (int w = 0; w < kWidth - 1; ++w){ ++ conv_states[w] = x_vals_load[offset + w ]; ++ } ++ } ++ ++ } ++} ++ ++ ++template ++void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { ++ static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; ++ const bool kVarlen = params.query_start_loc_ptr != nullptr; ++ BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] { ++ using Ktraits = Causal_conv1d_fwd_kernel_traits; ++ constexpr int kSmemSize = Ktraits::kSmemSize; ++ dim3 grid(params.batch, params.dim); ++ ++ auto kernel = &causal_conv1d_fwd_kernel; ++ ++ if (kSmemSize >= 48 * 1024) { ++ #ifndef USE_ROCM ++ C10_CUDA_CHECK(cudaFuncSetAttribute( ++ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); ++ #else ++ // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function. ++ C10_CUDA_CHECK(cudaFuncSetAttribute( ++ (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); ++ std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; ++ #endif ++ } ++ kernel<<>>(params); ++ ++ C10_CUDA_KERNEL_LAUNCH_CHECK(); ++ }); ++} ++ ++template ++void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { ++ if (params.width == 2) { ++ causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); ++ } else if (params.width == 3) { ++ causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); ++ } else if (params.width == 4) { ++ causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); ++ } ++} ++ ++ ++template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); ++template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); ++template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); ++ ++ ++ ++ ++template ++struct Causal_conv1d_update_kernel_traits { ++ using input_t = input_t_; ++ using weight_t = weight_t_; ++ static constexpr int kNThreads = kNThreads_; ++ static constexpr int kWidth = kWidth_; ++ static constexpr int kNBytes = sizeof(input_t); ++ static_assert(kNBytes == 2 || kNBytes == 4); ++}; ++ ++template ++__global__ __launch_bounds__(Ktraits::kNThreads) ++void causal_conv1d_update_kernel(ConvParamsBase params) { ++ constexpr int kWidth = Ktraits::kWidth; ++ constexpr int kNThreads = Ktraits::kNThreads; ++ using input_t = typename Ktraits::input_t; ++ using weight_t = typename Ktraits::weight_t; ++ ++ const int tidx = threadIdx.x; ++ const int batch_id = blockIdx.x; ++ const int channel_id = blockIdx.y * kNThreads + tidx; ++ if (channel_id >= params.dim) return; ++ ++ input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride ++ + channel_id * params.x_c_stride; ++ ++ // If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor ++ // along the batch axis. Otherwise, the conv state coordinate is the same as the batch id. ++ const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr ++ ? batch_id ++ : params.conv_state_indices_ptr[batch_id]; ++ // conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early ++ if (conv_state_batch_coord == params.pad_slot_id){ ++ return; ++ } ++ input_t *conv_state = reinterpret_cast(params.conv_state_ptr) ++ + conv_state_batch_coord * params.conv_state_batch_stride ++ + channel_id * params.conv_state_c_stride; ++ ++ weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; ++ input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride ++ + channel_id * params.out_c_stride; ++ float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); ++ ++ int state_len = params.conv_state_len; ++ int advance_len = params.seqlen; ++ int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0; ++ int update_idx = cache_seqlen - (kWidth - 1); ++ update_idx = update_idx < 0 ? update_idx + state_len : update_idx; ++ ++ float weight_vals[kWidth] = {0}; ++ #pragma unroll ++ for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } ++ ++ float x_vals[kWidth] = {0}; ++ if constexpr (!kIsCircularBuffer) { ++ #pragma unroll 2 ++ for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) { ++ conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride]; ++ } ++ #pragma unroll ++ for (int i = 0; i < kWidth - 1; ++i) { ++ input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride]; ++ if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) { ++ conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val; ++ } ++ x_vals[i] = float(state_val); ++ } ++ } else { ++ #pragma unroll ++ for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) { ++ input_t state_val = conv_state[update_idx * params.conv_state_l_stride]; ++ x_vals[i] = float(state_val); ++ } ++ } ++ #pragma unroll 2 ++ for (int i = 0; i < params.seqlen; ++i) { ++ input_t x_val = x[i * params.x_l_stride]; ++ if constexpr (!kIsCircularBuffer) { ++ if (i < advance_len && state_len - advance_len + i >= 0) { ++ conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val; ++ } ++ } else { ++ conv_state[update_idx * params.conv_state_l_stride] = x_val; ++ ++update_idx; ++ update_idx = update_idx >= state_len ? update_idx - state_len : update_idx; ++ } ++ x_vals[kWidth - 1] = float(x_val); ++ float out_val = bias_val; ++ #pragma unroll ++ for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; } ++ if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } ++ out[i * params.out_l_stride] = input_t(out_val); ++ // Shift the input buffer by 1 ++ #pragma unroll ++ for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; } ++ } ++} ++ ++template ++void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { ++ using Ktraits = Causal_conv1d_update_kernel_traits; ++ dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); ++ auto kernel = params.cache_seqlens == nullptr ++ ? &causal_conv1d_update_kernel ++ : &causal_conv1d_update_kernel; ++ kernel<<>>(params); ++ C10_CUDA_KERNEL_LAUNCH_CHECK(); ++} ++ ++template ++void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { ++ if (params.width == 2) { ++ causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); ++ } else if (params.width == 3) { ++ causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); ++ } else if (params.width == 4) { ++ causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); ++ } ++} ++ ++template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); ++template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); ++template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h +new file mode 100644 +index 0000000..e26684a +--- /dev/null ++++ b/csrc/mamba/causal_conv1d/causal_conv1d.h +@@ -0,0 +1,159 @@ ++/****************************************************************************** ++ * Copyright (c) 2024, Tri Dao. ++ ******************************************************************************/ ++// clang-format off ++// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h ++#pragma once ++ ++#include ++#include ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++struct ConvParamsBase { ++ using index_t = uint32_t; ++ ++ int batch, dim, seqlen, width; ++ int64_t pad_slot_id; ++ bool silu_activation; ++ ++ index_t x_batch_stride; ++ index_t x_c_stride; ++ index_t x_l_stride; ++ index_t weight_c_stride; ++ index_t weight_width_stride; ++ index_t out_batch_stride; ++ index_t out_c_stride; ++ index_t out_l_stride; ++ ++ int conv_state_len; ++ index_t conv_state_batch_stride; ++ index_t conv_state_c_stride; ++ index_t conv_state_l_stride; ++ ++ // Common data pointers. ++ void *__restrict__ x_ptr; ++ void *__restrict__ weight_ptr; ++ void *__restrict__ bias_ptr; ++ void *__restrict__ out_ptr; ++ ++ void *__restrict__ conv_state_ptr; ++ void *__restrict__ query_start_loc_ptr; ++ void *__restrict__ has_initial_state_ptr; ++ void *__restrict__ cache_indices_ptr; ++ int32_t *__restrict__ cache_seqlens; ++ ++ // For the continuous batching case. Makes it so that the mamba state for ++ // the current batch doesn't need to be a contiguous tensor. ++ int32_t *__restrict__ conv_state_indices_ptr; ++ ++ void *__restrict__ seq_idx_ptr; ++ ++ // No __restrict__ since initial_states could be the same as final_states. ++ void * initial_states_ptr; ++ index_t initial_states_batch_stride; ++ index_t initial_states_l_stride; ++ index_t initial_states_c_stride; ++ ++ void * final_states_ptr; ++ index_t final_states_batch_stride; ++ index_t final_states_l_stride; ++ index_t final_states_c_stride; ++ ++ void * conv_states_ptr; ++ index_t conv_states_batch_stride; ++ index_t conv_states_l_stride; ++ index_t conv_states_c_stride; ++}; ++ ++ ++#ifndef USE_ROCM ++ #include ++ ++ template ++ __device__ inline T shuffle_xor(T val, int offset) { ++ return __shfl_xor_sync(uint32_t(-1), val, offset); ++ } ++ ++ constexpr size_t custom_max(std::initializer_list ilist) ++ { ++ return std::max(ilist); ++ } ++ ++ template ++ constexpr T constexpr_min(T a, T b) { ++ return std::min(a, b); ++ } ++ ++#else ++ #include ++ ++ template ++ __device__ inline T shuffle_xor(T val, int offset) { ++ return __shfl_xor(val, offset); ++ } ++ constexpr size_t custom_max(std::initializer_list ilist) ++ { ++ return *std::max_element(ilist.begin(), ilist.end()); ++ } ++ ++ template ++ constexpr T constexpr_min(T a, T b) { ++ return a < b ? a : b; ++ } ++#endif ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++template struct BytesToType {}; ++ ++template<> struct BytesToType<16> { ++ using Type = uint4; ++ static_assert(sizeof(Type) == 16); ++}; ++ ++template<> struct BytesToType<8> { ++ using Type = uint64_t; ++ static_assert(sizeof(Type) == 8); ++}; ++ ++template<> struct BytesToType<4> { ++ using Type = uint32_t; ++ static_assert(sizeof(Type) == 4); ++}; ++ ++template<> struct BytesToType<2> { ++ using Type = uint16_t; ++ static_assert(sizeof(Type) == 2); ++}; ++ ++template<> struct BytesToType<1> { ++ using Type = uint8_t; ++ static_assert(sizeof(Type) == 1); ++}; ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++template ++struct SumOp { ++__device__ inline T operator()(T const & x, T const & y) { return x + y; } ++}; ++ ++template ++struct Allreduce { ++ static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); ++ template ++ static __device__ inline T run(T x, Operator &op) { ++ constexpr int OFFSET = THREADS / 2; ++ x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); ++ return Allreduce::run(x, op); ++ } ++}; ++ ++template<> ++struct Allreduce<2> { ++template ++static __device__ inline T run(T x, Operator &op) { ++ x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); ++ return x; ++} ++}; +diff --git a/csrc/mamba/causal_conv1d/static_switch.h b/csrc/mamba/causal_conv1d/static_switch.h +new file mode 100644 +index 0000000..ef74bf4 +--- /dev/null ++++ b/csrc/mamba/causal_conv1d/static_switch.h +@@ -0,0 +1,28 @@ ++// Inspired by ++// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h ++// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h ++// clang-format off ++// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h ++ ++#pragma once ++ ++/// @param COND - a boolean expression to switch by ++/// @param CONST_NAME - a name given for the constexpr bool variable. ++/// @param ... - code to execute for true and false ++/// ++/// Usage: ++/// ``` ++/// BOOL_SWITCH(flag, BoolConst, [&] { ++/// some_function(...); ++/// }); ++/// ``` ++#define BOOL_SWITCH(COND, CONST_NAME, ...) \ ++ [&] { \ ++ if (COND) { \ ++ static constexpr bool CONST_NAME = true; \ ++ return __VA_ARGS__(); \ ++ } else { \ ++ static constexpr bool CONST_NAME = false; \ ++ return __VA_ARGS__(); \ ++ } \ ++ }() +diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h +new file mode 100644 +index 0000000..563d2fe +--- /dev/null ++++ b/csrc/mamba/mamba_ssm/selective_scan.h +@@ -0,0 +1,266 @@ ++/****************************************************************************** ++ * Copyright (c) 2023, Tri Dao. ++ ******************************************************************************/ ++// clang-format off ++// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.h ++ ++#pragma once ++ ++#ifndef USE_ROCM ++ #include ++#else ++ #include ++#endif ++#include ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++struct SSMParamsBase { ++ using index_t = uint32_t; ++ ++ int batch, dim, seqlen, dstate, n_groups, n_chunks; ++ int dim_ngroups_ratio; ++ bool is_variable_B; ++ bool is_variable_C; ++ int64_t pad_slot_id; ++ ++ bool delta_softplus; ++ ++ index_t A_d_stride; ++ index_t A_dstate_stride; ++ index_t B_batch_stride; ++ index_t B_d_stride; ++ index_t B_dstate_stride; ++ index_t B_group_stride; ++ index_t C_batch_stride; ++ index_t C_d_stride; ++ index_t C_dstate_stride; ++ index_t C_group_stride; ++ index_t u_batch_stride; ++ index_t u_d_stride; ++ index_t delta_batch_stride; ++ index_t delta_d_stride; ++ index_t z_batch_stride; ++ index_t z_d_stride; ++ index_t out_batch_stride; ++ index_t out_d_stride; ++ index_t out_z_batch_stride; ++ index_t out_z_d_stride; ++ ++ // Common data pointers. ++ void *__restrict__ A_ptr; ++ void *__restrict__ B_ptr; ++ void *__restrict__ C_ptr; ++ void *__restrict__ D_ptr; ++ void *__restrict__ u_ptr; ++ void *__restrict__ delta_ptr; ++ void *__restrict__ delta_bias_ptr; ++ void *__restrict__ out_ptr; ++ void *__restrict__ ssm_states_ptr; ++ void *__restrict__ z_ptr; ++ void *__restrict__ out_z_ptr; ++ ++ void *__restrict__ query_start_loc_ptr; ++ void *__restrict__ cache_indices_ptr; ++ void *__restrict__ has_initial_state_ptr; ++ ++}; ++ ++ ++ ++ ++#ifndef USE_ROCM ++ ++ constexpr size_t custom_max(std::initializer_list ilist) ++ { ++ return std::max(ilist); ++ } ++ ++ template ++ constexpr T constexpr_min(T a, T b) { ++ return std::min(a, b); ++ } ++ ++#else ++ constexpr size_t custom_max(std::initializer_list ilist) ++ { ++ return *std::max_element(ilist.begin(), ilist.end()); ++ } ++ ++ template ++ constexpr T constexpr_min(T a, T b) { ++ return a < b ? a : b; ++ } ++#endif ++ ++ ++#define MAX_DSTATE 256 ++ ++ ++inline __device__ float2 operator+(const float2 & a, const float2 & b){ ++ return {a.x + b.x, a.y + b.y}; ++} ++ ++inline __device__ float3 operator+(const float3 &a, const float3 &b) { ++ return {a.x + b.x, a.y + b.y, a.z + b.z}; ++} ++ ++inline __device__ float4 operator+(const float4 & a, const float4 & b){ ++ return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; ++} ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++template struct BytesToType {}; ++ ++template<> struct BytesToType<16> { ++ using Type = uint4; ++ static_assert(sizeof(Type) == 16); ++}; ++ ++template<> struct BytesToType<8> { ++ using Type = uint64_t; ++ static_assert(sizeof(Type) == 8); ++}; ++ ++template<> struct BytesToType<4> { ++ using Type = uint32_t; ++ static_assert(sizeof(Type) == 4); ++}; ++ ++template<> struct BytesToType<2> { ++ using Type = uint16_t; ++ static_assert(sizeof(Type) == 2); ++}; ++ ++template<> struct BytesToType<1> { ++ using Type = uint8_t; ++ static_assert(sizeof(Type) == 1); ++}; ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++template ++struct Converter{ ++ static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { ++ #pragma unroll ++ for (int i = 0; i < N; ++i) { dst[i] = src[i]; } ++ } ++}; ++ ++template ++struct Converter{ ++ static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { ++ static_assert(N % 2 == 0); ++ auto &src2 = reinterpret_cast(src); ++ auto &dst2 = reinterpret_cast(dst); ++ #pragma unroll ++ for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } ++ } ++}; ++ ++#if __CUDA_ARCH__ >= 800 ++template ++struct Converter{ ++ static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { ++ static_assert(N % 2 == 0); ++ auto &src2 = reinterpret_cast(src); ++ auto &dst2 = reinterpret_cast(dst); ++ #pragma unroll ++ for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } ++ } ++}; ++#endif ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++ ++template struct SSMScanOp; ++ ++template<> ++struct SSMScanOp { ++ __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { ++ return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); ++ } ++}; ++ ++// A stateful callback functor that maintains a running prefix to be applied ++// during consecutive scan operations. ++template struct SSMScanPrefixCallbackOp { ++ using scan_t = std::conditional_t, float2, float4>; ++ scan_t running_prefix; ++ // Constructor ++ __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} ++ // Callback operator to be entered by the first warp of threads in the block. ++ // Thread-0 is responsible for returning a value for seeding the block-wide scan. ++ __device__ scan_t operator()(scan_t block_aggregate) { ++ scan_t old_prefix = running_prefix; ++ running_prefix = SSMScanOp()(running_prefix, block_aggregate); ++ return old_prefix; ++ } ++}; ++ ++//////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++template ++inline __device__ void load_input(typename Ktraits::input_t *u, ++ typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], ++ typename Ktraits::BlockLoadT::TempStorage &smem_load, ++ int seqlen) { ++ if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) { ++ auto& smem_load_vec = reinterpret_cast(smem_load); ++ using vec_t = typename Ktraits::vec_t; ++ typename Ktraits::BlockLoadVecT(smem_load_vec).Load( ++ reinterpret_cast(u), ++ reinterpret_cast(u_vals) ++ #ifdef USE_ROCM ++ , Ktraits::kNThreads * Ktraits::kNLoads ++ #endif ++ ++ ); ++ } else { ++ typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); ++ } ++} ++ ++ ++template ++inline __device__ void load_weight(typename Ktraits::input_t *Bvar, ++ typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], ++ typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, ++ int seqlen) { ++ constexpr int kNItems = Ktraits::kNItems; ++ typename Ktraits::input_t B_vals_load[kNItems]; ++ if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) { ++ auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); ++ using vec_t = typename Ktraits::vec_t; ++ typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( ++ reinterpret_cast(Bvar), ++ reinterpret_cast(B_vals_load) ++ ); ++ } else { ++ typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); ++ } ++ // #pragma unroll ++ // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } ++ Converter::to_float(B_vals_load, B_vals); ++} ++ ++template ++inline __device__ void store_output(typename Ktraits::input_t *out, ++ const float (&out_vals)[Ktraits::kNItems], ++ typename Ktraits::BlockStoreT::TempStorage &smem_store, ++ int seqlen) { ++ typename Ktraits::input_t write_vals[Ktraits::kNItems]; ++ #pragma unroll ++ for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } ++ if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) { ++ auto& smem_store_vec = reinterpret_cast(smem_store); ++ using vec_t = typename Ktraits::vec_t; ++ typename Ktraits::BlockStoreVecT(smem_store_vec).Store( ++ reinterpret_cast(out), ++ reinterpret_cast(write_vals) ++ ); ++ } else { ++ typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); ++ } ++} +diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +new file mode 100644 +index 0000000..bd0a341 +--- /dev/null ++++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +@@ -0,0 +1,658 @@ ++// clang-format off ++// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh ++#include ++#include ++#include ++#include "selective_scan.h" ++ ++#include ++#include ++#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK ++ ++#ifndef USE_ROCM ++ #include ++ #include ++ #include ++#else ++ #include ++ namespace cub = hipcub; ++#endif ++ ++#include "selective_scan.h" ++#include "static_switch.h" ++ ++template ++struct Selective_Scan_fwd_kernel_traits { ++ static_assert(kNItems_ % 4 == 0); ++ using input_t = input_t_; ++ using weight_t = weight_t_; ++ static constexpr int kNThreads = kNThreads_; ++ // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. ++ static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; ++ static constexpr int kNItems = kNItems_; ++ static constexpr int kNRows = kNRows_; ++ static constexpr int kNBytes = sizeof(input_t); ++ static_assert(kNBytes == 2 || kNBytes == 4); ++ static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems); ++ static_assert(kNItems % kNElts == 0); ++ static constexpr int kNLoads = kNItems / kNElts; ++ static constexpr bool kIsEvenLen = kVarlen_ ? false : kIsEvenLen_; ++ static constexpr bool kIsVariableB = kIsVariableB_; ++ static constexpr bool kIsVariableC = kIsVariableC_; ++ static constexpr bool kHasZ = kHasZ_; ++ static constexpr bool kVarlen = kVarlen_; ++ ++ static constexpr bool kDirectIO = kVarlen_ ? false : kIsEvenLen && kNLoads == 1; ++ static constexpr int kNLoadsIndex = kNItems / 4; ++ using vec_t = typename BytesToType::Type; ++ using scan_t = float2; ++ using BlockLoadT = cub::BlockLoad; ++ using BlockLoadVecT = cub::BlockLoad; ++ using BlockLoadWeightT = cub::BlockLoad; ++ using BlockLoadWeightVecT = cub::BlockLoad; ++ using BlockStoreT = cub::BlockStore; ++ using BlockStoreVecT = cub::BlockStore; ++ // using BlockScanT = cub::BlockScan; ++ // using BlockScanT = cub::BlockScan; ++ using BlockScanT = cub::BlockScan; ++ static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage), ++ sizeof(typename BlockLoadVecT::TempStorage), ++ (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), ++ (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), ++ sizeof(typename BlockStoreT::TempStorage), ++ sizeof(typename BlockStoreVecT::TempStorage)}); ++ static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); ++}; ++ ++template ++__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) ++void selective_scan_fwd_kernel(SSMParamsBase params) { ++ constexpr bool kIsVariableB = Ktraits::kIsVariableB; ++ constexpr bool kIsVariableC = Ktraits::kIsVariableC; ++ constexpr bool kHasZ = Ktraits::kHasZ; ++ constexpr bool kVarlen = Ktraits::kVarlen; ++ constexpr int kNThreads = Ktraits::kNThreads; ++ constexpr int kNItems = Ktraits::kNItems; ++ constexpr int kNRows = Ktraits::kNRows; ++ constexpr bool kDirectIO = Ktraits::kDirectIO; ++ using input_t = typename Ktraits::input_t; ++ using weight_t = typename Ktraits::weight_t; ++ using scan_t = typename Ktraits::scan_t; ++ ++ // Shared memory. ++ extern __shared__ char smem_[]; ++ // cast to lvalue reference of expected type ++ // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); ++ // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); ++ // auto& smem_load = reinterpret_cast(smem_loadstorescan); ++ auto& smem_load = reinterpret_cast(smem_); ++ auto& smem_load_weight = reinterpret_cast(smem_); ++ auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); ++ auto& smem_store = reinterpret_cast(smem_); ++ auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); ++ // weight_t *smem_a = reinterpret_cast(smem_ + smem_loadstorescan_size); ++ // weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); ++ scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); ++ ++ const int batch_id = blockIdx.x; ++ const int dim_id = blockIdx.y; ++ const int group_id = dim_id / (params.dim_ngroups_ratio); ++ int seqlen = params.seqlen; ++ int sequence_start_index = batch_id; ++ if constexpr (kVarlen){ ++ int *query_start_loc = reinterpret_cast(params.query_start_loc_ptr); ++ sequence_start_index = query_start_loc[batch_id]; ++ seqlen = query_start_loc[batch_id + 1] - sequence_start_index; ++ } ++ const bool has_initial_state = params.has_initial_state_ptr == nullptr ? false ++ : reinterpret_cast(params.has_initial_state_ptr)[batch_id]; ++ ++ const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr ++ : reinterpret_cast(params.cache_indices_ptr); ++ const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; ++ // cache_index == params.pad_slot_id is defined as padding, so we exit early ++ if (cache_index == params.pad_slot_id){ ++ return; ++ } ++ input_t *u = reinterpret_cast(params.u_ptr) + sequence_start_index * params.u_batch_stride ++ + dim_id * kNRows * params.u_d_stride; ++ input_t *delta = reinterpret_cast(params.delta_ptr) + sequence_start_index * params.delta_batch_stride ++ + dim_id * kNRows * params.delta_d_stride; ++ weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; ++ weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; ++ input_t *Bvar = reinterpret_cast(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride; ++ weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; ++ input_t *Cvar = reinterpret_cast(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride; ++ input_t *ssm_states = reinterpret_cast(params.ssm_states_ptr) + (cache_index * params.dim + dim_id * kNRows) * params.dstate; ++ ++ float D_val[kNRows] = {0}; ++ if (params.D_ptr != nullptr) { ++ #pragma unroll ++ for (int r = 0; r < kNRows; ++r) { ++ D_val[r] = reinterpret_cast(params.D_ptr)[dim_id * kNRows + r]; ++ } ++ } ++ float delta_bias[kNRows] = {0}; ++ if (params.delta_bias_ptr != nullptr) { ++ #pragma unroll ++ for (int r = 0; r < kNRows; ++r) { ++ delta_bias[r] = reinterpret_cast(params.delta_bias_ptr)[dim_id * kNRows + r]; ++ } ++ } ++ ++ ++ // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { ++ // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; ++ // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; ++ // } ++ ++ constexpr int kChunkSize = kNThreads * kNItems; ++ const int n_chunks = (seqlen + 2048 - 1) / 2048; ++ for (int chunk = 0; chunk < n_chunks; ++chunk) { ++ input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; ++ ++ __syncthreads(); ++ #pragma unroll ++ for (int r = 0; r < kNRows; ++r) { ++ if constexpr (!kDirectIO) { ++ if (r > 0) { __syncthreads(); } ++ } ++ load_input(u + r * params.u_d_stride, u_vals[r], smem_load, seqlen - chunk * kChunkSize); ++ if constexpr (!kDirectIO) { __syncthreads(); } ++ load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, seqlen - chunk * kChunkSize); ++ } ++ u += kChunkSize; ++ delta += kChunkSize; ++ ++ float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; ++ #pragma unroll ++ for (int r = 0; r < kNRows; ++r) { ++ #pragma unroll ++ for (int i = 0; i < kNItems; ++i) { ++ float u_val = float(u_vals[r][i]); ++ delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; ++ if (params.delta_softplus) { ++ delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; ++ } ++ delta_u_vals[r][i] = delta_vals[r][i] * u_val; ++ out_vals[r][i] = D_val[r] * u_val; ++ } ++ } ++ ++ __syncthreads(); ++ for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { ++ weight_t A_val[kNRows]; ++ #pragma unroll ++ for (int r = 0; r < kNRows; ++r) { ++ A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; ++ // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. ++ constexpr float kLog2e = M_LOG2E; ++ A_val[r] *= kLog2e; ++ } ++ // This variable holds B * C if both B and C are constant across seqlen. If only B varies ++ // across seqlen, this holds C. If only C varies across seqlen, this holds B. ++ // If both B and C vary, this is unused. ++ weight_t BC_val[kNRows]; ++ weight_t B_vals[kNItems], C_vals[kNItems]; ++ if constexpr (kIsVariableB) { ++ load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, ++ smem_load_weight, (seqlen - chunk * kChunkSize) * (1)); ++ if constexpr (!kIsVariableC) { ++ #pragma unroll ++ for (int r = 0; r < kNRows; ++r) { ++ BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; ++ } ++ } ++ } ++ if constexpr (kIsVariableC) { ++ auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; ++ load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, ++ smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1 )); ++ if constexpr (!kIsVariableB) { ++ #pragma unroll ++ for (int r = 0; r < kNRows; ++r) { ++ BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; ++ } ++ } ++ } ++ if constexpr (!kIsVariableB && !kIsVariableC) { ++ #pragma unroll ++ for (int r = 0; r < kNRows; ++r) { ++ BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; ++ } ++ } ++ ++ #pragma unroll ++ for (int r = 0; r < kNRows; ++r) { ++ if (r > 0) { __syncthreads(); } // Scan could be using the same smem ++ scan_t thread_data[kNItems]; ++ #pragma unroll ++ for (int i = 0; i < kNItems; ++i) { ++ thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), ++ !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); ++ ++ if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct ++ if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) { ++ thread_data[i] = make_float2(1.f, 0.f); ++ } ++ } ++ } ++ // Initialize running total ++ ++ scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx]): 0.0); ++ ++ SSMScanPrefixCallbackOp prefix_op(running_prefix); ++ typename Ktraits::BlockScanT(smem_scan).InclusiveScan( ++ thread_data, thread_data, SSMScanOp(), prefix_op ++ ); ++ // There's a syncthreads in the scan op, so we don't need to sync here. ++ // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. ++ if (threadIdx.x == 0) { ++ smem_running_prefix[state_idx] = prefix_op.running_prefix; ++ if (chunk == n_chunks - 1) { ++ ssm_states[state_idx] = input_t(prefix_op.running_prefix.y); ++ } ++ } ++ #pragma unroll ++ for (int i = 0; i < kNItems; ++i) { ++ const weight_t C_val = !kIsVariableC ++ ? BC_val[r] ++ : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); ++ out_vals[r][i] += thread_data[i].y * C_val; ++ } ++ } ++ } ++ ++ input_t *out = reinterpret_cast(params.out_ptr) + sequence_start_index * params.out_batch_stride ++ + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; ++ __syncthreads(); ++ #pragma unroll ++ for (int r = 0; r < kNRows; ++r) { ++ if constexpr (!kDirectIO) { ++ if (r > 0) { __syncthreads(); } ++ } ++ store_output(out + r * params.out_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize); ++ } ++ ++ if constexpr (kHasZ) { ++ input_t *z = reinterpret_cast(params.z_ptr) + sequence_start_index * params.z_batch_stride ++ + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; ++ input_t *out_z = reinterpret_cast(params.out_z_ptr) + sequence_start_index * params.out_z_batch_stride ++ + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; ++ #pragma unroll ++ for (int r = 0; r < kNRows; ++r) { ++ input_t z_vals[kNItems]; ++ __syncthreads(); ++ load_input(z + r * params.z_d_stride, z_vals, smem_load, seqlen - chunk * kChunkSize); ++ #pragma unroll ++ for (int i = 0; i < kNItems; ++i) { ++ float z_val = z_vals[i]; ++ out_vals[r][i] *= z_val / (1 + expf(-z_val)); ++ } ++ __syncthreads(); ++ store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize); ++ } ++ } ++ ++ Bvar += kChunkSize * 1; ++ Cvar += kChunkSize * 1; ++ } ++} ++ ++template ++void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { ++ // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block ++ // processing 1 row. ++ constexpr int kNRows = 1; ++ // kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size ++ constexpr bool kIsVariableB = true; ++ constexpr bool kIsVariableC = true; ++ constexpr bool kHasZ = true; ++ BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { ++ BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] { ++ using Ktraits = Selective_Scan_fwd_kernel_traits; ++ constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); ++ dim3 grid(params.batch, params.dim / kNRows); ++ auto kernel = &selective_scan_fwd_kernel; ++ if (kSmemSize >= 48 * 1024) { ++ C10_CUDA_CHECK(cudaFuncSetAttribute( ++ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); ++ } ++ kernel<<>>(params); ++ C10_CUDA_KERNEL_LAUNCH_CHECK(); ++ }); ++ }); ++} ++ ++template ++void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { ++ ++ #ifndef USE_ROCM ++ if (params.seqlen <= 128) { ++ selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); ++ } else if (params.seqlen <= 256) { ++ selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); ++ } else if (params.seqlen <= 512) { ++ selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); ++ } else if (params.seqlen <= 1024) { ++ selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); ++ } else { ++ selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); ++ } ++ #else ++ if (params.seqlen <= 256) { ++ selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream); ++ } else if (params.seqlen <= 512) { ++ selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream); ++ } else if (params.seqlen <= 1024) { ++ selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); ++ } else { ++ selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); ++ } ++ #endif ++} ++ ++template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); ++template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); ++template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); ++ ++#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") ++ ++#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ ++ if (ITYPE == at::ScalarType::Half) { \ ++ using input_t = at::Half; \ ++ using weight_t = float; \ ++ __VA_ARGS__(); \ ++ } else if (ITYPE == at::ScalarType::BFloat16) { \ ++ using input_t = at::BFloat16; \ ++ using weight_t = float; \ ++ __VA_ARGS__(); \ ++ } else if (ITYPE == at::ScalarType::Float) { \ ++ using input_t = float; \ ++ using weight_t = float; \ ++ __VA_ARGS__(); \ ++ } else { \ ++ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ ++ } ++ ++ ++template ++void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); ++ ++void set_ssm_params_fwd(SSMParamsBase ¶ms, ++ // sizes ++ const size_t batch, ++ const size_t dim, ++ const size_t seqlen, ++ const size_t dstate, ++ const size_t n_groups, ++ const bool is_variable_B, ++ const bool is_variable_C, ++ // device pointers ++ const torch::Tensor u, ++ const torch::Tensor delta, ++ const torch::Tensor A, ++ const torch::Tensor B, ++ const torch::Tensor C, ++ const torch::Tensor out, ++ const torch::Tensor z, ++ const torch::Tensor out_z, ++ const std::optional& D, ++ const std::optional& delta_bias, ++ const torch::Tensor ssm_states, ++ bool has_z, ++ bool delta_softplus, ++ const std::optional& query_start_loc, ++ const std::optional& cache_indices, ++ const std::optional& has_initial_state, ++ bool varlen, ++ int64_t pad_slot_id) { ++ ++ // Reset the parameters ++ memset(¶ms, 0, sizeof(params)); ++ ++ params.batch = batch; ++ params.dim = dim; ++ params.seqlen = seqlen; ++ params.dstate = dstate; ++ params.n_groups = n_groups; ++ params.dim_ngroups_ratio = dim / n_groups; ++ params.pad_slot_id = pad_slot_id; ++ ++ params.delta_softplus = delta_softplus; ++ ++ params.is_variable_B = is_variable_B; ++ params.is_variable_C = is_variable_C; ++ ++ // Set the pointers and strides. ++ params.u_ptr = u.data_ptr(); ++ params.delta_ptr = delta.data_ptr(); ++ params.A_ptr = A.data_ptr(); ++ params.B_ptr = B.data_ptr(); ++ params.C_ptr = C.data_ptr(); ++ params.D_ptr = D.has_value() ? D.value().data_ptr() : nullptr; ++ params.delta_bias_ptr = delta_bias.has_value() ? delta_bias.value().data_ptr() : nullptr; ++ params.out_ptr = out.data_ptr(); ++ params.ssm_states_ptr = ssm_states.data_ptr(); ++ params.z_ptr = has_z ? z.data_ptr() : nullptr; ++ params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; ++ params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr; ++ params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr; ++ params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr; ++ ++ ++ // All stride are in elements, not bytes. ++ params.A_d_stride = A.stride(0); ++ params.A_dstate_stride = A.stride(1); ++ ++ if (varlen){ ++ params.B_batch_stride = B.stride(2); ++ params.B_group_stride = B.stride(0); ++ params.B_dstate_stride = B.stride(1); ++ params.C_batch_stride = C.stride(2); ++ params.C_group_stride = C.stride(0); ++ params.C_dstate_stride = C.stride(1); ++ ++ params.u_batch_stride = u.stride(1); ++ params.u_d_stride = u.stride(0); ++ params.delta_batch_stride = delta.stride(1); ++ params.delta_d_stride = delta.stride(0); ++ if (has_z) { ++ params.z_batch_stride = z.stride(1); ++ params.z_d_stride = z.stride(0); ++ params.out_z_batch_stride = out_z.stride(1); ++ params.out_z_d_stride = out_z.stride(0); ++ } ++ params.out_batch_stride = out.stride(1); ++ params.out_d_stride = out.stride(0); ++ ++ } ++ else{ ++ if (!is_variable_B) { ++ params.B_d_stride = B.stride(0); ++ } else { ++ params.B_batch_stride = B.stride(0); ++ params.B_group_stride = B.stride(1); ++ } ++ params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); ++ if (!is_variable_C) { ++ params.C_d_stride = C.stride(0); ++ } else { ++ params.C_batch_stride = C.stride(0); ++ params.C_group_stride = C.stride(1); ++ } ++ params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); ++ params.u_batch_stride = u.stride(0); ++ params.u_d_stride = u.stride(1); ++ params.delta_batch_stride = delta.stride(0); ++ params.delta_d_stride = delta.stride(1); ++ if (has_z) { ++ params.z_batch_stride = z.stride(0); ++ params.z_d_stride = z.stride(1); ++ params.out_z_batch_stride = out_z.stride(0); ++ params.out_z_d_stride = out_z.stride(1); ++ } ++ params.out_batch_stride = out.stride(0); ++ params.out_d_stride = out.stride(1); ++ } ++} ++ ++void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, ++ const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C, ++ const std::optional &D_, ++ const std::optional &z_, ++ const std::optional &delta_bias_, ++ bool delta_softplus, ++ const std::optional &query_start_loc, ++ const std::optional &cache_indices, ++ const std::optional &has_initial_state, ++ const torch::Tensor &ssm_states, ++ // used to identify padding entries if cache_indices provided ++ // in case of padding, the kernel will return early ++ int64_t pad_slot_id) { ++ auto input_type = u.scalar_type(); ++ auto weight_type = A.scalar_type(); ++ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); ++ TORCH_CHECK(weight_type == at::ScalarType::Float); ++ ++ const bool is_variable_B = B.dim() >= 3; ++ const bool is_variable_C = C.dim() >= 3; ++ ++ TORCH_CHECK(delta.scalar_type() == input_type); ++ TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); ++ TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); ++ ++ TORCH_CHECK(u.is_cuda()); ++ TORCH_CHECK(delta.is_cuda()); ++ TORCH_CHECK(A.is_cuda()); ++ TORCH_CHECK(B.is_cuda()); ++ TORCH_CHECK(C.is_cuda()); ++ ++ TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); ++ TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); ++ ++ const auto sizes = u.sizes(); ++ const bool varlen = query_start_loc.has_value(); ++ const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0]; ++ const int dim = varlen ? sizes[0] : sizes[1]; ++ const int seqlen = varlen ? sizes[1] : sizes[2]; ++ const int dstate = A.size(1); ++ const int n_groups = varlen ? B.size(0) : B.size(1); ++ ++ TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); ++ ++ if (varlen) { ++ CHECK_SHAPE(u, dim, seqlen); ++ CHECK_SHAPE(delta, dim, seqlen); ++ } else { ++ CHECK_SHAPE(u, batch_size, dim, seqlen); ++ CHECK_SHAPE(delta, batch_size, dim, seqlen); ++ } ++ CHECK_SHAPE(A, dim, dstate); ++ TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size") ++ if (varlen) { ++ CHECK_SHAPE(B, n_groups, dstate, seqlen); ++ } else { ++ CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen); ++ } ++ TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); ++ ++ TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size") ++ if (varlen) { ++ CHECK_SHAPE(C, n_groups, dstate, seqlen); ++ } else { ++ CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen); ++ } ++ TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); ++ ++ if (D_.has_value()) { ++ auto D = D_.value(); ++ TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); ++ TORCH_CHECK(D.is_cuda()); ++ TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); ++ CHECK_SHAPE(D, dim); ++ } ++ ++ if (delta_bias_.has_value()) { ++ auto delta_bias = delta_bias_.value(); ++ TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); ++ TORCH_CHECK(delta_bias.is_cuda()); ++ TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); ++ CHECK_SHAPE(delta_bias, dim); ++ } ++ ++ ++ if (has_initial_state.has_value()) { ++ auto has_initial_state_ = has_initial_state.value(); ++ TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool); ++ TORCH_CHECK(has_initial_state_.is_cuda()); ++ CHECK_SHAPE(has_initial_state_, batch_size); ++ } ++ ++ ++ if (query_start_loc.has_value()) { ++ auto query_start_loc_ = query_start_loc.value(); ++ TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int); ++ TORCH_CHECK(query_start_loc_.is_cuda()); ++ } ++ ++ ++ if (cache_indices.has_value()) { ++ auto cache_indices_ = cache_indices.value(); ++ TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int); ++ TORCH_CHECK(cache_indices_.is_cuda()); ++ CHECK_SHAPE(cache_indices_, batch_size); ++ } ++ ++ ++ at::Tensor z, out_z; ++ const bool has_z = z_.has_value(); ++ TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size") ++ z = z_.value(); ++ TORCH_CHECK(z.scalar_type() == input_type); ++ TORCH_CHECK(z.is_cuda()); ++ TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); ++ if (varlen){ ++ CHECK_SHAPE(z, dim, seqlen); ++ } else { ++ CHECK_SHAPE(z, batch_size, dim, seqlen); ++ } ++ ++ out_z = z; ++ ++ // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout ++ at::Tensor out = delta; ++ TORCH_CHECK(ssm_states.scalar_type() == input_type); ++ TORCH_CHECK(ssm_states.is_cuda()); ++ TORCH_CHECK(ssm_states.stride(-1) == 1); ++ ++ SSMParamsBase params; ++ set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C, ++ u, delta, A, B, C, out, z, out_z, ++ D_, ++ delta_bias_, ++ ssm_states, ++ has_z, ++ delta_softplus, ++ query_start_loc, ++ cache_indices, ++ has_initial_state, ++ varlen, ++ pad_slot_id ++ ); ++ ++ ++ // Otherwise the kernel will be launched from cuda:0 device ++ // Cast to char to avoid compiler warning about narrowing ++ at::cuda::CUDAGuard device_guard{(char)u.get_device()}; ++ auto stream = at::cuda::getCurrentCUDAStream().stream(); ++ DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { ++ selective_scan_fwd_cuda(params, stream); ++ }); ++} ++ +diff --git a/csrc/mamba/mamba_ssm/static_switch.h b/csrc/mamba/mamba_ssm/static_switch.h +new file mode 100644 +index 0000000..840cb23 +--- /dev/null ++++ b/csrc/mamba/mamba_ssm/static_switch.h +@@ -0,0 +1,28 @@ ++// Inspired by ++// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h ++// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h ++ ++// clang-format off ++// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/static_switch.h ++#pragma once ++ ++/// @param COND - a boolean expression to switch by ++/// @param CONST_NAME - a name given for the constexpr bool variable. ++/// @param ... - code to execute for true and false ++/// ++/// Usage: ++/// ``` ++/// BOOL_SWITCH(flag, BoolConst, [&] { ++/// some_function(...); ++/// }); ++/// ``` ++#define BOOL_SWITCH(COND, CONST_NAME, ...) \ ++ [&] { \ ++ if (COND) { \ ++ constexpr bool CONST_NAME = true; \ ++ return __VA_ARGS__(); \ ++ } else { \ ++ constexpr bool CONST_NAME = false; \ ++ return __VA_ARGS__(); \ ++ } \ ++ }() +diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel.h b/csrc/moe/marlin_kernels/marlin_moe_kernel.h +new file mode 100644 +index 0000000..a217401 +--- /dev/null ++++ b/csrc/moe/marlin_kernels/marlin_moe_kernel.h +@@ -0,0 +1,1616 @@ ++#pragma once ++ ++#include ++ ++#include ++#include ++#include ++#include ++#include ++ ++#include ++ ++#include "core/scalar_type.hpp" ++ ++namespace marlin_moe { ++ ++constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } ++ ++#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 ++ ++// Instances of `Vec` are used to organize groups of >>registers<<, as needed ++// for instance as inputs to tensor core operations. Consequently, all ++// corresponding index accesses must be compile-time constants, which is why we ++// extensively use `#pragma unroll` throughout the kernel code to guarantee ++// this. ++template ++struct Vec { ++ T elems[n]; ++ __device__ T& operator[](int i) { return elems[i]; } ++}; ++ ++using I4 = Vec; ++ ++// Matrix fragments for tensor core instructions; their precise layout is ++// documented here: ++// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type ++using FragA = Vec; ++using FragB = Vec; ++using FragC = Vec; ++using FragS = Vec; // quantization scales ++using FragZP = Vec; ++ ++// Predicated asynchronous global->shared copy; used for inputs A where we apply ++// predication to handle batchsizes that are not multiples of 16. ++__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, ++ bool pred = true) { ++ const int BYTES = 16; ++ uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); ++ asm volatile( ++ "{\n" ++ " .reg .pred p;\n" ++ " setp.ne.b32 p, %0, 0;\n" ++ " @p cp.async.cg.shared.global [%1], [%2], %3;\n" ++ "}\n" ::"r"((int)pred), ++ "r"(smem), "l"(glob_ptr), "n"(BYTES)); ++} ++ ++// Asynchronous global->shared copy ++__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { ++ const int BYTES = 16; ++ uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); ++ asm volatile( ++ "{\n" ++ " cp.async.cg.shared.global [%0], [%1], %2;\n" ++ "}\n" ::"r"(smem), ++ "l"(glob_ptr), "n"(BYTES)); ++} ++ ++// Async copy fence. ++__device__ inline void cp_async_fence() { ++ asm volatile("cp.async.commit_group;\n" ::); ++} ++ ++// Wait until at most `n` async copy stages are still pending. ++template ++__device__ inline void cp_async_wait() { ++ asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); ++} ++ ++// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 ++// output/accumulation. ++__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, ++ FragC& frag_c) { ++ const uint32_t* a = reinterpret_cast(&a_frag); ++ const uint32_t* b = reinterpret_cast(&frag_b); ++ float* c = reinterpret_cast(&frag_c); ++ asm volatile( ++ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " ++ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" ++ : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) ++ : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), ++ "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); ++} ++ ++// Instruction for loading a full 16x16 matrix fragment of operand A from shared ++// memory, directly in tensor core layout. ++__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { ++ uint32_t* a = reinterpret_cast(&frag_a); ++ uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); ++ asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" ++ : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) ++ : "r"(smem)); ++} ++ ++// Lookup-table based 3-input logical operation; explicitly used for ++// dequantization as the compiler does not seem to automatically recognize it in ++// all cases. ++template ++__device__ inline int lop3(int a, int b, int c) { ++ int res; ++ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" ++ : "=r"(res) ++ : "r"(a), "r"(b), "r"(c), "n"(lut)); ++ return res; ++} ++ ++// Constructs destination register by taking bytes from 2 sources (based on ++// mask) ++template ++__device__ inline uint32_t prmt(uint32_t a) { ++ uint32_t res; ++ asm volatile("prmt.b32 %0, %1, %2, %3;\n" ++ : "=r"(res) ++ : "r"(a), "n"(start_byte), "n"(mask)); ++ return res; ++} ++ ++template ++__device__ inline FragB dequant(int q); ++ ++// Efficiently dequantize 4bit values packed in an int32 value into a full ++// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, ++// with some small changes: ++// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 ++template <> ++__device__ inline FragB dequant(int q) { ++ const int LO = 0x000f000f; ++ const int HI = 0x00f000f0; ++ const int EX = 0x64006400; ++ // Guarantee that the `(a & b) | c` operations are LOP3s. ++ int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); ++ int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); ++ // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point ++ // directly into `SUB` and `ADD`. ++ const int SUB = 0x64086408; ++ const int MUL = 0x2c002c00; ++ const int ADD = 0xd480d480; ++ FragB frag_b; ++ frag_b[0] = __hsub2(*reinterpret_cast(&lo), ++ *reinterpret_cast(&SUB)); ++ frag_b[1] = __hfma2(*reinterpret_cast(&hi), ++ *reinterpret_cast(&MUL), ++ *reinterpret_cast(&ADD)); ++ return frag_b; ++} ++ ++// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 ++// Reference: ++// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 ++template <> ++__device__ inline FragB dequant(int q) { ++ static constexpr uint32_t mask_for_elt_01 = 0x5250; ++ static constexpr uint32_t mask_for_elt_23 = 0x5351; ++ static constexpr uint32_t start_byte_for_fp16 = 0x64646464; ++ ++ uint32_t lo = prmt(q); ++ uint32_t hi = prmt(q); ++ ++ static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; ++ ++ FragB frag_b; ++ frag_b[0] = __hsub2(*reinterpret_cast(&lo), ++ *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); ++ frag_b[1] = __hsub2(*reinterpret_cast(&hi), ++ *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); ++ return frag_b; ++} ++ ++template <> ++__device__ inline FragB dequant(int q) { ++ const int LO = 0x000f000f; ++ const int HI = 0x00f000f0; ++ const int EX = 0x64006400; ++ // Guarantee that the `(a & b) | c` operations are LOP3s. ++ int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); ++ int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); ++ ++ const int SUB = 0x64006400; ++ const int MUL = 0x2c002c00; ++ const int ADD = 0xd400d400; ++ FragB frag_b; ++ frag_b[0] = __hsub2(*reinterpret_cast(&lo), ++ *reinterpret_cast(&SUB)); ++ frag_b[1] = __hfma2(*reinterpret_cast(&hi), ++ *reinterpret_cast(&MUL), ++ *reinterpret_cast(&ADD)); ++ return frag_b; ++} ++ ++template <> ++__device__ inline FragB dequant(int q) { ++ static constexpr uint32_t mask_for_elt_01 = 0x5250; ++ static constexpr uint32_t mask_for_elt_23 = 0x5351; ++ static constexpr uint32_t start_byte_for_fp16 = 0x64646464; ++ ++ uint32_t lo = prmt(q); ++ uint32_t hi = prmt(q); ++ ++ static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; ++ ++ FragB frag_b; ++ frag_b[0] = __hsub2(*reinterpret_cast(&lo), ++ *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); ++ frag_b[1] = __hsub2(*reinterpret_cast(&hi), ++ *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); ++ return frag_b; ++} ++ ++// Multiply dequantized values by the corresponding quantization scale; used ++// only for grouped quantization. ++__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { ++ half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); ++ frag_b[0] = __hmul2(frag_b[0], s); ++ frag_b[1] = __hmul2(frag_b[1], s); ++} ++ ++__device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) { ++ half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]); ++ frag_b[0] = __hsub2(frag_b[0], zp); ++ frag_b[1] = __hsub2(frag_b[1], zp); ++} ++ ++// Same as above, but for act_order (each K is multiplied individually) ++__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, ++ FragS& frag_s_3, FragS& frag_s_4, int i) { ++ __half2 s_val_1_2; ++ s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; ++ s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; ++ ++ __half2 s_val_3_4; ++ s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; ++ s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; ++ ++ frag_b[0] = __hmul2(frag_b[0], s_val_1_2); ++ frag_b[1] = __hmul2(frag_b[1], s_val_3_4); ++} ++ ++// Given 2 floats multiply by 2 scales (halves) ++__device__ inline void scale_float(float* c, FragS& s) { ++ __half* s_ptr = reinterpret_cast<__half*>(&s); ++ c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); ++ c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); ++} ++ ++// Wait until barrier reaches `count`, then lock for current threadblock. ++__device__ inline void barrier_acquire(int* lock, int count) { ++ if (threadIdx.x == 0) { ++ int state = -1; ++ do ++ // Guarantee that subsequent writes by this threadblock will be visible ++ // globally. ++ asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" ++ : "=r"(state) ++ : "l"(lock)); ++ while (state != count); ++ } ++ __syncthreads(); ++} ++ ++// Release barrier and increment visitation count. ++__device__ inline void barrier_release(int* lock, bool reset = false) { ++ __syncthreads(); ++ if (threadIdx.x == 0) { ++ if (reset) { ++ lock[0] = 0; ++ return; ++ } ++ int val = 1; ++ // Make sure that all writes since acquiring this barrier are visible ++ // globally, while releasing the barrier. ++ asm volatile("fence.acq_rel.gpu;\n"); ++ asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" ++ : ++ : "l"(lock), "r"(val)); ++ } ++} ++ ++template shared ++ // fetch pipeline ++ const bool has_act_order, // whether act_order is enabled ++ const bool has_zp, // whether zero-points are enabled ++ const int group_blocks = -1 // number of consecutive 16x16 blocks ++ // with a separate quantization scale ++ > ++__device__ void MarlinMoESingle( ++ const int4* __restrict__ A, // fp16 input matrix of shape mxk ++ const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn ++ int4* __restrict__ C, // fp16 output buffer of shape mxn ++ const int* __restrict__ sorted_ids, // int32 sorted ids of experts ++ const float* __restrict__ topk_weights, // float topk weights ++ const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape ++ // (k/groupsize)xn ++ const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape ++ // (k/groupsize)x(n/pack_factor) ++ const int* __restrict__ g_idx, // int32 group indices of shape k ++ const int* __restrict__ expert_offsets, ++ int num_groups, // number of scale groups per output channel ++ int expert_idx, // idx of current expert ++ int num_experts, // number of experts ++ int topk, // topk parameter of moe ++ int prob_m, // batch dimension m ++ int prob_n, // output dimension n ++ int prob_k, // reduction dimension k ++ int tot_m, // total number of rows in A and C ++ int* locks, // extra global storage for barrier synchronization ++ bool replicate_input, // do we use the same input for each expert? ++ bool apply_weights, // apply weights to output ++ int current_m_block // current m block to start kernel computation from ++) { ++ static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); ++ constexpr int pack_factor = 32 / w_type.size_bits(); ++ ++ // For larger GEMMs we run multiple batchsize 64 versions in parallel for a ++ // better partitioning with less reductions ++ int parallel = 1; ++ if (prob_m > 16 * thread_m_blocks) { ++ parallel = prob_m / (16 * thread_m_blocks); ++ prob_m = 16 * thread_m_blocks; ++ } ++ ++ int k_tiles = prob_k / 16 / thread_k_blocks; ++ int n_tiles = prob_n / 16 / thread_n_blocks; ++ int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); ++ ++ if constexpr (!has_act_order && group_blocks != -1) { ++ if (group_blocks >= thread_k_blocks) { ++ // Ensure that the number of tiles in each stripe is a multiple of the ++ // groupsize; this avoids an annoying special case where a stripe starts ++ // in the middle of group. ++ iters = (group_blocks / thread_k_blocks) * ++ ceildiv(iters, (group_blocks / thread_k_blocks)); ++ } ++ } ++ ++ int slice_row = (iters * blockIdx.x) % k_tiles; ++ int slice_col_par = (iters * blockIdx.x) / k_tiles; ++ int slice_col = slice_col_par; ++ int slice_iters; // number of threadblock tiles in the current slice ++ int slice_count = ++ 0; // total number of active threadblocks in the current slice ++ int slice_idx; // index of threadblock in current slice; numbered bottom to ++ // top ++ ++ // We can easily implement parallel problem execution by just remapping ++ // indices and advancing global pointers ++ if (slice_col_par >= n_tiles) { ++ locks += (slice_col_par / n_tiles) * n_tiles; ++ slice_col = slice_col_par % n_tiles; ++ sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; ++ } ++ ++ // Compute all information about the current slice which is required for ++ // synchronization. ++ auto init_slice = [&]() { ++ slice_iters = ++ iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); ++ if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; ++ if (slice_iters == 0) return; ++ if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; ++ slice_count = 1; ++ slice_idx = 0; ++ int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); ++ if (col_first <= k_tiles * (slice_col_par + 1)) { ++ int col_off = col_first - k_tiles * slice_col_par; ++ slice_count = ceildiv(k_tiles - col_off, iters); ++ if (col_off > 0) slice_count++; ++ int delta_first = iters * blockIdx.x - col_first; ++ if (delta_first < 0 || (col_off == 0 && delta_first == 0)) ++ slice_idx = slice_count - 1; ++ else { ++ slice_idx = slice_count - 1 - delta_first / iters; ++ if (col_off > 0) slice_idx--; ++ } ++ } ++ if (slice_col == n_tiles) { ++ sorted_ids += 16 * thread_m_blocks; ++ locks += n_tiles; ++ slice_col = 0; ++ } ++ }; ++ init_slice(); ++ ++ // A sizes/strides ++ ++ // stride of the A matrix in global memory ++ int a_gl_stride = prob_k / 8; ++ // stride of an A matrix tile in shared memory ++ constexpr int a_sh_stride = 16 * thread_k_blocks / 8; ++ // delta between subsequent A tiles in global memory ++ constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; ++ // between subsequent accesses within a tile ++ int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); ++ // between shared memory writes ++ constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); ++ // between shared memory tile reads ++ constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); ++ // within a shared memory tile ++ constexpr int a_sh_rd_delta_i = a_sh_stride * 16; ++ // overall size of a tile ++ constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); ++ // number of shared write iterations for a tile ++ constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); ++ ++ // B sizes/strides ++ int b_gl_stride = 16 * prob_n / (pack_factor * 4); ++ constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; ++ constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; ++ constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; ++ ++ int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; ++ int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); ++ constexpr int b_sh_wr_delta = threads * b_thread_vecs; ++ constexpr int b_sh_rd_delta = threads * b_thread_vecs; ++ constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; ++ constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; ++ ++ // Scale sizes/strides without act_order ++ int s_gl_stride = prob_n / 8; ++ constexpr int s_sh_stride = 16 * thread_n_blocks / 8; ++ constexpr int s_tb_groups = ++ !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ++ ? thread_k_blocks / group_blocks ++ : 1; ++ constexpr int s_sh_stage = s_tb_groups * s_sh_stride; ++ int s_gl_rd_delta = s_gl_stride; ++ // Scale size/strides with act_order ++ constexpr int tb_k = 16 * thread_k_blocks; ++ constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; ++ // constexpr int act_s_row_stride = 1; ++ // int act_s_col_stride = act_s_row_stride * num_groups; ++ int act_s_col_stride = 1; ++ int act_s_col_warp_stride = act_s_col_stride * 8; ++ int tb_n_warps = thread_n_blocks / 4; ++ int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; ++ ++ // Zero-points sizes/strides ++ int zp_gl_stride = (prob_n / pack_factor) / 4; ++ constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; ++ constexpr int zp_tb_groups = s_tb_groups; ++ constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; ++ int zp_gl_rd_delta = zp_gl_stride; ++ ++ // Global A read index of current thread. ++ int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + ++ (threadIdx.x % a_gl_rd_delta_o); ++ a_gl_rd += a_gl_rd_delta_o * slice_row; ++ // Shared write index of current thread. ++ int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + ++ (threadIdx.x % a_gl_rd_delta_o); ++ // Shared read index. ++ int a_sh_rd = ++ a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; ++ a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); ++ ++ int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + ++ (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; ++ b_gl_rd += b_sh_stride * slice_col; ++ b_gl_rd += b_gl_rd_delta_o * slice_row; ++ int b_sh_wr = threadIdx.x * b_thread_vecs; ++ int b_sh_rd = threadIdx.x * b_thread_vecs; ++ ++ // For act_order ++ constexpr int k_iter_size = tb_k / b_sh_wr_iters; ++ int slice_k_start = tb_k * slice_row; ++ int slice_k_finish = slice_k_start + tb_k * slice_iters; ++ int slice_k_start_shared_fetch = slice_k_start; ++ int slice_n_offset = act_s_col_tb_stride * slice_col; ++ ++ // No act_order ++ int s_gl_rd; ++ if constexpr (!has_act_order) { ++ if constexpr (group_blocks == -1) { ++ s_gl_rd = s_sh_stride * slice_col + threadIdx.x; ++ } else { ++ s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + ++ s_sh_stride * slice_col + threadIdx.x; ++ } ++ } ++ int s_sh_wr = threadIdx.x; ++ bool s_sh_wr_pred = threadIdx.x < s_sh_stride; ++ ++ // Zero-points ++ int zp_gl_rd; ++ if constexpr (has_zp) { ++ if constexpr (group_blocks == -1) { ++ zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; ++ } else { ++ zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + ++ zp_sh_stride * slice_col + threadIdx.x; ++ } ++ } ++ int zp_sh_wr = threadIdx.x; ++ bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; ++ ++ // We use a different scale layout for grouped and column-wise quantization as ++ // we scale a `half2` tile in column-major layout in the former and in ++ // row-major in the latter case. ++ int s_sh_rd; ++ if constexpr (group_blocks != -1) ++ s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + ++ (threadIdx.x % 32) / 4; ++ else ++ s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + ++ (threadIdx.x % 32) % 4; ++ ++ // Zero-points have the same read layout as the scales ++ // (without column-wise case) ++ constexpr int num_col_threads = 8; ++ constexpr int num_row_threads = 4; ++ constexpr int num_ints_per_thread = 8 / pack_factor; ++ int zp_sh_rd; ++ if constexpr (has_zp) { ++ zp_sh_rd = num_ints_per_thread * num_col_threads * ++ ((threadIdx.x / 32) % (thread_n_blocks / 4)) + ++ num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); ++ } ++ ++ int sh_first_group_id = -1; ++ int sh_num_groups = -1; ++ constexpr int sh_max_num_groups = 32; ++ ++ extern __shared__ int4 sh[]; ++ // Shared memory storage for global fetch pipelines. ++ int4* sh_a = sh; ++ int4* sh_b = sh_a + (stages * a_sh_stage); ++ int4* sh_g_idx = sh_b + (stages * b_sh_stage); ++ int4* sh_zp = sh_g_idx + (stages * g_idx_stage); ++ int4* sh_s = sh_zp + (stages * zp_sh_stage); ++ ++ // Precompute which thread should not read memory in which iterations; this is ++ // needed if there are more threads than required for a certain tilesize or ++ // when the batchsize is not a multiple of 16. ++ bool a_sh_wr_pred[a_sh_wr_iters]; ++ #pragma unroll ++ for (int i = 0; i < a_sh_wr_iters; i++) { ++ int a_idx = a_sh_wr_delta * i + a_sh_wr; ++ int row = a_idx / a_gl_rd_delta_o; ++ if (row >= prob_m) { ++ a_sh_wr_pred[i] = false; ++ } else { ++ a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; ++ } ++ } ++ ++ // To ensure that writing and reading A tiles to/from shared memory, the ++ // latter in fragment format, is fully bank conflict free, we need to use a ++ // rather fancy XOR-based layout. The key here is that neither reads nor ++ // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the ++ // same shared memory banks. Further, it seems (based on NSight-Compute) that ++ // each warp must also write a consecutive memory segment? ++ auto transform_a = [&](int i) { ++ int row = i / a_gl_rd_delta_o; ++ return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; ++ }; ++ // Since the computation of this remapping is non-trivial and, due to our main ++ // loop unrolls, all shared memory accesses are static, we simply precompute ++ // both transformed reads and writes. ++ int a_sh_wr_trans[a_sh_wr_iters]; ++ #pragma unroll ++ for (int i = 0; i < a_sh_wr_iters; i++) ++ a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); ++ int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; ++ #pragma unroll ++ for (int i = 0; i < b_sh_wr_iters; i++) { ++ #pragma unroll ++ for (int j = 0; j < thread_m_blocks; j++) ++ a_sh_rd_trans[i][j] = ++ transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); ++ } ++ ++ // Since B-accesses have non-constant stride they have to be computed at ++ // runtime; we break dependencies between subsequent accesses with a tile by ++ // maintining multiple pointers (we have enough registers), a tiny ++ // optimization. ++ const int4* B_ptr[b_sh_wr_iters]; ++ #pragma unroll ++ for (int i = 0; i < b_sh_wr_iters; i++) ++ B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; ++ ++ // Register storage for double buffer of shared memory reads. ++ FragA frag_a[2][thread_m_blocks]; ++ I4 frag_b_quant[2][b_thread_vecs]; ++ FragC frag_c[thread_m_blocks][4][2]; ++ FragS frag_s[2][4]; // No act-order ++ FragS act_frag_s[2][4][4]; // For act-order ++ int frag_qzp[2][num_ints_per_thread]; // Zero-points ++ FragZP frag_zp; // Zero-points in fp16 ++ ++ // Zero accumulators. ++ auto zero_accums = [&]() { ++ #pragma unroll ++ for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) ++ reinterpret_cast(frag_c)[i] = 0; ++ }; ++ ++ auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, ++ int last_group_id) { ++ sh_first_group_id = first_group_id; ++ sh_num_groups = last_group_id - first_group_id + 1; ++ ++ if (sh_num_groups < sh_max_num_groups) { ++ sh_num_groups = sh_max_num_groups; ++ } ++ ++ if (sh_first_group_id + sh_num_groups > num_groups) { ++ sh_num_groups = num_groups - sh_first_group_id; ++ } ++ ++ int row_offset = first_group_id * s_gl_stride; ++ ++ if (is_async) { ++ for (int i = 0; i < sh_num_groups; i++) { ++ if (threadIdx.x < s_sh_stride) { ++ cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], ++ &scales_ptr[row_offset + (i * s_gl_stride) + ++ slice_n_offset + threadIdx.x]); ++ } ++ } ++ } else { ++ for (int i = 0; i < sh_num_groups; i++) { ++ if (threadIdx.x < s_sh_stride) { ++ sh_s[(i * s_sh_stride) + threadIdx.x] = ++ scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + ++ threadIdx.x]; ++ } ++ } ++ } ++ }; ++ // Asynchronously fetch the next A, B and s tile from global to the next ++ // shared memory pipeline location. ++ auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { ++ if (pred) { ++ int4* sh_a_stage = sh_a + a_sh_stage * pipe; ++ #pragma unroll ++ for (int i = 0; i < a_sh_wr_iters; i++) { ++ int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; ++ int row = a_idx / a_gl_stride; ++ int sorted_row = ++ replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; ++ int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; ++ if (sorted_row < tot_m * (replicate_input ? 1 : topk) && ++ new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { ++ cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], ++ a_sh_wr_pred[i]); ++ } ++ } ++ int4* sh_b_stage = sh_b + b_sh_stage * pipe; ++ #pragma unroll ++ for (int i = 0; i < b_sh_wr_iters; i++) { ++ #pragma unroll ++ for (int j = 0; j < b_thread_vecs; j++) { ++ cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); ++ } ++ B_ptr[i] += b_gl_rd_delta_o; ++ } ++ ++ if constexpr (has_act_order) { ++ // Fetch g_idx thread-block portion ++ int full_pipe = a_off; ++ int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; ++ if (cur_k < prob_k && cur_k < slice_k_finish) { ++ int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; ++ ++ int4 const* cur_g_idx_stage_ptr = ++ reinterpret_cast(&g_idx[cur_k]); ++ ++ if (threadIdx.x < g_idx_stage) { ++ cp_async4_pred(&sh_g_idx_stage[threadIdx.x], ++ &cur_g_idx_stage_ptr[threadIdx.x]); ++ } ++ } ++ } else { ++ if constexpr (group_blocks != -1) { ++ int4* sh_s_stage = sh_s + s_sh_stage * pipe; ++ ++ if constexpr (group_blocks >= thread_k_blocks) { ++ // Only fetch scales if this tile starts a new group ++ if (pipe % (group_blocks / thread_k_blocks) == 0) { ++ if (s_sh_wr_pred) { ++ cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); ++ } ++ s_gl_rd += s_gl_rd_delta; ++ } ++ } else { ++ for (int i = 0; i < s_tb_groups; i++) { ++ if (s_sh_wr_pred) { ++ cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], ++ &scales_ptr[s_gl_rd]); ++ } ++ s_gl_rd += s_gl_rd_delta; ++ } ++ } ++ } ++ ++ if constexpr (has_zp && group_blocks != -1) { ++ int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; ++ ++ if constexpr (group_blocks >= thread_k_blocks) { ++ // Only fetch zero-points if this tile starts a new group ++ if (pipe % (group_blocks / thread_k_blocks) == 0) { ++ if (zp_sh_wr_pred) { ++ cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); ++ } ++ zp_gl_rd += zp_gl_rd_delta; ++ } ++ } else { ++ for (int i = 0; i < zp_tb_groups; i++) { ++ if (zp_sh_wr_pred) { ++ cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], ++ &zp_ptr[zp_gl_rd]); ++ } ++ zp_gl_rd += zp_gl_rd_delta; ++ } ++ } ++ } ++ } ++ } ++ // Insert a fence even when we are winding down the pipeline to ensure that ++ // waiting is also correct at this point. ++ cp_async_fence(); ++ }; ++ ++ auto fetch_zp_to_shared = [&]() { ++ if (zp_sh_wr_pred) { ++ cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); ++ } ++ }; ++ ++ // Wait until the next thread tile has been loaded to shared memory. ++ auto wait_for_stage = [&]() { ++ // We only have `stages - 2` active fetches since we are double buffering ++ // and can only issue the next fetch when it is guaranteed that the previous ++ // shared memory load is fully complete (as it may otherwise be ++ // overwritten). ++ cp_async_wait(); ++ __syncthreads(); ++ }; ++ ++ // Load the next sub-tile from the current location in the shared memory pipe ++ // into the current register buffer. ++ auto fetch_to_registers = [&](int k, int pipe) { ++ int4* sh_a_stage = sh_a + a_sh_stage * pipe; ++ #pragma unroll ++ for (int i = 0; i < thread_m_blocks; i++) ++ ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); ++ int4* sh_b_stage = sh_b + b_sh_stage * pipe; ++ ++ #pragma unroll ++ for (int i = 0; i < b_thread_vecs; i++) { ++ frag_b_quant[k % 2][i] = *reinterpret_cast( ++ &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); ++ } ++ }; ++ ++ bool is_same_group[stages]; ++ int same_group_id[stages]; ++ ++ auto init_same_group = [&](int pipe) { ++ if constexpr (!has_act_order) { ++ is_same_group[pipe] = false; ++ same_group_id[pipe] = 0; ++ return; ++ } ++ ++ int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; ++ int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); ++ ++ int group_id_1 = sh_g_idx_int_ptr[0]; ++ int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; ++ ++ is_same_group[pipe] = group_id_1 == group_id_2; ++ same_group_id[pipe] = group_id_1; ++ }; ++ ++ auto fetch_scales_to_registers = [&](int k, int full_pipe) { ++ int pipe = full_pipe % stages; ++ ++ if constexpr (!has_act_order) { ++ // No act-order case ++ if constexpr (group_blocks != -1) { ++ if constexpr (group_blocks >= thread_k_blocks) { ++ int4* sh_s_stage = ++ sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * ++ (pipe / (group_blocks / thread_k_blocks))); ++ reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; ++ } else { ++ int warp_id = threadIdx.x / 32; ++ int n_warps = thread_n_blocks / 4; ++ ++ int warp_row = warp_id / n_warps; ++ ++ int cur_k = warp_row * 16; ++ cur_k += k_iter_size * (k % b_sh_wr_iters); ++ ++ int k_blocks = cur_k / 16; ++ int cur_group_id = k_blocks / group_blocks; ++ ++ int4* sh_s_stage = sh_s + s_sh_stage * pipe; ++ ++ reinterpret_cast(&frag_s[k % 2])[0] = ++ sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; ++ } ++ } ++ ++ return; ++ } ++ ++ // Act-order case ++ ++ // Determine K of the "current" thread-block ++ int cur_k = slice_k_start + tb_k * full_pipe; ++ if (cur_k >= prob_k || cur_k >= slice_k_finish) { ++ return; ++ } ++ ++ // Reset (to current thread-block) since we read g_idx portion from the ++ // shared memory ++ cur_k = 0; ++ ++ // Progress to current iteration ++ cur_k += k_iter_size * (k % b_sh_wr_iters); ++ ++ // Determine "position" inside the thread-block (based on warp and ++ // thread-id) ++ int warp_id = threadIdx.x / 32; ++ int n_warps = ++ thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N ++ ++ int warp_row = warp_id / n_warps; ++ int warp_col = warp_id % n_warps; ++ ++ cur_k += warp_row * 16; ++ ++ int th_id = threadIdx.x % 32; ++ cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix ++ ++ int s_col_shift = ++ /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + ++ (th_id / 4) * act_s_col_stride; ++ ++ if (is_same_group[pipe]) { ++ if (k % 2 == 0) { ++ *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = ++ sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + ++ s_col_shift]; ++ } else { ++ *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = ++ *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); ++ } ++ ++ for (int i = 1; i < 4; i++) { ++ *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = ++ *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); ++ } ++ return; ++ } ++ ++ int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; ++ int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); ++ ++ constexpr int k_frag_offsets[4] = {0, 1, 8, ++ 9}; // Tensor core offsets per thread ++ ++ #pragma unroll ++ for (int i = 0; i < 4; i++) { ++ int actual_k = cur_k + k_frag_offsets[i]; ++ ++ int group_id = sh_g_idx_int_ptr[actual_k]; ++ int rel_group_id = group_id - sh_first_group_id; ++ ++ *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = ++ sh_s[rel_group_id * s_sh_stride + s_col_shift]; ++ } ++ }; ++ ++ auto fetch_zp_to_registers = [&](int k, int full_pipe) { ++ // This code does not handle group_blocks == 0, ++ // which signifies act_order. ++ // has_zp implies AWQ, which doesn't have act_order, ++ static_assert(!has_zp || group_blocks != 0); ++ ++ if constexpr (has_zp) { ++ int pipe = full_pipe % stages; ++ ++ if constexpr (group_blocks == -1) { ++ for (int i = 0; i < num_ints_per_thread; i++) { ++ frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; ++ } ++ ++ } else if constexpr (group_blocks >= thread_k_blocks) { ++ int4* sh_zp_stage = ++ sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * ++ (pipe / (group_blocks / thread_k_blocks))); ++ for (int i = 0; i < num_ints_per_thread; i++) { ++ frag_qzp[k % 2][i] = ++ (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; ++ } ++ } else { ++ int warp_id = threadIdx.x / 32; ++ int n_warps = thread_n_blocks / 4; ++ ++ int warp_row = warp_id / n_warps; ++ ++ int cur_k = warp_row * 16; ++ cur_k += k_iter_size * (k % b_sh_wr_iters); ++ ++ int k_blocks = cur_k / 16; ++ int cur_group_id = 0; ++ ++ // Suppress bogus and persistent divide-by-zero warning ++ #pragma nv_diagnostic push ++ #pragma nv_diag_suppress divide_by_zero ++ cur_group_id = k_blocks / group_blocks; ++ #pragma nv_diagnostic pop ++ ++ int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; ++ ++ sh_zp_stage += cur_group_id * zp_sh_stride; ++ ++ for (int i = 0; i < num_ints_per_thread; i++) { ++ frag_qzp[k % 2][i] = ++ (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; ++ } ++ } ++ } ++ }; ++ ++ // Execute the actual tensor core matmul of a sub-tile. ++ auto matmul = [&](int k) { ++ if constexpr (has_zp) { ++ FragB frag_zp_0; ++ FragB frag_zp_1; ++ int zp_quant_0, zp_quant_1; ++ ++ if constexpr (w_type.size_bits() == 4) { ++ zp_quant_0 = frag_qzp[k % 2][0]; ++ zp_quant_1 = zp_quant_0 >> 8; ++ } else { ++ static_assert(w_type.size_bits() == 8); ++ zp_quant_0 = frag_qzp[k % 2][0]; ++ zp_quant_1 = frag_qzp[k % 2][1]; ++ } ++ ++ frag_zp_0 = dequant(zp_quant_0); ++ frag_zp_1 = dequant(zp_quant_1); ++ ++ frag_zp[0] = frag_zp_0[0]; ++ frag_zp[1] = frag_zp_0[1]; ++ frag_zp[2] = frag_zp_1[0]; ++ frag_zp[3] = frag_zp_1[1]; ++ } ++ ++ // We have the m dimension as the inner loop in order to encourage overlapping ++ // dequantization and matmul operations. ++ #pragma unroll ++ for (int j = 0; j < 4; j++) { ++ int b_quant_0, b_quant_1; ++ if constexpr (w_type.size_bits() == 4) { ++ b_quant_0 = frag_b_quant[k % 2][0][j]; ++ b_quant_1 = b_quant_0 >> 8; ++ } else { ++ static_assert(w_type.size_bits() == 8); ++ int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); ++ b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; ++ b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; ++ } ++ ++ FragB frag_b0 = dequant(b_quant_0); ++ FragB frag_b1 = dequant(b_quant_1); ++ // Apply zero-point to frag_b0 ++ if constexpr (has_zp) { ++ sub_zp(frag_b0, frag_zp[j], 0); ++ } ++ ++ // Apply scale to frag_b0 ++ if constexpr (has_act_order) { ++ scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], ++ act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); ++ } else { ++ if constexpr (group_blocks != -1) { ++ scale(frag_b0, frag_s[k % 2][j], 0); ++ } ++ } ++ ++ // Apply zero-point to frag_b1 ++ if constexpr (has_zp) { ++ sub_zp(frag_b1, frag_zp[j], 1); ++ } ++ ++ // Apply scale to frag_b1 ++ if constexpr (has_act_order) { ++ scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], ++ act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); ++ ++ } else { ++ if constexpr (group_blocks != -1) { ++ scale(frag_b1, frag_s[k % 2][j], 1); ++ } ++ } ++ ++ #pragma unroll ++ for (int i = 0; i < thread_m_blocks; i++) { ++ mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); ++ mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); ++ } ++ } ++ }; ++ ++ // Since we slice across the k dimension of a tile in order to increase the ++ // number of warps while keeping the n dimension of a tile reasonable, we have ++ // multiple warps that accumulate their partial sums of the same output ++ // location; which we have to reduce over in the end. We do in shared memory. ++ auto thread_block_reduce = [&]() { ++ constexpr int red_off = threads / b_sh_stride_threads / 2; ++ if (red_off >= 1) { ++ int red_idx = threadIdx.x / b_sh_stride_threads; ++ constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; ++ constexpr int red_sh_delta = b_sh_stride_threads; ++ int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + ++ (threadIdx.x % b_sh_stride_threads); ++ ++ // Parallel logarithmic shared memory reduction. We make sure to avoid any ++ // unnecessary read or write iterations, e.g., for two warps we write only ++ // once by warp 1 and read only once by warp 0. ++ ++ #pragma unroll ++ for (int m_block = 0; m_block < thread_m_blocks; m_block++) { ++ #pragma unroll ++ for (int i = red_off; i > 0; i /= 2) { ++ if (i <= red_idx && red_idx < 2 * i) { ++ #pragma unroll ++ for (int j = 0; j < 4 * 2; j++) { ++ int red_sh_wr = ++ red_sh_delta * j + (red_sh_rd - red_sh_stride * i); ++ if (i < red_off) { ++ float* c_rd = ++ reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); ++ float* c_wr = reinterpret_cast(&sh[red_sh_wr]); ++ #pragma unroll ++ for (int k = 0; k < 4; k++) ++ reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += ++ c_rd[k] + c_wr[k]; ++ } ++ sh[red_sh_wr] = ++ reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; ++ } ++ } ++ __syncthreads(); ++ } ++ if (red_idx == 0) { ++ #pragma unroll ++ for (int i = 0; i < 4 * 2; i++) { ++ float* c_rd = ++ reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); ++ #pragma unroll ++ for (int j = 0; j < 4; j++) ++ reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += ++ c_rd[j]; ++ } ++ } ++ __syncthreads(); ++ } ++ } ++ }; ++ ++ // Since multiple threadblocks may process parts of the same column slice, we ++ // finally have to globally reduce over the results. As the striped ++ // partitioning minimizes the number of such reductions and our outputs are ++ // usually rather small, we perform this reduction serially in L2 cache. ++ auto global_reduce = [&](bool first = false, bool last = false) { ++ // We are very careful here to reduce directly in the output buffer to ++ // maximize L2 cache utilization in this step. To do this, we write out ++ // results in FP16 (but still reduce with FP32 compute). ++ constexpr int active_threads = 32 * thread_n_blocks / 4; ++ if (threadIdx.x < active_threads) { ++ int c_gl_stride = prob_n / 8; ++ int c_gl_wr_delta_o = 8 * c_gl_stride; ++ int c_gl_wr_delta_i = 4 * (active_threads / 32); ++ int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + ++ 4 * (threadIdx.x / 32) + threadIdx.x % 4; ++ c_gl_wr += (2 * thread_n_blocks) * slice_col; ++ constexpr int c_sh_wr_delta = active_threads; ++ int c_sh_wr = threadIdx.x; ++ ++ int row = (threadIdx.x % 32) / 4; ++ ++ if (!first) { ++ // Interestingly, doing direct global accesses here really seems to mess up ++ // the compiler and lead to slowdowns, hence we also use async-copies even ++ // though these fetches are not actually asynchronous. ++ #pragma unroll ++ for (int i = 0; i < thread_m_blocks * 4; i++) { ++ int c_idx = ++ c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); ++ int sorted_row = sorted_ids[c_idx / c_gl_stride]; ++ int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; ++ cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], ++ sorted_row < tot_m * topk && ++ (8 * (i / 2) + row < prob_m && ++ (i < (thread_m_blocks - 1) * 4 || ++ sorted_ids[8 * (i / 2) + row] < tot_m * topk))); ++ } ++ cp_async_fence(); ++ cp_async_wait<0>(); ++ } ++ ++ #pragma unroll ++ for (int i = 0; i < thread_m_blocks * 4; i++) { ++ if (8 * (i / 2) + row < prob_m && ++ (i < (thread_m_blocks - 1) * 4 || ++ sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { ++ if (!first) { ++ int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; ++ #pragma unroll ++ for (int j = 0; j < 2 * 4; j++) { ++ reinterpret_cast( ++ &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += ++ __half2float(reinterpret_cast<__half*>(&c_red)[j]); ++ } ++ } ++ if (!last) { ++ int4 c; ++ #pragma unroll ++ for (int j = 0; j < 2 * 4; j++) { ++ reinterpret_cast<__half*>(&c)[j] = ++ __float2half(reinterpret_cast( ++ &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); ++ } ++ int c_idx = ++ c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); ++ int row = sorted_ids[c_idx / c_gl_stride]; ++ if (row < tot_m * topk) { ++ int new_idx = row * c_gl_stride + c_idx % c_gl_stride; ++ C[new_idx] = c; ++ } ++ } ++ } ++ } ++ } ++ }; ++ ++ // Write out the reduce final result in the correct layout. We only actually ++ // reshuffle matrix fragments in this step, the reduction above is performed ++ // in fragment layout. ++ auto write_result = [&]() { ++ int c_gl_stride = prob_n / 8; ++ constexpr int c_sh_stride = 2 * thread_n_blocks + 1; ++ int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); ++ constexpr int c_sh_rd_delta = ++ c_sh_stride * (threads / (2 * thread_n_blocks)); ++ ++ int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + ++ (threadIdx.x % (2 * thread_n_blocks)); ++ c_gl_wr += (2 * thread_n_blocks) * slice_col; ++ int c_sh_wr = ++ (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; ++ c_sh_wr += 32 * (threadIdx.x / 32); ++ int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + ++ (threadIdx.x % (2 * thread_n_blocks)); ++ ++ int c_gl_wr_end = c_gl_stride * prob_m; ++ ++ // We first reorder in shared memory to guarantee the most efficient final ++ // global write patterns ++ auto write = [&](int idx, float c0, float c1, FragS& s) { ++ half2 res = __halves2half2(__float2half(c0), __float2half(c1)); ++ ++ // For per-column quantization we finally apply the scale here (only for ++ // 4-bit) ++ if constexpr (!has_act_order && group_blocks == -1 && ++ w_type.size_bits() == 4) { ++ res = __hmul2(res, s[0]); ++ } ++ ++ ((half2*)sh)[idx] = res; ++ }; ++ if (threadIdx.x / 32 < thread_n_blocks / 4) { ++ #pragma unroll ++ for (int i = 0; i < thread_m_blocks; i++) { ++ #pragma unroll ++ for (int j = 0; j < 4; j++) { ++ int wr = c_sh_wr + 8 * j; ++ write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], ++ frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); ++ write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], ++ frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); ++ write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], ++ frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); ++ write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], ++ frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); ++ } ++ c_sh_wr += 16 * (4 * c_sh_stride); ++ } ++ } ++ __syncthreads(); ++ ++ #pragma unroll ++ for (int i = 0; ++ i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); ++ i++) { ++ if (c_gl_wr < c_gl_wr_end) { ++ int row = sorted_ids[c_gl_wr / c_gl_stride]; ++ if (row < tot_m * topk) { ++ int off = row * c_gl_stride + c_gl_wr % c_gl_stride; ++ if (!apply_weights) { ++ C[off] = sh[c_sh_rd]; ++ } else { ++ __half* ctrg = reinterpret_cast<__half*>(&C[off]); ++ __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); ++ for (int j = 0; j < 8; ++j) { ++ ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); ++ } ++ } ++ c_gl_wr += c_gl_wr_delta; ++ c_sh_rd += c_sh_rd_delta; ++ } ++ } ++ } ++ }; ++ ++ // Start global fetch and register load pipelines. ++ auto start_pipes = [&]() { ++ ++ #pragma unroll ++ for (int i = 0; i < stages - 1; i++) { ++ if (has_act_order && i == 0) { ++ int last_g_idx = slice_k_start + stages * tb_k * 2; ++ if (last_g_idx >= prob_k) { ++ last_g_idx = prob_k - 1; ++ } ++ fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); ++ } ++ ++ if constexpr (has_zp && group_blocks == -1) { ++ if (i == 0) { ++ fetch_zp_to_shared(); ++ } ++ } ++ fetch_to_shared(i, i, i < slice_iters); ++ } ++ ++ zero_accums(); ++ wait_for_stage(); ++ init_same_group(0); ++ fetch_to_registers(0, 0); ++ fetch_scales_to_registers(0, 0); ++ fetch_zp_to_registers(0, 0); ++ a_gl_rd += a_gl_rd_delta_o * (stages - 1); ++ slice_k_start_shared_fetch += tb_k * (stages - 1); ++ }; ++ if (slice_iters) { ++ start_pipes(); ++ } ++ ++ // Main loop. ++ while (slice_iters) { ++ // We unroll over both the global fetch and the register load pipeline to ++ // ensure all shared memory accesses are static. Note that both pipelines ++ // have even length meaning that the next iteration will always start at ++ // index 0. ++ #pragma unroll ++ for (int pipe = 0; pipe < stages;) { ++ #pragma unroll ++ for (int k = 0; k < b_sh_wr_iters; k++) { ++ fetch_to_registers(k + 1, pipe % stages); ++ fetch_scales_to_registers(k + 1, pipe); ++ fetch_zp_to_registers(k + 1, pipe); ++ if (k == b_sh_wr_iters - 2) { ++ fetch_to_shared((pipe + stages - 1) % stages, pipe, ++ slice_iters >= stages); ++ pipe++; ++ wait_for_stage(); ++ init_same_group(pipe % stages); ++ } ++ matmul(k); ++ } ++ slice_iters--; ++ if (slice_iters == 0) { ++ break; ++ } ++ } ++ ++ a_gl_rd += a_gl_rd_delta_o * stages; ++ slice_k_start += tb_k * stages; ++ slice_k_start_shared_fetch += tb_k * stages; ++ ++ if constexpr (has_act_order) { ++ int first_group_id = g_idx[slice_k_start]; ++ int last_g_idx = slice_k_start + stages * tb_k * 2; ++ if (last_g_idx >= prob_k) { ++ last_g_idx = prob_k - 1; ++ } ++ int last_group_id = g_idx[last_g_idx]; ++ if (last_group_id >= sh_first_group_id + sh_num_groups) { ++ fetch_scales_to_shared(false, first_group_id, last_group_id); ++ __syncthreads(); ++ } ++ } ++ ++ // Process results and, if necessary, proceed to the next column slice. ++ // While this pattern may not be the most readable, other ways of writing ++ // the loop seemed to noticeably worse performance after compilation. ++ if (slice_iters == 0) { ++ cp_async_wait<0>(); ++ bool last = slice_idx == slice_count - 1; ++ if constexpr (!has_act_order && group_blocks == -1) { ++ if constexpr (w_type.size_bits() == 8) { ++ if (s_sh_wr_pred) { ++ cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); ++ } ++ cp_async_fence(); ++ } else { ++ // For 4-bit per-column scales, we only fetch them here in the ++ // final step before write-out ++ if (last) { ++ if (s_sh_wr_pred) { ++ cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); ++ } ++ cp_async_fence(); ++ } ++ } ++ } ++ ++ thread_block_reduce(); ++ if constexpr (!has_act_order && group_blocks == -1) { ++ if constexpr (w_type.size_bits() == 8) { ++ cp_async_wait<0>(); ++ __syncthreads(); ++ if (threadIdx.x / 32 < thread_n_blocks / 4) { ++ reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; ++ reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; ++ } ++ ++ } else { ++ if (last) { ++ cp_async_wait<0>(); ++ __syncthreads(); ++ if (threadIdx.x / 32 < thread_n_blocks / 4) { ++ reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; ++ reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; ++ } ++ } ++ } ++ } ++ ++ // For 8-bit channelwise, we apply the scale before the global reduction ++ // that converts the fp32 results to fp16 (so that we avoid possible ++ // overflow in fp16) ++ if constexpr (!has_act_order && group_blocks == -1 && ++ w_type.size_bits() == 8) { ++ if (threadIdx.x / 32 < thread_n_blocks / 4) { ++ #pragma unroll ++ for (int i = 0; i < thread_m_blocks; i++) { ++ #pragma unroll ++ for (int j = 0; j < 4; j++) { ++ scale_float(reinterpret_cast(&frag_c[i][j][0][0]), ++ frag_s[j / 2][2 * (j % 2) + 0]); ++ scale_float(reinterpret_cast(&frag_c[i][j][0][2]), ++ frag_s[j / 2][2 * (j % 2) + 0]); ++ ++ scale_float(reinterpret_cast(&frag_c[i][j][1][0]), ++ frag_s[j / 2][2 * (j % 2) + 1]); ++ scale_float(reinterpret_cast(&frag_c[i][j][1][2]), ++ frag_s[j / 2][2 * (j % 2) + 1]); ++ } ++ } ++ } ++ } ++ ++ if (slice_count > 1) { // only globally reduce if there is more than one ++ // block in a slice ++ barrier_acquire(&locks[slice_col], slice_idx); ++ global_reduce(slice_idx == 0, last); ++ barrier_release(&locks[slice_col], last); ++ } ++ if (last) // only the last block in a slice actually writes the result ++ write_result(); ++ slice_row = 0; ++ slice_col_par++; ++ slice_col++; ++ init_slice(); ++ if (slice_iters) { ++ a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + ++ (threadIdx.x % a_gl_rd_delta_o); ++ #pragma unroll ++ for (int i = 0; i < b_sh_wr_iters; i++) ++ B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; ++ if (slice_col == 0) { ++ #pragma unroll ++ for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; ++ } ++ ++ // Update slice k/n for scales loading ++ if constexpr (has_act_order) { ++ slice_k_start = tb_k * slice_row; ++ slice_k_finish = slice_k_start + tb_k * slice_iters; ++ slice_k_start_shared_fetch = slice_k_start; ++ slice_n_offset = act_s_col_tb_stride * slice_col; ++ ++ } else { ++ s_gl_rd = s_sh_stride * slice_col + threadIdx.x; ++ zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; ++ } ++ ++ start_pipes(); ++ } ++ } ++ } ++} ++ ++template shared ++ // fetch pipeline ++ const bool has_act_order, // whether act_order is enabled ++ const bool has_zp, // whether zero-points are enabled ++ const int group_blocks = -1 // number of consecutive 16x16 blocks ++ // with a separate quantization scale ++ > ++__global__ void MarlinMoE( ++ const int4* __restrict__ A, // fp16 input matrix of shape mxk ++ const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn ++ int4* __restrict__ C, // fp16 output buffer of shape mxn ++ const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts ++ const float* __restrict__ topk_weights, // float topk weights ++ const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape ++ // (k/groupsize)xn ++ const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape ++ // (k/groupsize)x(n/pack_factor) ++ const int* __restrict__ g_idx, // int32 group indices of shape k ++ const int* __restrict__ expert_offsets, ++ int num_groups, // number of scale groups per output channel ++ int expert_idx, // idx of current expert ++ int num_experts, // number of experts ++ int topk, // topk parameter of moe ++ int prob_m, // batch dimension m ++ int prob_n, // output dimension n ++ int prob_k, // reduction dimension k ++ int tot_m, // total number of rows in A and C ++ int* locks, // extra global storage for barrier synchronization ++ bool replicate_input, // do we use the same input for each expert? ++ bool apply_weights, // apply weights to output ++ int current_m_block, // current m block to start kernel computation from ++ int max_par, // maximum parallelism ++ int cfg_max_m_blocks // upper bound on m blocks ++) { ++ int m_block_ctr = current_m_block; ++ ++ const int* sorted_ids_expert = ++ sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; ++ int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; ++ if (tot_its == 0) { ++ return; ++ } ++ int tot_m_blocks = ceildiv(tot_its, 16); ++ int pad = 16 * tot_m_blocks - tot_its; ++ ++ if (m_block_ctr >= tot_m_blocks) { ++ return; ++ } ++ ++ int max_block = tot_m_blocks - m_block_ctr; ++ prob_m = tot_its - 16 * m_block_ctr; ++ ++ int par = 1; ++ if (max_block > cfg_max_m_blocks) { ++ // Note that parallel > 1 currently only works for inputs without any ++ // padding ++ par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); ++ if (par > max_par) par = max_par; ++ prob_m = (16 * cfg_max_m_blocks) * par; ++ m_block_ctr += cfg_max_m_blocks * (par - 1); ++ max_block = cfg_max_m_blocks; ++ } ++ ++ if (max_block == 1) { ++ MarlinMoESingle( ++ A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, ++ expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, ++ prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, ++ current_m_block); ++ } else if (max_block == 2) { ++ MarlinMoESingle( ++ A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, ++ expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, ++ prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, ++ current_m_block); ++ } else if (max_block == 3) { ++ MarlinMoESingle( ++ A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, ++ expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, ++ prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, ++ current_m_block); ++ } else { ++ MarlinMoESingle( ++ A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, ++ expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, ++ prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, ++ current_m_block); ++ } ++} ++ ++#else ++ ++template shared ++ // fetch pipeline ++ const bool has_act_order, // whether act_order is enabled ++ const bool has_zp, // whether zero-points are enabled ++ const int group_blocks = -1 // number of consecutive 16x16 blocks ++ // with a separate quantization scale ++ > ++__global__ void MarlinMoE( ++ const int4* __restrict__ A, // fp16 input matrix of shape mxk ++ const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn ++ int4* __restrict__ C, // fp16 output buffer of shape mxn ++ const int* __restrict__ sorted_ids, // int32 sorted ids of experts ++ const float* __restrict__ topk_weights, // float topk weights ++ const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape ++ // (k/groupsize)xn ++ const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape ++ // (k/groupsize)x(n/pack_factor) ++ const int* __restrict__ g_idx, // int32 group indices of shape k ++ const int* __restrict__ expert_offsets, ++ int num_groups, // number of scale groups per output channel ++ int expert_idx, // idx of current expert ++ int num_experts, // number of experts ++ int topk, // topk parameter of moe ++ int prob_m, // batch dimension m ++ int prob_n, // output dimension n ++ int prob_k, // reduction dimension k ++ int tot_m, // total number of rows in A and C ++ int* locks, // extra global storage for barrier synchronization ++ bool replicate_input, // do we use the same input for each expert? ++ bool apply_weights, // apply weights to output ++ int current_m_block, // current m block to start kernel computation from ++ int max_par, // maximum parallelism ++ int cfg_max_m_blocks // upper bound on m blocks ++) { ++ // Marlin is not implemented yet for SM < 8.0 ++ assert(false); ++ return; ++} ++ ++#endif ++ ++// 8 warps are a good choice since every SM has 4 schedulers and having more ++// than 1 warp per schedule allows some more latency hiding. At the same time, ++// we want relatively few warps to have many registers per warp and small tiles. ++const int USER_THREADS = ++ 256; // Note: This is only used with user-provided thread_k/n ++const int STAGES = 4; // 4 pipeline stages fit into shared memory ++ ++static constexpr int min_thread_n = 64; ++static constexpr int min_thread_k = 64; ++ ++#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ ++ HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ ++ else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ ++ thread_k_blocks == THREAD_K_BLOCKS && \ ++ has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ ++ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ ++ cudaFuncSetAttribute( \ ++ MarlinMoE, \ ++ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ ++ MarlinMoE \ ++ <<>>( \ ++ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ ++ zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ ++ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ ++ replicate_input, apply_weights, m_block, max_par, \ ++ cfg_max_m_blocks); \ ++ } ++ ++#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ ++ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ ++ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ ++ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ ++ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ ++ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) ++ ++#define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ ++ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ ++ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ ++ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ ++ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) ++ ++} // namespace marlin_moe +diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu +new file mode 100644 +index 0000000..77bc0dd +--- /dev/null ++++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu +@@ -0,0 +1,31 @@ ++#include "marlin_moe_kernel_ku4.h" ++ ++namespace marlin_moe { ++ ++// We return bool so we can create these different kernel calls as a sequence ++// of if-elseif's. ++bool call_marlin_moe_kernel_ku4( ++ vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, ++ bool has_act_order, int group_blocks, int num_threads, int blocks, ++ int max_shared_mem, cudaStream_t stream, const int4* A_ptr, ++ const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, ++ const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, ++ const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, ++ int expert_idx, int num_experts, int topk, int prob_m, int prob_n, ++ int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, ++ int m_block, int max_par, int cfg_max_m_blocks) { ++ bool has_zp = true; ++ ++ if (false) { ++ } ++ AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256) ++ AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256) ++ AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128) ++ AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128) ++ else { ++ return false; ++ } ++ return true; ++} ++ ++} // namespace marlin_moe +diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h +new file mode 100644 +index 0000000..833fadf +--- /dev/null ++++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h +@@ -0,0 +1,20 @@ ++#pragma once ++ ++#include "marlin_moe_kernel.h" ++ ++namespace marlin_moe { ++ ++// We return bool so we can create these different kernel calls as a sequence ++// of if-elseif's. ++bool call_marlin_moe_kernel_ku4( ++ vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, ++ bool has_act_order, int group_blocks, int num_threads, int blocks, ++ int max_shared_mem, cudaStream_t stream, const int4* A_ptr, ++ const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, ++ const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, ++ const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, ++ int expert_idx, int num_experts, int topk, int prob_m, int prob_n, ++ int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, ++ int m_block, int max_par, int cfg_max_m_blocks); ++ ++} // namespace marlin_moe +diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu +new file mode 100644 +index 0000000..f7e57b0 +--- /dev/null ++++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu +@@ -0,0 +1,31 @@ ++#include "marlin_moe_kernel_ku4b8.h" ++ ++namespace marlin_moe { ++ ++// We return bool so we can create these different kernel calls as a sequence ++// of if-elseif's. ++bool call_marlin_moe_kernel_ku4b8( ++ vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, ++ bool has_act_order, int group_blocks, int num_threads, int blocks, ++ int max_shared_mem, cudaStream_t stream, const int4* A_ptr, ++ const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, ++ const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, ++ const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, ++ int expert_idx, int num_experts, int topk, int prob_m, int prob_n, ++ int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, ++ int m_block, int max_par, int cfg_max_m_blocks) { ++ bool has_zp = false; ++ ++ if (false) { ++ } ++ GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) ++ GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) ++ GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) ++ GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) ++ else { ++ return false; ++ } ++ return true; ++} ++ ++} // namespace marlin_moe +diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h +new file mode 100644 +index 0000000..494da8f +--- /dev/null ++++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h +@@ -0,0 +1,20 @@ ++#pragma once ++ ++#include "marlin_moe_kernel.h" ++ ++namespace marlin_moe { ++ ++// We return bool so we can create these different kernel calls as a sequence ++// of if-elseif's. ++bool call_marlin_moe_kernel_ku4b8( ++ vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, ++ bool has_act_order, int group_blocks, int num_threads, int blocks, ++ int max_shared_mem, cudaStream_t stream, const int4* A_ptr, ++ const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, ++ const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, ++ const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, ++ int expert_idx, int num_experts, int topk, int prob_m, int prob_n, ++ int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, ++ int m_block, int max_par, int cfg_max_m_blocks); ++ ++} // namespace marlin_moe +diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu +new file mode 100644 +index 0000000..a901f0b +--- /dev/null ++++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu +@@ -0,0 +1,31 @@ ++#include "marlin_moe_kernel_ku8b128.h" ++ ++namespace marlin_moe { ++ ++// We return bool so we can create these different kernel calls as a sequence ++// of if-elseif's. ++bool call_marlin_moe_kernel_ku8b128( ++ vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, ++ bool has_act_order, int group_blocks, int num_threads, int blocks, ++ int max_shared_mem, cudaStream_t stream, const int4* A_ptr, ++ const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, ++ const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, ++ const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, ++ int expert_idx, int num_experts, int topk, int prob_m, int prob_n, ++ int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, ++ int m_block, int max_par, int cfg_max_m_blocks) { ++ bool has_zp = false; ++ ++ if (false) { ++ } ++ GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) ++ GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) ++ GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) ++ GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) ++ else { ++ return false; ++ } ++ return true; ++} ++ ++} // namespace marlin_moe +diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h +new file mode 100644 +index 0000000..f3018aa +--- /dev/null ++++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h +@@ -0,0 +1,18 @@ ++#pragma once ++ ++#include "marlin_moe_kernel.h" ++ ++namespace marlin_moe { ++ ++bool call_marlin_moe_kernel_ku8b128( ++ vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, ++ bool has_act_order, int group_blocks, int num_threads, int blocks, ++ int max_shared_mem, cudaStream_t stream, const int4* A_ptr, ++ const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, ++ const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, ++ const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, ++ int expert_idx, int num_experts, int topk, int prob_m, int prob_n, ++ int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, ++ int m_block, int max_par, int cfg_max_m_blocks); ++ ++} +diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu +new file mode 100644 +index 0000000..5f12483 +--- /dev/null ++++ b/csrc/moe/marlin_moe_ops.cu +@@ -0,0 +1,588 @@ ++/* ++ * Modified by Neural Magic ++ * Copyright (C) Marlin.2024 Elias Frantar ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++#include ++ ++#include ++#include ++#include ++#include ++#include ++ ++#include ++ ++#include "core/exception.hpp" ++#include "core/scalar_type.hpp" ++#include "core/registration.h" ++#include "marlin_kernels/marlin_moe_kernel_ku4b8.h" ++#include "marlin_kernels/marlin_moe_kernel_ku8b128.h" ++#include "marlin_kernels/marlin_moe_kernel_ku4.h" ++ ++template ++inline std::string str(T x) { ++ return std::to_string(x); ++} ++ ++namespace marlin_moe { ++ ++#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 ++ ++// For a given "a" of size [M,K] performs a permutation of the K columns based ++// on the given "perm" indices. ++__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, ++ int const* __restrict__ perm_int_ptr, ++ int4* __restrict__ out_int4_ptr, int size_m, ++ int size_k, int block_rows) { ++ int start_row = block_rows * blockIdx.x; ++ int finish_row = start_row + block_rows; ++ if (finish_row > size_m) { ++ finish_row = size_m; ++ } ++ int cur_block_rows = finish_row - start_row; ++ ++ int row_stride = size_k * sizeof(half) / 16; ++ ++ auto permute_row = [&](int row) { ++ int iters = size_k / blockDim.x; ++ int rest = size_k % blockDim.x; ++ ++ int offset = row * row_stride; ++ ++ half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); ++ half* out_half = reinterpret_cast(out_int4_ptr + offset); ++ ++ int base_k = 0; ++ ++ for (int i = 0; i < iters; i++) { ++ int cur_k = base_k + threadIdx.x; ++ int src_pos = perm_int_ptr[cur_k]; ++ ++ out_half[cur_k] = a_row_half[src_pos]; ++ ++ base_k += blockDim.x; ++ } ++ ++ if (rest) { ++ if (threadIdx.x < rest) { ++ int cur_k = base_k + threadIdx.x; ++ int src_pos = perm_int_ptr[cur_k]; ++ ++ out_half[cur_k] = a_row_half[src_pos]; ++ } ++ } ++ }; ++ ++ for (int i = 0; i < cur_block_rows; i++) { ++ int cur_row = start_row + i; ++ if (cur_row < size_m) { ++ permute_row(cur_row); ++ } ++ } ++} ++ ++__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, ++ int* __restrict__ expert_offsets, ++ int topk_length, int block_size) { ++ int expert_id = threadIdx.x; ++ int num_experts = blockDim.x; ++ ++ int occurrences = 0; ++ for (int i = 0; i < topk_length; ++i) { ++ occurrences += (topk_ids[i] == expert_id); ++ } ++ expert_offsets[expert_id + 1] = occurrences; ++ __syncthreads(); ++ ++ if (threadIdx.x == 0) { ++ int tot_offset = 0; ++ expert_offsets[0] = 0; ++ for (int i = 0; i < num_experts; ++i) { ++ tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size; ++ expert_offsets[i + 1] = tot_offset; ++ } ++ } ++ __syncthreads(); ++} ++ ++#else ++ ++__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, ++ int const* __restrict__ perm_int_ptr, ++ int4* __restrict__ out_int4_ptr, int size_m, ++ int size_k, int block_rows) { ++ // Marlin is not implemented yet for SM < 8.0 ++ assert(false); ++ return; ++} ++ ++__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, ++ int* __restrict__ expert_offsets, ++ int topk_length, int block_size) { ++ // Marlin is not implemented yet for SM < 8.0 ++ assert(false); ++ return; ++} ++ ++#endif ++ ++typedef struct { ++ int thread_k; ++ int thread_n; ++ int num_threads; ++} thread_config_t; ++ ++typedef struct { ++ int max_m_blocks; ++ thread_config_t tb_cfg; ++} exec_config_t; ++ ++thread_config_t small_batch_thread_configs[] = { ++ // Ordered by priority ++ ++ // thread_k, thread_n, num_threads ++ {128, 128, 256}, // Default ++ {128, 64, 128}, // Reduce N 2X, same K ++ {64, 256, 256}, // Reduce K 2X, increase N 2X ++ {64, 128, 128}, // Reduce K 2X, same N ++ {64, 64, 128}, // Reduce both 2X ++}; ++ ++thread_config_t large_batch_thread_configs[] = { ++ // Ordered by priority ++ ++ // thread_k, thread_n, num_threads ++ {64, 256, 256}, // Default ++ {128, 128, 256}, // Reduce N 2X, increase K 2X ++ {64, 128, 128}, // Reduce N 2X, same K ++ {128, 64, 128}, // Reduce N 4X, increase K 2X ++ {64, 64, 128}, // Reduce N 4X, same K ++}; ++ ++int get_scales_cache_size(thread_config_t const& th_config, int prob_m, ++ int prob_n, int prob_k, int num_bits, int group_size, ++ bool has_act_order, bool is_k_full) { ++ bool cache_scales_chunk = has_act_order && !is_k_full; ++ ++ int tb_n = th_config.thread_n; ++ int tb_k = th_config.thread_k; ++ ++ // Get max scale groups per thread-block ++ int tb_groups; ++ if (group_size == -1) { ++ tb_groups = 1; ++ } else if (group_size == 0) { ++ tb_groups = ceildiv(tb_k, 32); // Worst case is 32 group size ++ } else { ++ tb_groups = ceildiv(tb_k, group_size); ++ } ++ ++ if (cache_scales_chunk) { ++ int load_groups = ++ tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K ++ load_groups = max(load_groups, 32); // We load at least 32 scale groups ++ return load_groups * tb_n * 4; ++ ++ } else { ++ int tb_scales = tb_groups * tb_n * 2; ++ ++ return tb_scales * STAGES; ++ } ++} ++ ++bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, ++ int prob_m, int prob_n, int prob_k, int num_bits, ++ int scales_cache_size, int max_shared_mem) { ++ int pack_factor = 32 / num_bits; ++ ++ // Get B size ++ int tb_k = th_config.thread_k; ++ int tb_n = th_config.thread_n; ++ ++ int b_size = (tb_k * tb_n / pack_factor) * 4; ++ ++ // Get A size ++ int m_blocks = ceildiv(prob_m, 16); ++ int tb_max_m = 16; ++ ++ while (true) { ++ if (m_blocks >= max_m_blocks) { ++ tb_max_m *= max_m_blocks; ++ break; ++ } ++ ++ max_m_blocks--; ++ if (max_m_blocks == 0) { ++ TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); ++ } ++ } ++ ++ int a_size = (tb_max_m * tb_k) * 2; ++ ++ float pipe_size = (a_size + b_size) * STAGES; ++ ++ TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity ++ ++ return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); ++} ++ ++bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, ++ int prob_m, int prob_n, int prob_k, int num_bits, ++ int group_size, bool has_act_order, bool is_k_full, ++ int max_shared_mem) { ++ // Sanity ++ if (th_config.thread_k == -1 || th_config.thread_n == -1 || ++ th_config.num_threads == -1) { ++ return false; ++ } ++ ++ // Verify K/N are divisible by thread K/N ++ if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { ++ return false; ++ } ++ ++ // thread_k can be only 128 or 64 (because it must be less than groupsize ++ // which is 128) ++ if (th_config.thread_k != 128 && th_config.thread_k != 64) { ++ return false; ++ } ++ ++ // Verify min for thread K/N ++ if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { ++ return false; ++ } ++ ++ // num_threads must be at least 128 (= 4 warps) ++ if (th_config.num_threads < 128) { ++ return false; ++ } ++ ++ // Determine cache for scales ++ int scales_cache_size = ++ get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, ++ group_size, has_act_order, is_k_full); ++ ++ // Check that pipeline fits into cache ++ if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, ++ num_bits, scales_cache_size, max_shared_mem)) { ++ return false; ++ } ++ ++ return true; ++} ++ ++exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, ++ int num_bits, int group_size, ++ bool has_act_order, bool is_k_full, ++ int max_shared_mem) { ++ int max_m_blocks = 4; ++ while (max_m_blocks > 0) { ++ if (prob_m <= 16) { ++ for (auto th_config : small_batch_thread_configs) { ++ if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, ++ num_bits, group_size, has_act_order, is_k_full, ++ max_shared_mem)) { ++ return exec_config_t{max_m_blocks, th_config}; ++ } ++ } ++ } else { ++ for (auto th_config : large_batch_thread_configs) { ++ if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, ++ num_bits, group_size, has_act_order, is_k_full, ++ max_shared_mem)) { ++ return exec_config_t{max_m_blocks, th_config}; ++ } ++ } ++ } ++ ++ max_m_blocks--; // Process less M blocks per invocation to reduce cache ++ // usage ++ } ++ ++ return exec_config_t{0, {-1, -1, -1}}; ++} ++ ++#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ ++ else if (KERNEL_FUNCTION( \ ++ q_type, thread_n_blocks, thread_k_blocks, has_act_order, \ ++ group_blocks, num_threads, blocks, max_shared_mem, stream, \ ++ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ ++ zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ ++ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ ++ replicate_input, apply_weights, m_block, max_par, \ ++ exec_cfg.max_m_blocks)) { \ ++ } ++ ++void marlin_mm_moe(const void* A, const void* B, void* C, ++ const void* sorted_ids, const void* topk_weights, ++ const void* topk_ids, const void* s, void* zp, ++ const void* g_idx, const void* perm, void* a_tmp, ++ void* expert_offsets, int prob_m, int prob_n, int prob_k, ++ void* workspace, vllm::ScalarType const& q_type, ++ bool has_act_order, bool is_k_full, bool has_zp, ++ int num_groups, int group_size, int num_experts, int topk, ++ int moe_block_size, int dev, cudaStream_t stream, ++ int thread_k, int thread_n, int sms, int max_par, ++ bool replicate_input, bool apply_weights) { ++ TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ++ ", ", prob_n, ", ", prob_k, "]"); ++ ++ if (sms == -1) { ++ cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); ++ } ++ ++ int max_shared_mem = 0; ++ cudaDeviceGetAttribute(&max_shared_mem, ++ cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); ++ TORCH_CHECK(max_shared_mem > 0); ++ ++ int num_bits = q_type.size_bits(); ++ ++ // Set thread config ++ exec_config_t exec_cfg; ++ if (thread_k != -1 && thread_n != -1) { ++ // User-defined config ++ exec_cfg = ++ exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}}; ++ } else { ++ // Auto config ++ exec_cfg = ++ determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, ++ has_act_order, is_k_full, max_shared_mem); ++ } ++ ++ TORCH_CHECK(exec_cfg.max_m_blocks > 0 && ++ is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, ++ prob_m, prob_n, prob_k, num_bits, group_size, ++ has_act_order, is_k_full, max_shared_mem), ++ "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, ++ ", thread_k = ", exec_cfg.tb_cfg.thread_k, ++ ", thread_n = ", exec_cfg.tb_cfg.thread_n, ++ ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", ++ prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, ++ ", group_size = ", group_size, ++ ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, ++ ", max_shared_mem = ", max_shared_mem); ++ ++ int num_threads = exec_cfg.tb_cfg.num_threads; ++ thread_k = exec_cfg.tb_cfg.thread_k; ++ thread_n = exec_cfg.tb_cfg.thread_n; ++ ++ int thread_k_blocks = thread_k / 16; ++ int thread_n_blocks = thread_n / 16; ++ ++ int blocks = sms; ++ ++ TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, ++ " is not divisible by thread_n = ", thread_n); ++ TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, ++ " is not divisible by thread_k = ", thread_k); ++ ++ int group_blocks = 0; ++ if (has_act_order) { ++ if (is_k_full) { ++ TORCH_CHECK(group_size != -1); ++ group_blocks = group_size / 16; ++ TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, ++ " is not divisible by group_blocks = ", group_blocks); ++ } else { ++ TORCH_CHECK(group_size == 0); ++ group_blocks = 0; ++ } ++ ++ } else { ++ if (group_size == -1) { ++ group_blocks = -1; ++ } else { ++ group_blocks = group_size / 16; ++ TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, ++ " is not divisible by group_blocks = ", group_blocks); ++ } ++ } ++ ++ int tot_m = prob_m; ++ ++ const int* topk_ids_ptr = (const int*)topk_ids; ++ int* expert_offsets_ptr = (int*)expert_offsets; ++ compute_expert_offsets<<<1, num_experts, 0, stream>>>( ++ topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size); ++ ++ bool do_permute_a = has_act_order; ++ ++ // If we have a full K, then we can run the non-act-order version of Marlin ++ // (since the weight rows are reordered by increasing group ids, and by ++ // having a full K, we have full original groups) ++ if (is_k_full) { ++ has_act_order = false; ++ } ++ ++ int pack_factor = 32 / q_type.size_bits(); ++ ++ for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { ++ const int4* A_ptr = (const int4*)A; ++ int4* a_tmp_ptr = (int4*)a_tmp; ++ const int4* B_ptr = ++ (const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx; ++ int4* C_ptr = (int4*)C; ++ const float* topk_weights_ptr = (const float*)topk_weights; ++ const int* sorted_ids_ptr = (const int*)sorted_ids; ++ const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx; ++ const int4* zp_ptr = ++ (const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx; ++ const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; ++ const int* perm_ptr = (const int*)perm + prob_k * expert_idx; ++ int* locks = (int*)workspace; ++ ++ if (do_permute_a) { ++ // Permute A columns ++ int topk_rows = replicate_input ? tot_m : tot_m * topk; ++ int block_rows = ceildiv(topk_rows, blocks); ++ permute_cols_kernel<<>>( ++ A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows); ++ A_ptr = a_tmp_ptr; ++ } ++ ++ int tot_m_blocks = ceildiv(tot_m, 16); ++ for (int m_block = 0; m_block < tot_m_blocks; ++ m_block += 4 * exec_cfg.max_m_blocks) { ++ if (false) { ++ } ++ CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) ++ CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) ++ CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4) ++ else { ++ TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + ++ str(prob_n) + ", " + str(prob_k) + "]" + ++ ", has_act_order = " + str(has_act_order) + ++ ", num_groups = " + str(num_groups) + ++ ", group_size = " + str(group_size) + ++ ", thread_n_blocks = " + str(thread_n_blocks) + ++ ", thread_k_blocks = " + str(thread_k_blocks)); ++ } ++ } ++ } ++} ++ ++} // namespace marlin_moe ++ ++torch::Tensor marlin_gemm_moe( ++ const torch::Tensor& a, const torch::Tensor& b_q_weights, ++ const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, ++ const torch::Tensor& topk_ids, const torch::Tensor& b_scales, ++ torch::Tensor& b_zeros, const torch::Tensor& g_idx, ++ const torch::Tensor& perm, torch::Tensor& workspace, ++ vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n, ++ int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, ++ int64_t moe_block_size, bool replicate_input, bool apply_weights) { ++ vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); ++ bool has_zp = b_zeros.size(1) != 0; ++ if (has_zp) { ++ TORCH_CHECK( ++ b_q_type == vllm::kU4, ++ "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); ++ } else { ++ TORCH_CHECK( ++ b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, ++ "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str()); ++ } ++ ++ int pack_factor = 32 / b_q_type.size_bits(); ++ ++ int max_par = 4; ++ ++ int dev = a.get_device(); ++ ++ auto options_dtype = ++ torch::TensorOptions().dtype(a.dtype()).device(a.device()); ++ auto options_int = ++ torch::TensorOptions().dtype(torch::kInt).device(a.device()); ++ torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype); ++ torch::Tensor a_tmp = ++ replicate_input ? torch::zeros({size_m, size_k}, options_dtype) ++ : torch::zeros({size_m, topk, size_k}, options_dtype); ++ torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int); ++ ++ // thread_k: `k` size of a thread_tile in `weights` (can usually be left as ++ // auto -1) ++ int thread_k = -1; ++ // thread_n: `n` size of a thread_tile in `weights` (can usually be left as ++ // auto -1) ++ int thread_n = -1; ++ // sms: number of SMs to use for the kernel (can usually be left as auto -1) ++ int sms = -1; ++ ++ // Detect groupsize and act_order ++ int num_groups = -1; ++ int group_size = -1; ++ bool has_act_order = g_idx.size(1) != 0; ++ ++ int b_rank = b_scales.sizes().size(); ++ TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3"); ++ TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), ++ " is not size_n = ", size_n); ++ num_groups = b_scales.size(1); ++ ++ TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order), ++ "if is_k_full is false, has_act_order must be true"); ++ ++ if (has_act_order) { ++ if (is_k_full) { ++ TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); ++ TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, ++ ", is not divisible by num_groups = ", num_groups); ++ group_size = size_k / num_groups; ++ } else { ++ group_size = 0; ++ } ++ ++ } else { ++ if (num_groups > 1) { ++ TORCH_CHECK( ++ size_k % num_groups == 0, "size_k = ", size_k, ++ ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); ++ group_size = size_k / num_groups; ++ } else { ++ group_size = -1; ++ } ++ } ++ ++ // Verify b_zeros ++ if (has_zp) { ++ int rank = b_zeros.sizes().size(); ++ TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3"); ++ TORCH_CHECK(b_zeros.size(1) == num_groups, ++ "b_zeros dim 1 = ", b_zeros.size(1), ++ " is not num_groups = ", num_groups); ++ TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor, ++ "b_zeros dim 2 = ", b_zeros.size(2), ++ " is not size_n / pack_factor = ", size_n / pack_factor); ++ } ++ ++ marlin_moe::marlin_mm_moe( ++ a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), ++ topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), ++ b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), ++ expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), ++ b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, ++ num_experts, topk, moe_block_size, dev, ++ at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, ++ replicate_input, apply_weights); ++ return c; ++} ++ ++TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { ++ m.impl("marlin_gemm_moe", &marlin_gemm_moe); ++} +diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu +new file mode 100644 +index 0000000..24341d6 +--- /dev/null ++++ b/csrc/moe/moe_align_sum_kernels.cu +@@ -0,0 +1,324 @@ ++#include ++#include ++#include ++ ++#include ++#include ++ ++#include "../cuda_compat.h" ++#include "../dispatch_utils.h" ++ ++#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) ++ ++namespace vllm { ++namespace moe { ++ ++namespace { ++__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, ++ int32_t col) { ++ // don't worry about overflow because num_experts is relatively small ++ return row * total_col + col; ++} ++} // namespace ++ ++template ++__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, ++ int32_t* sorted_token_ids, ++ int32_t* expert_ids, ++ int32_t* total_tokens_post_pad, ++ int32_t num_experts, ++ int32_t block_size, size_t numel) { ++ const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); ++ const size_t start_idx = threadIdx.x * tokens_per_thread; ++ ++ extern __shared__ int32_t shared_mem[]; ++ ++ int32_t* tokens_cnts = ++ shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts) ++ int32_t* cumsum = ++ shared_mem + ++ (blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1) ++ ++ for (int i = 0; i < num_experts; ++i) { ++ tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; ++ } ++ ++ /** ++ * In the first step we compute token_cnts[thread_index + 1][expert_index], ++ * which counts how many tokens in the token shard of thread_index are ++ * assigned to expert expert_index. ++ */ ++ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { ++ ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; ++ } ++ ++ __syncthreads(); ++ ++ // For each expert we accumulate the token counts from the different threads. ++ if (threadIdx.x < num_experts) { ++ tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; ++ for (int i = 1; i <= blockDim.x; ++i) { ++ tokens_cnts[index(num_experts, i, threadIdx.x)] += ++ tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; ++ } ++ } ++ ++ __syncthreads(); ++ ++ // We accumulate the token counts of all experts in thread 0. ++ if (threadIdx.x == 0) { ++ cumsum[0] = 0; ++ for (int i = 1; i <= num_experts; ++i) { ++ cumsum[i] = cumsum[i - 1] + ++ CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], ++ block_size) * ++ block_size; ++ } ++ *total_tokens_post_pad = cumsum[num_experts]; ++ } ++ ++ __syncthreads(); ++ ++ /** ++ * For each expert, each thread processes the tokens of the corresponding ++ * blocks and stores the corresponding expert_id for each block. ++ */ ++ if (threadIdx.x < num_experts) { ++ for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; ++ i += block_size) { ++ expert_ids[i / block_size] = threadIdx.x; ++ } ++ } ++ ++ /** ++ * Each thread processes a token shard, calculating the index of each token ++ * after sorting by expert number. Given the example topk_ids = ++ * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, ++ * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a ++ * padding value(preset in python). ++ */ ++ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { ++ int32_t expert_id = topk_ids[i]; ++ /** The cumsum[expert_id] stores the starting index of the tokens that the ++ * expert with expert_id needs to process, and ++ * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens ++ * processed by the expert with expert_id within the current thread's token ++ * shard. ++ */ ++ int32_t rank_post_pad = ++ tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + ++ cumsum[expert_id]; ++ sorted_token_ids[rank_post_pad] = i; ++ ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; ++ } ++} ++ ++// TODO(simon): this is temporarily adapted from ++// https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7 ++// we did this to unblock Deepseek V3 but there should be a better ++// implementation to manage shared memory. ++template ++__global__ void moe_align_block_size_global_mem_kernel( ++ scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, ++ int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, ++ int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) { ++ const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); ++ const size_t start_idx = threadIdx.x * tokens_per_thread; ++ ++ for (int i = 0; i < num_experts; ++i) { ++ tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; ++ } ++ ++ /** ++ * In the first step we compute token_cnts[thread_index + 1][expert_index], ++ * which counts how many tokens in the token shard of thread_index are ++ * assigned to expert expert_index. ++ */ ++ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { ++ ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; ++ } ++ ++ __syncthreads(); ++ ++ // For each expert we accumulate the token counts from the different threads. ++ if (threadIdx.x < num_experts) { ++ tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; ++ for (int i = 1; i <= blockDim.x; ++i) { ++ tokens_cnts[index(num_experts, i, threadIdx.x)] += ++ tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; ++ } ++ } ++ ++ __syncthreads(); ++ ++ // We accumulate the token counts of all experts in thread 0. ++ if (threadIdx.x == 0) { ++ cumsum[0] = 0; ++ for (int i = 1; i <= num_experts; ++i) { ++ cumsum[i] = cumsum[i - 1] + ++ CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], ++ block_size) * ++ block_size; ++ } ++ *total_tokens_post_pad = cumsum[num_experts]; ++ } ++ ++ __syncthreads(); ++ ++ /** ++ * For each expert, each thread processes the tokens of the corresponding ++ * blocks and stores the corresponding expert_id for each block. ++ */ ++ if (threadIdx.x < num_experts) { ++ for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; ++ i += block_size) { ++ expert_ids[i / block_size] = threadIdx.x; ++ } ++ } ++ ++ /** ++ * Each thread processes a token shard, calculating the index of each token ++ * after sorting by expert number. Given the example topk_ids = ++ * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, ++ * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a ++ * padding value(preset in python). ++ */ ++ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { ++ int32_t expert_id = topk_ids[i]; ++ /** The cumsum[expert_id] stores the starting index of the tokens that the ++ * expert with expert_id needs to process, and ++ * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens ++ * processed by the expert with expert_id within the current thread's token ++ * shard. ++ */ ++ int32_t rank_post_pad = ++ tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + ++ cumsum[expert_id]; ++ sorted_token_ids[rank_post_pad] = i; ++ ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; ++ } ++} ++ ++template ++__global__ void moe_sum_kernel( ++ scalar_t* __restrict__ out, // [..., d] ++ const scalar_t* __restrict__ input, // [..., topk, d] ++ const int d) { ++ const int64_t token_idx = blockIdx.x; ++ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { ++ scalar_t x = 0.0; ++#pragma unroll ++ for (int k = 0; k < TOPK; ++k) { ++ x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]); ++ } ++ out[token_idx * d + idx] = x; ++ } ++} ++ ++} // namespace moe ++} // namespace vllm ++ ++void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ++ int64_t block_size, torch::Tensor sorted_token_ids, ++ torch::Tensor experts_ids, ++ torch::Tensor num_tokens_post_pad) { ++ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); ++ ++ // If we have very large number of experts, we can no longer use shared ++ // memory. ++ // TODO(simon): the right solution should be calculating the exact right ++ // amount of shared memory and use that. The num_experts >= 256 is just a ++ // temporary solution to unblock Deepseek V3. ++ if (num_experts >= 256) { ++ VLLM_DISPATCH_INTEGRAL_TYPES( ++ topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { ++ // calc needed amount of shared mem for `tokens_cnts` and `cumsum` ++ // tensors ++ const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); ++ ++ const int32_t mem_tokens_cnts = ++ ((num_experts + 1) * num_experts) * sizeof(int32_t); ++ const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t); ++ // allocate global memory ++ int32_t* tokens_cnts; ++ int32_t* cumsum; ++ cudaMalloc(&tokens_cnts, mem_tokens_cnts); ++ cudaMalloc(&cumsum, mem_cumsum); ++ ++ auto kernel = ++ vllm::moe::moe_align_block_size_global_mem_kernel; ++ kernel<<<1, num_thread, 0, stream>>>( ++ topk_ids.data_ptr(), ++ sorted_token_ids.data_ptr(), ++ experts_ids.data_ptr(), ++ num_tokens_post_pad.data_ptr(), num_experts, block_size, ++ topk_ids.numel(), tokens_cnts, cumsum); ++ cudaFree(tokens_cnts); ++ cudaFree(cumsum); ++ }); ++ } else { ++ VLLM_DISPATCH_INTEGRAL_TYPES( ++ topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { ++ // calc needed amount of shared mem for `tokens_cnts` and `cumsum` ++ // tensors ++ const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); ++ const int32_t shared_mem = ++ ((num_thread + 1) * num_experts + (num_experts + 1)) * ++ sizeof(int32_t); ++ ++ // set dynamic shared mem ++ auto kernel = vllm::moe::moe_align_block_size_kernel; ++ AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( ++ (void*)kernel, shared_mem)); ++ kernel<<<1, num_thread, shared_mem, stream>>>( ++ topk_ids.data_ptr(), ++ sorted_token_ids.data_ptr(), ++ experts_ids.data_ptr(), ++ num_tokens_post_pad.data_ptr(), num_experts, block_size, ++ topk_ids.numel()); ++ }); ++ } ++} ++ ++void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] ++ torch::Tensor& output) // [num_tokens, hidden_size] ++{ ++ const int hidden_size = input.size(-1); ++ const int num_tokens = output.numel() / hidden_size; ++ const int topk = input.size(1); ++ ++ dim3 grid(num_tokens); ++ dim3 block(std::min(hidden_size, 1024)); ++ const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); ++ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); ++ ++ switch (topk) { ++ case 2: ++ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { ++ vllm::moe::moe_sum_kernel<<>>( ++ output.data_ptr(), input.data_ptr(), ++ hidden_size); ++ }); ++ break; ++ ++ case 3: ++ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { ++ vllm::moe::moe_sum_kernel<<>>( ++ output.data_ptr(), input.data_ptr(), ++ hidden_size); ++ }); ++ break; ++ ++ case 4: ++ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { ++ vllm::moe::moe_sum_kernel<<>>( ++ output.data_ptr(), input.data_ptr(), ++ hidden_size); ++ }); ++ break; ++ ++ default: ++ at::sum_out(output, input, 1); ++ break; ++ } ++} +diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h +index a01be3e..596cc0a 100644 +--- a/csrc/moe/moe_ops.h ++++ b/csrc/moe/moe_ops.h +@@ -1,9 +1,14 @@ + #pragma once + +-#include ++#include + +-void topk_softmax( +- torch::Tensor& topk_weights, +- torch::Tensor& topk_indices, +- torch::Tensor& token_expert_indices, +- torch::Tensor& gating_output); ++void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, ++ torch::Tensor& token_expert_indices, ++ torch::Tensor& gating_output); ++ ++void moe_sum(torch::Tensor& input, torch::Tensor& output); ++ ++void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ++ int64_t block_size, torch::Tensor sorted_token_ids, ++ torch::Tensor experts_ids, ++ torch::Tensor num_tokens_post_pad); +diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu +index 8c65f40..de9747b 100644 +--- a/csrc/moe/topk_softmax_kernels.cu ++++ b/csrc/moe/topk_softmax_kernels.cu +@@ -16,18 +16,25 @@ + * See the License for the specific language governing permissions and + * limitations under the License. + */ +-#include ++#include + #include + #include ++#include "../cuda_compat.h" + +-#include +-#include ++#ifndef USE_ROCM ++ #include ++ #include ++#else ++ #include ++ #include ++#endif ++ ++#define MAX(a, b) ((a) > (b) ? (a) : (b)) ++#define MIN(a, b) ((a) < (b) ? (a) : (b)) + + namespace vllm { + namespace moe { + +-static constexpr int WARP_SIZE = 32; +- + /// Aligned array type + template < + typename T, +@@ -265,7 +272,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ + #pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { +- thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); ++ thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW)); + } + + // From this point, thread max in all the threads have the max within the row. +@@ -282,7 +289,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ + #pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { +- row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); ++ row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW); + } + + // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables +@@ -332,8 +339,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ + #pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { +- float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); +- int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); ++ float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW); ++ int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW); + + // We want lower indices to "win" in every thread so we break ties this way + if (other_max > max_val || (other_max == max_val && other_expert < expert)) +@@ -383,7 +390,7 @@ struct TopkConstants + { + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); +- static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); ++ static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; +@@ -396,7 +403,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f + { + static constexpr std::size_t MAX_BYTES_PER_LDG = 16; + +- static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); ++ static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); + using Constants = detail::TopkConstants; + static constexpr int VPT = Constants::VPT; + static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; +diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp +new file mode 100644 +index 0000000..f3a558c +--- /dev/null ++++ b/csrc/moe/torch_bindings.cpp +@@ -0,0 +1,39 @@ ++#include "core/registration.h" ++#include "moe_ops.h" ++ ++TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ++ // Apply topk softmax to the gating outputs. ++ m.def( ++ "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " ++ "token_expert_indices, Tensor gating_output) -> ()"); ++ m.impl("topk_softmax", torch::kCUDA, &topk_softmax); ++ ++ // Calculate the result of moe by summing up the partial results ++ // from all selected experts. ++ m.def("moe_sum(Tensor! input, Tensor output) -> ()"); ++ m.impl("moe_sum", torch::kCUDA, &moe_sum); ++ ++ // Aligning the number of tokens to be processed by each expert such ++ // that it is divisible by the block size. ++ m.def( ++ "moe_align_block_size(Tensor topk_ids, int num_experts," ++ " int block_size, Tensor! sorted_token_ids," ++ " Tensor! experts_ids," ++ " Tensor! num_tokens_post_pad) -> ()"); ++ m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); ++ ++#ifndef USE_ROCM ++ m.def( ++ "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " ++ "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " ++ "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, " ++ "int b_q_type, SymInt size_m, " ++ "SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int " ++ "topk, " ++ "int moe_block_size, bool replicate_input, bool apply_weights)" ++ " -> Tensor"); ++ // conditionally compiled so impl registration is in source file ++#endif ++} ++ ++REGISTER_EXTENSION(TORCH_EXTENSION_NAME) +diff --git a/csrc/ops.h b/csrc/ops.h +index 9541adc..9efd9b0 100644 +--- a/csrc/ops.h ++++ b/csrc/ops.h +@@ -1,206 +1,245 @@ + #pragma once + +-#include ++#include ++#include ++ ++#include "core/scalar_type.hpp" ++ ++#include ++ ++torch::Tensor weak_ref_tensor(torch::Tensor& tensor) { ++ // Ensure tensor is on CUDA ++ if (!tensor.is_cuda()) { ++ throw std::runtime_error("Tensor must be on CUDA device"); ++ } ++ ++ // Get the raw data pointer ++ void* data_ptr = tensor.data_ptr(); ++ ++ // Get tensor sizes and strides ++ std::vector sizes = tensor.sizes().vec(); ++ std::vector strides = tensor.strides().vec(); ++ ++ // Get tensor options (dtype, device) ++ auto options = tensor.options(); ++ ++ // Create a new tensor from the raw data pointer ++ auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options); ++ ++ return new_tensor; ++} + + void paged_attention_v1( +- torch::Tensor& out, +- torch::Tensor& query, +- torch::Tensor& key_cache, +- torch::Tensor& value_cache, +- int num_kv_heads, +- float scale, +- torch::Tensor& block_tables, +- torch::Tensor& seq_lens, +- int block_size, +- int max_seq_len, +- const c10::optional& alibi_slopes, +- const std::string& kv_cache_dtype, +- float kv_scale); ++ torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, ++ torch::Tensor& value_cache, int64_t num_kv_heads, double scale, ++ torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, ++ int64_t max_seq_len, const std::optional& alibi_slopes, ++ const std::string& kv_cache_dtype, double k_scale, double v_scale, ++ const int64_t tp_rank, const int64_t blocksparse_local_blocks, ++ const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, ++ const int64_t blocksparse_head_sliding_step); + + void paged_attention_v2( +- torch::Tensor& out, +- torch::Tensor& exp_sums, +- torch::Tensor& max_logits, +- torch::Tensor& tmp_out, +- torch::Tensor& query, +- torch::Tensor& key_cache, +- torch::Tensor& value_cache, +- int num_kv_heads, +- float scale, +- torch::Tensor& block_tables, +- torch::Tensor& seq_lens, +- int block_size, +- int max_seq_len, +- const c10::optional& alibi_slopes, +- const std::string& kv_cache_dtype, +- float kv_scale); +- +-void rms_norm( +- torch::Tensor& out, +- torch::Tensor& input, +- torch::Tensor& weight, +- float epsilon); +- +-void fused_add_rms_norm( +- torch::Tensor& input, +- torch::Tensor& residual, +- torch::Tensor& weight, +- float epsilon); +- +-void rotary_embedding( +- torch::Tensor& positions, +- torch::Tensor& query, +- torch::Tensor& key, +- int head_size, +- torch::Tensor& cos_sin_cache, +- bool is_neox); +- +-void batched_rotary_embedding( +- torch::Tensor& positions, +- torch::Tensor& query, +- torch::Tensor& key, +- int head_size, +- torch::Tensor& cos_sin_cache, +- bool is_neox, +- int rot_dim, +- torch::Tensor& cos_sin_cache_offsets); +- +-void silu_and_mul( +- torch::Tensor& out, +- torch::Tensor& input); +- +-void gelu_and_mul( +- torch::Tensor& out, +- torch::Tensor& input); +- +-void gelu_tanh_and_mul( +- torch::Tensor& out, +- torch::Tensor& input); +- +-void gelu_new( +- torch::Tensor& out, +- torch::Tensor& input); +- +-void gelu_fast( +- torch::Tensor& out, +- torch::Tensor& input); ++ torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, ++ torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, ++ torch::Tensor& value_cache, int64_t num_kv_heads, double scale, ++ torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, ++ int64_t max_seq_len, const std::optional& alibi_slopes, ++ const std::string& kv_cache_dtype, double k_scale, double v_scale, ++ const int64_t tp_rank, const int64_t blocksparse_local_blocks, ++ const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, ++ const int64_t blocksparse_head_sliding_step); ++ ++void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, ++ double epsilon); ++ ++void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, ++ torch::Tensor& weight, double epsilon); ++ ++void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, ++ torch::Tensor& weight, torch::Tensor& scale, ++ double epsilon); ++ ++void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out, ++ torch::Tensor& input, ++ torch::Tensor& residual, ++ torch::Tensor& weight, ++ torch::Tensor& scale, double epsilon); ++ ++void rms_norm_dynamic_per_token_quant(torch::Tensor& out, ++ torch::Tensor const& input, ++ torch::Tensor const& weight, ++ torch::Tensor& scales, ++ double const epsilon, ++ std::optional scale_ub, ++ std::optional residual); ++ ++void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, ++ torch::Tensor& key, int64_t head_size, ++ torch::Tensor& cos_sin_cache, bool is_neox); ++ ++void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, ++ torch::Tensor& key, int64_t head_size, ++ torch::Tensor& cos_sin_cache, bool is_neox, ++ int64_t rot_dim, ++ torch::Tensor& cos_sin_cache_offsets); ++ ++void silu_and_mul(torch::Tensor& out, torch::Tensor& input); ++ ++void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); ++ ++void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); ++ ++void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input, ++ double threshold); ++ ++void gelu_new(torch::Tensor& out, torch::Tensor& input); ++ ++void gelu_fast(torch::Tensor& out, torch::Tensor& input); ++ ++void gelu_quick(torch::Tensor& out, torch::Tensor& input); ++ ++void advance_step_flashattn(int64_t num_seqs, int64_t num_queries, ++ int64_t block_size, torch::Tensor& input_tokens, ++ torch::Tensor& sampled_token_ids, ++ torch::Tensor& input_positions, ++ torch::Tensor& seq_lens, ++ torch::Tensor& slot_mapping, ++ torch::Tensor& block_tables); ++ ++void advance_step_flashinfer( ++ int64_t num_seqs, int64_t num_queries, int64_t block_size, ++ torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, ++ torch::Tensor& input_positions, torch::Tensor& seq_lens, ++ torch::Tensor& slot_mapping, torch::Tensor& block_tables, ++ torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, ++ torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds); + + #ifndef USE_ROCM +-torch::Tensor aqlm_gemm( +- const torch::Tensor& input, +- const torch::Tensor& codes, +- const torch::Tensor& codebooks, +- const torch::Tensor& scales, +- const torch::Tensor& codebook_partition_sizes, +- const std::optional& bias +-); ++torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, ++ const torch::Tensor& codebooks, ++ const torch::Tensor& scales, ++ const std::vector& codebook_partition_sizes, ++ const std::optional& bias); + + torch::Tensor aqlm_dequant( +- const torch::Tensor& codes, +- const torch::Tensor& codebooks, +- const torch::Tensor& codebook_partition_sizes +-); +- +-torch::Tensor awq_gemm( +- torch::Tensor _in_feats, +- torch::Tensor _kernel, +- torch::Tensor _scaling_factors, +- torch::Tensor _zeros, +- int split_k_iters); +- +-torch::Tensor awq_dequantize( +- torch::Tensor _kernel, +- torch::Tensor _scaling_factors, +- torch::Tensor _zeros, +- int split_k_iters, +- int thx, +- int thy); +- +-torch::Tensor marlin_gemm( +- torch::Tensor& a, +- torch::Tensor& b_q_weight, +- torch::Tensor& b_scales, +- torch::Tensor& workspace, +- int64_t size_m, +- int64_t size_n, +- int64_t size_k); +- +-torch::Tensor gptq_marlin_gemm( +- torch::Tensor &a, +- torch::Tensor &b_q_weight, +- torch::Tensor &b_scales, +- torch::Tensor &g_idx, +- torch::Tensor &perm, +- torch::Tensor &workspace, +- int64_t num_bits, +- int64_t size_m, +- int64_t size_n, +- int64_t size_k, +- bool is_k_full); +- +-torch::Tensor gptq_marlin_repack( +- torch::Tensor &b_q_weight, +- torch::Tensor &perm, +- int64_t size_k, +- int64_t size_n, +- int64_t num_bits); ++ const torch::Tensor& codes, const torch::Tensor& codebooks, ++ const std::vector& codebook_partition_sizes); ++ ++torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, ++ torch::Tensor _scaling_factors, torch::Tensor _zeros, ++ int64_t split_k_iters); ++ ++torch::Tensor awq_dequantize(torch::Tensor _kernel, ++ torch::Tensor _scaling_factors, ++ torch::Tensor _zeros, int64_t split_k_iters, ++ int64_t thx, int64_t thy); ++ ++torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm); ++#endif ++ ++torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, ++ int64_t n); ++ ++torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, ++ int64_t type, int64_t row); ++ ++torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, ++ int64_t row); ++ ++#ifndef USE_ROCM ++bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); ++ ++void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, ++ torch::Tensor const& b, torch::Tensor const& a_scales, ++ torch::Tensor const& b_scales, ++ std::optional const& bias); ++ ++void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, ++ torch::Tensor const& b, ++ torch::Tensor const& a_scales, ++ torch::Tensor const& b_scales, ++ torch::Tensor const& azp_adj, ++ std::optional const& azp, ++ std::optional const& bias); ++ ++bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability); ++ ++void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, ++ torch::Tensor const& b, torch::Tensor const& e, ++ torch::Tensor const& a_scales, ++ torch::Tensor const& b_scales, ++ std::optional const& bias); ++ ++bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed, ++ torch::Tensor& e, torch::Tensor const& a); + #endif + +-void squeezellm_gemm( +- torch::Tensor vec, +- torch::Tensor mat, +- torch::Tensor mul, +- torch::Tensor lookup_table); +- +-torch::Tensor gptq_gemm( +- torch::Tensor a, +- torch::Tensor b_q_weight, +- torch::Tensor b_gptq_qzeros, +- torch::Tensor b_gptq_scales, +- torch::Tensor b_g_idx, +- bool use_exllama, +- int bit); +- +-void gptq_shuffle( +- torch::Tensor q_weight, +- torch::Tensor q_perm, +- int bit); +- +-void static_scaled_fp8_quant( +- torch::Tensor& out, +- torch::Tensor& input, +- torch::Tensor& scale); +- +-void dynamic_scaled_fp8_quant( +- torch::Tensor& out, +- torch::Tensor& input, +- torch::Tensor& scale); +- +-void moe_align_block_size( +- torch::Tensor topk_ids, +- int num_experts, +- int block_size, +- torch::Tensor sorted_token_ids, +- torch::Tensor experts_ids, +- torch::Tensor num_tokens_post_pad); ++void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, ++ torch::Tensor const& scale, ++ std::optional const& azp); ++ ++void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, ++ torch::Tensor& scales, ++ std::optional const& azp); ++ ++torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, ++ torch::Tensor b_gptq_qzeros, ++ torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, ++ bool use_exllama, int64_t bit); ++ ++void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); ++ ++void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, ++ torch::Tensor const& scale); ++ ++void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, ++ torch::Tensor& scale); ++ ++void dynamic_per_token_scaled_fp8_quant( ++ torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, ++ std::optional const& scale_ub); ++ ++void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, ++ const torch::Tensor& A, const torch::Tensor& B, ++ const torch::Tensor& C, ++ const std::optional& D_, ++ const std::optional& z_, ++ const std::optional& delta_bias_, ++ bool delta_softplus, ++ const std::optional& query_start_loc, ++ const std::optional& cache_indices, ++ const std::optional& has_initial_state, ++ const torch::Tensor& ssm_states, int64_t pad_slot_id); ++ ++void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state, ++ const at::Tensor& weight, ++ const std::optional& bias_, ++ bool silu_activation, ++ const std::optional& cache_seqlens_, ++ const std::optional& conv_state_indices_, ++ int64_t pad_slot_id); ++ ++void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, ++ const std::optional& bias_, ++ const std::optional& conv_states, ++ const std::optional& query_start_loc, ++ const std::optional& cache_indices, ++ const std::optional& has_initial_state, ++ bool silu_activation, int64_t pad_slot_id); + + #ifndef USE_ROCM +-using fptr_t = uint64_t; +-fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, +- const std::vector &handles, +- const std::vector &offsets, int rank, +- bool full_nvlink); +-bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, +- bool full_nvlink); +-void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); +-void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, +- torch::Tensor &out); ++using fptr_t = int64_t; ++fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, ++ torch::Tensor& rank_data, int64_t rank, bool full_nvlink); ++void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, ++ fptr_t reg_buffer, int64_t reg_buffer_sz_bytes); + void dispose(fptr_t _fa); +-int meta_size(); +-void register_buffer(fptr_t _fa, torch::Tensor &t, +- const std::vector &handles, +- const std::vector &offsets); +-std::pair, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); +-void register_graph_buffers(fptr_t _fa, const std::vector &handles, +- const std::vector> &offsets); ++int64_t meta_size(); ++void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs); ++std::tuple, std::vector> ++get_graph_buffer_ipc_meta(fptr_t _fa); ++void register_graph_buffers(fptr_t _fa, ++ const std::vector>& handles, ++ const std::vector>& offsets); + #endif +diff --git a/csrc/permute_cols.cu b/csrc/permute_cols.cu +new file mode 100644 +index 0000000..f51fa73 +--- /dev/null ++++ b/csrc/permute_cols.cu +@@ -0,0 +1,88 @@ ++#include ++ ++#include ++#include ++ ++#include ++ ++static constexpr int default_threads = 256; ++static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } ++ ++// For a given "a" of size [M,K] performs a permutation of the K columns based ++// on the given "perm" indices. ++// Currently only supports 16bit types (since we permute half types) ++__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, ++ int const* __restrict__ perm_int_ptr, ++ int4* __restrict__ out_int4_ptr, int size_m, ++ int size_k, int block_rows) { ++ int start_row = block_rows * blockIdx.x; ++ int finish_row = start_row + block_rows; ++ if (finish_row > size_m) { ++ finish_row = size_m; ++ } ++ int cur_block_rows = std::max(finish_row - start_row, 0); ++ ++ int row_stride = size_k * sizeof(half) / 16; ++ ++ auto permute_row = [&](int row) { ++ int iters = size_k / default_threads; ++ int rest = size_k % default_threads; ++ ++ int offset = row * row_stride; ++ ++ half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); ++ half* out_half = reinterpret_cast(out_int4_ptr + offset); ++ ++ int base_k = 0; ++ ++ for (int i = 0; i < iters; i++) { ++ int cur_k = base_k + threadIdx.x; ++ int src_pos = perm_int_ptr[cur_k]; ++ ++ out_half[cur_k] = a_row_half[src_pos]; ++ ++ base_k += default_threads; ++ } ++ ++ if (rest) { ++ if (threadIdx.x < rest) { ++ int cur_k = base_k + threadIdx.x; ++ int src_pos = perm_int_ptr[cur_k]; ++ ++ out_half[cur_k] = a_row_half[src_pos]; ++ } ++ } ++ }; ++ ++ for (int i = 0; i < cur_block_rows; i++) { ++ int cur_row = start_row + i; ++ if (cur_row < size_m) { ++ permute_row(cur_row); ++ } ++ } ++} ++ ++// More efficient version of A[..., perm] ++// taken from gptq_marlin.cu ++torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) { ++ const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); ++ auto dev = A.get_device(); ++ auto stream = at::cuda::getCurrentCUDAStream(dev); ++ ++ TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16, ++ "Currently only 16bit types are supported"); ++ TORCH_CHECK(A.is_contiguous(), "A must be contiguous"); ++ TORCH_CHECK(A.size(-1) % 8 == 0, ++ "A columns must be a multiple of 8 (128bits)"); ++ auto A_2d = A.view({-1, A.size(-1)}); ++ ++ torch::Tensor D = torch::empty_like(A); ++ int sms; ++ cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); ++ int block_rows = div_ceil(A_2d.size(0), sms); ++ permute_cols_kernel<<>>( ++ reinterpret_cast(A_2d.const_data_ptr()), ++ perm.const_data_ptr(), reinterpret_cast(D.mutable_data_ptr()), ++ A_2d.size(0), A_2d.size(1), block_rows); ++ return D; ++} +\ No newline at end of file +diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu +index d80cb69..97184a8 100644 +--- a/csrc/pos_encoding_kernels.cu ++++ b/csrc/pos_encoding_kernels.cu +@@ -1,4 +1,4 @@ +-#include ++#include + #include + #include + +@@ -7,14 +7,10 @@ + + namespace vllm { + +-template ++template + inline __device__ void apply_token_rotary_embedding( +- scalar_t* __restrict__ arr, +- const scalar_t* __restrict__ cos_ptr, +- const scalar_t* __restrict__ sin_ptr, +- int rot_offset, +- int embed_dim) +-{ ++ scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr, ++ const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) { + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { +@@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding( + arr[y_index] = y * cos + x * sin; + } + +-template ++template + inline __device__ void apply_rotary_embedding( +- scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] +- scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] +- const scalar_t* cache_ptr, +- const int head_size, +- const int num_heads, +- const int num_kv_heads, +- const int rot_dim, +- const int token_idx, +- const int64_t query_stride, +- const int64_t key_stride) +-{ ++ scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, ++ // head_size] or [num_tokens, num_heads, ++ // head_size] ++ scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, ++ // head_size] or [num_tokens, num_kv_heads, ++ // head_size] ++ const scalar_t* cache_ptr, const int head_size, const int num_heads, ++ const int num_kv_heads, const int rot_dim, const int token_idx, ++ const int64_t query_stride, const int64_t key_stride) { + const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; +@@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding( + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * query_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; +- apply_token_rotary_embedding(query + token_head, cos_ptr, +- sin_ptr, rot_offset, embed_dim); ++ apply_token_rotary_embedding( ++ query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } + + const int nk = num_kv_heads * embed_dim; +@@ -68,62 +62,74 @@ inline __device__ void apply_rotary_embedding( + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; +- apply_token_rotary_embedding(key + token_head, cos_ptr, +- sin_ptr, rot_offset, embed_dim); ++ apply_token_rotary_embedding( ++ key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } + } + +-template ++template + __global__ void rotary_embedding_kernel( +- const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] +- scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] +- scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] +- const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] +- const int rot_dim, +- const int64_t query_stride, +- const int64_t key_stride, +- const int num_heads, +- const int num_kv_heads, +- const int head_size) { ++ const int64_t* __restrict__ positions, // [batch_size, seq_len] or ++ // [num_tokens] ++ scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, ++ // head_size] or [num_tokens, num_heads, ++ // head_size] ++ scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, ++ // head_size] or [num_tokens, num_kv_heads, ++ // head_size] ++ const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // ++ // 2] ++ const int rot_dim, const int64_t query_stride, const int64_t key_stride, ++ const int num_heads, const int num_kv_heads, const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + +- apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); ++ apply_rotary_embedding( ++ query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, ++ token_idx, query_stride, key_stride); + } + +-template ++template + __global__ void batched_rotary_embedding_kernel( +- const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] +- scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] +- scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] +- const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] +- const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens] +- const int rot_dim, +- const int64_t query_stride, +- const int64_t key_stride, +- const int num_heads, +- const int num_kv_heads, +- const int head_size) { ++ const int64_t* __restrict__ positions, // [batch_size, seq_len] or ++ // [num_tokens] ++ scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, ++ // head_size] or [num_tokens, num_heads, ++ // head_size] ++ scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, ++ // head_size] or [num_tokens, num_kv_heads, ++ // head_size] ++ const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // ++ // 2] ++ const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] ++ // or [num_tokens] ++ const int rot_dim, const int64_t query_stride, const int64_t key_stride, ++ const int num_heads, const int num_kv_heads, const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; +- const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; ++ const scalar_t* cache_ptr = ++ cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; + +- apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); ++ apply_rotary_embedding( ++ query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, ++ token_idx, query_stride, key_stride); + } + +-} // namespace vllm ++} // namespace vllm + + void rotary_embedding( +- torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] +- torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] +- torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] +- int head_size, +- torch::Tensor& cos_sin_cache, // [max_position, rot_dim] +- bool is_neox) { ++ torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] ++ torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or ++ // [num_tokens, num_heads * head_size] ++ torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or ++ // [num_tokens, num_kv_heads * head_size] ++ int64_t head_size, ++ torch::Tensor& cos_sin_cache, // [max_position, rot_dim] ++ bool is_neox) { + int64_t num_tokens = query.numel() / query.size(-1); + int rot_dim = cos_sin_cache.size(1); + int num_heads = query.size(-1) / head_size; +@@ -132,39 +138,24 @@ void rotary_embedding( + int64_t key_stride = key.stride(-2); + + dim3 grid(num_tokens); +- dim3 block(std::min(num_heads * rot_dim / 2, 512)); ++ dim3 block(std::min(num_heads * rot_dim / 2, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); +- VLLM_DISPATCH_FLOATING_TYPES( +- query.scalar_type(), +- "rotary_embedding", +- [&] { +- if (is_neox) { +- vllm::rotary_embedding_kernel<<>>( +- positions.data_ptr(), +- query.data_ptr(), +- key.data_ptr(), +- cos_sin_cache.data_ptr(), +- rot_dim, +- query_stride, +- key_stride, +- num_heads, +- num_kv_heads, +- head_size); +- } else { +- vllm::rotary_embedding_kernel<<>>( +- positions.data_ptr(), +- query.data_ptr(), +- key.data_ptr(), +- cos_sin_cache.data_ptr(), +- rot_dim, +- query_stride, +- key_stride, +- num_heads, +- num_kv_heads, +- head_size); +- } +- }); ++ VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { ++ if (is_neox) { ++ vllm::rotary_embedding_kernel<<>>( ++ positions.data_ptr(), query.data_ptr(), ++ key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, ++ query_stride, key_stride, num_heads, num_kv_heads, head_size); ++ } else { ++ vllm::rotary_embedding_kernel ++ <<>>( ++ positions.data_ptr(), query.data_ptr(), ++ key.data_ptr(), cos_sin_cache.data_ptr(), ++ rot_dim, query_stride, key_stride, num_heads, num_kv_heads, ++ head_size); ++ } ++ }); + } + + /* +@@ -172,14 +163,15 @@ Batched version of rotary embedding, pack multiple LoRAs together + and process in batched manner. + */ + void batched_rotary_embedding( +- torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] +- torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] +- torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] +- int head_size, +- torch::Tensor& cos_sin_cache, // [max_position, rot_dim] +- bool is_neox, +- int rot_dim, +- torch::Tensor& cos_sin_cache_offsets // [num_tokens] ++ torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] ++ torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or ++ // [num_tokens, num_heads * head_size] ++ torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or ++ // [num_tokens, num_kv_heads * head_size] ++ int64_t head_size, ++ torch::Tensor& cos_sin_cache, // [max_position, rot_dim] ++ bool is_neox, int64_t rot_dim, ++ torch::Tensor& cos_sin_cache_offsets // [num_tokens] + ) { + int64_t num_tokens = cos_sin_cache_offsets.size(0); + int num_heads = query.size(-1) / head_size; +@@ -188,39 +180,24 @@ void batched_rotary_embedding( + int64_t key_stride = key.stride(-2); + + dim3 grid(num_tokens); +- dim3 block(std::min(num_heads * rot_dim / 2, 512)); ++ dim3 block(std::min(num_heads * rot_dim / 2, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); +- VLLM_DISPATCH_FLOATING_TYPES( +- query.scalar_type(), +- "rotary_embedding", +- [&] { +- if (is_neox) { +- vllm::batched_rotary_embedding_kernel<<>>( +- positions.data_ptr(), +- query.data_ptr(), +- key.data_ptr(), +- cos_sin_cache.data_ptr(), +- cos_sin_cache_offsets.data_ptr(), +- rot_dim, +- query_stride, +- key_stride, +- num_heads, +- num_kv_heads, +- head_size); +- } else { +- vllm::batched_rotary_embedding_kernel<<>>( +- positions.data_ptr(), +- query.data_ptr(), +- key.data_ptr(), +- cos_sin_cache.data_ptr(), +- cos_sin_cache_offsets.data_ptr(), +- rot_dim, +- query_stride, +- key_stride, +- num_heads, +- num_kv_heads, +- head_size); +- } +- }); ++ VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { ++ if (is_neox) { ++ vllm::batched_rotary_embedding_kernel ++ <<>>( ++ positions.data_ptr(), query.data_ptr(), ++ key.data_ptr(), cos_sin_cache.data_ptr(), ++ cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, ++ key_stride, num_heads, num_kv_heads, head_size); ++ } else { ++ vllm::batched_rotary_embedding_kernel ++ <<>>( ++ positions.data_ptr(), query.data_ptr(), ++ key.data_ptr(), cos_sin_cache.data_ptr(), ++ cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, ++ key_stride, num_heads, num_kv_heads, head_size); ++ } ++ }); + } +diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu +new file mode 100644 +index 0000000..bd184ee +--- /dev/null ++++ b/csrc/prepare_inputs/advance_step.cu +@@ -0,0 +1,327 @@ ++/* ++ * The goal of this GPU kernel is to advance input tensors on the GPU directly ++ * PR: https://github.com/vllm-project/vllm/pull/6338 ++ * Current restrictions: ++ * 1. Specialized for DraftModelRunner ++ * 2. Supports flash_attn only ++ */ ++ ++#include "advance_step.cuh" ++ ++namespace prepare_inputs { ++ ++// ++template ++__global__ void advance_step_flashattn_kernel( ++ int num_seqs, int num_queries, int block_size, long* input_tokens_ptr, ++ long const* sampled_token_ids_ptr, long* input_positions_ptr, ++ int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr, ++ int64_t const block_tables_stride) { ++ int const n_pad = num_seqs - num_queries; ++ if (n_pad && blockIdx.x == 0) { ++ // Handle cuda graph padding ++ int const offset = num_queries; ++ for (int i = threadIdx.x; i < n_pad; i += blockDim.x) { ++ input_tokens_ptr[offset + i] = 0; ++ input_positions_ptr[offset + i] = 0; ++ slot_mapping_ptr[offset + i] = -1; ++ } ++ } ++ ++ int num_query_blocks = div_ceil(num_queries, num_threads); ++ ++ if (blockIdx.x >= num_query_blocks) { ++ return; ++ } ++ ++ int cur_query_id = blockIdx.x * num_threads + threadIdx.x; ++ ++ if (cur_query_id >= num_queries) { ++ return; ++ } ++ ++ // Update input_tokens ++ input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; ++ ++ int seq_len = seq_lens_ptr[cur_query_id]; ++ int next_seq_len = seq_len + 1; ++ int next_input_pos = next_seq_len - 1; ++ ++ // Update seq_lens ++ seq_lens_ptr[cur_query_id] = next_seq_len; ++ // Update input_positions ++ input_positions_ptr[cur_query_id] = next_input_pos; ++ ++ int const* seq_block_tables_ptr = ++ block_tables_ptr + block_tables_stride * cur_query_id; ++ ++ int block_index = next_input_pos / block_size; ++ int block_offset = next_input_pos % block_size; ++ ++ int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset; ++ // Update slot_mapping ++ slot_mapping_ptr[cur_query_id] = slot_num; ++} ++ ++inline void verify_tensor(std::string const& name, torch::Tensor const& t, ++ int64_t const size_0, int64_t const size_1, ++ c10::ScalarType const type) { ++ bool size_0_cond = true; ++ if (size_0 != -1) { ++ size_0_cond = t.size(0) == size_0; ++ } ++ ++ bool size_1_cond = true; ++ if (size_1 != -1) { ++ size_1_cond = t.size(1) == size_1; ++ } ++ ++ bool is_contiguous = t.is_contiguous(); ++ bool same_type = t.dtype() == type; ++ ++ bool pass = size_0_cond && size_1_cond && is_contiguous && same_type; ++ if (!pass) { ++ TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(), ++ " is_cont = ", t.is_contiguous(), ", type = ", t.dtype(), ++ " is not as expected: shape = [", size_0, ", ", size_1, ++ "], type = ", type); ++ } ++} ++ ++/// each thread processes a block per query ++__global__ void advance_step_flashinfer_kernel( ++ int num_threads, int num_seqs, int num_queries, int block_size, ++ long* input_tokens_ptr, long const* sampled_token_ids_ptr, ++ long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr, ++ int const* block_tables_ptr, int64_t const block_tables_stride, ++ int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) { ++ int num_query_blocks = div_ceil(num_queries, num_threads); ++ ++ if (blockIdx.x < num_query_blocks) { ++ int cur_query_id = blockIdx.x * num_threads + threadIdx.x; ++ ++ if (cur_query_id < num_queries) { ++ // Update input_tokens ++ input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; ++ ++ int seq_len = seq_lens_ptr[cur_query_id]; ++ int next_seq_len = seq_len + 1; ++ int next_input_pos = next_seq_len - 1; ++ ++ // Update seq_lens ++ seq_lens_ptr[cur_query_id] = next_seq_len; ++ // Update input_positions ++ input_positions_ptr[cur_query_id] = next_input_pos; ++ ++ int const* seq_block_tables_ptr = ++ block_tables_ptr + block_tables_stride * cur_query_id; ++ ++ int block_index = next_input_pos / block_size; ++ int block_offset = next_input_pos % block_size; ++ ++ // Update paged_kv_last_page_len ++ paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1; ++ ++ int slot_num = ++ seq_block_tables_ptr[block_index] * block_size + block_offset; ++ // Update slot_mapping ++ slot_mapping_ptr[cur_query_id] = slot_num; ++ block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size); ++ } ++ } ++} ++ ++__global__ void advance_step_flashinfer_indptr_kernel( ++ int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr, ++ int* block_table_bound_ptr) { ++ int idx = blockIdx.x * num_threads + threadIdx.x; ++ // Update paged_kv_indptr ++ if (idx == 0) { ++ paged_kv_indptr_ptr[idx] = 0; ++ } ++ if (idx < num_queries) { ++ int sum = 0; ++ for (int i = 0; i <= idx; ++i) { ++ sum += block_table_bound_ptr[i]; ++ } ++ paged_kv_indptr_ptr[idx + 1] = sum; ++ } ++} ++ ++__global__ void advance_step_flashinfer_indices_kernel( ++ int num_seqs, int num_queries, int const* block_tables_ptr, ++ int64_t const max_num_blocks_per_seq, int* paged_kv_indices_ptr, ++ int* paged_kv_indptr_ptr, int* block_table_bound_ptr) { ++ // note: max_num_blocks_per_seq = block_tables.stride(0) ++ int tid = blockIdx.x * blockDim.x + threadIdx.x; ++ ++ // when cuda graphs are enabled, paged_kv_indptr tensor ++ // has to be updated for the padded queries ++ // tid represents a query# for paged_kv_indptr tensor ++ if (num_queries < tid && tid <= num_seqs) { ++ paged_kv_indptr_ptr[tid] = paged_kv_indptr_ptr[num_queries]; ++ } ++ ++ // each thread processes a block_ptr in block_tables ++ // block_tables shape: [num_queries, max_num_blocks_per_seq] ++ // paged_kv_indices is flattened block_tables. ++ for (int idx = tid; idx < (num_seqs * max_num_blocks_per_seq); ++ idx += (gridDim.x * blockDim.x)) { ++ // block_tables-row = paged_kv_indptr[queryNum] ++ int queryNum = idx / max_num_blocks_per_seq; ++ int col = idx % max_num_blocks_per_seq; ++ if (queryNum < num_queries && col < block_table_bound_ptr[queryNum]) { ++ int indices_arr_idx = paged_kv_indptr_ptr[queryNum] + col; ++ int block_tables_idx = queryNum * max_num_blocks_per_seq + col; ++ paged_kv_indices_ptr[indices_arr_idx] = ++ block_tables_ptr[block_tables_idx]; ++ } ++ } ++} ++ ++void advance_step_flashattn(int num_seqs, int num_queries, int block_size, ++ torch::Tensor& input_tokens, // type: long ++ torch::Tensor& sampled_token_ids, // type: long ++ torch::Tensor& input_positions, // type: long ++ torch::Tensor& seq_lens, // type: int ++ torch::Tensor& slot_mapping, // type: long ++ torch::Tensor& block_tables) { // type: int ++ ++ if (logging) { ++ printf("advance_step_flashattn:\n"); ++ printf(" num_seqs = %d\n", num_seqs); ++ printf(" num_queries = %d\n", num_queries); ++ printf(" block_size = %d\n", block_size); ++ } ++ // Verify all tensors ++ verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); ++ verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, ++ at::kLong); ++ verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); ++ verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); ++ verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); ++ verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); ++ ++ int dev = sampled_token_ids.get_device(); ++ cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); ++ ++ int blocks; ++ cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); ++ ++ advance_step_flashattn_kernel ++ <<>>( ++ num_seqs, num_queries, block_size, ++ reinterpret_cast(input_tokens.data_ptr()), ++ reinterpret_cast(sampled_token_ids.data_ptr()), ++ reinterpret_cast(input_positions.data_ptr()), ++ reinterpret_cast(seq_lens.data_ptr()), ++ reinterpret_cast(slot_mapping.data_ptr()), ++ reinterpret_cast(block_tables.data_ptr()), ++ block_tables.stride(0)); ++} ++ ++void advance_step_flashinfer( ++ int num_seqs, int num_queries, int block_size, ++ torch::Tensor& input_tokens, // type: long ++ torch::Tensor& sampled_token_ids, // type: long ++ torch::Tensor& input_positions, // type: long ++ torch::Tensor& seq_lens, // type: int ++ torch::Tensor& slot_mapping, // type: long ++ torch::Tensor& block_tables, // type: int ++ torch::Tensor& paged_kv_indices, // type: int ++ torch::Tensor& paged_kv_indptr, // type: int ++ torch::Tensor& paged_kv_last_page_len, // type: int ++ torch::Tensor& block_table_bound) { // type: int ++ ++ if (logging) { ++ printf("advance_step_flashinfer:\n"); ++ printf(" num_seqs = %d\n", num_seqs); ++ printf(" num_queries = %d\n", num_queries); ++ printf(" block_size = %d\n", block_size); ++ printf(" block_tables.stride(0) = %zu\n", block_tables.stride(0)); ++ } ++ // Verify all tensors ++ verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); ++ // verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, ++ // at::kLong); ++ verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); ++ verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); ++ verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); ++ verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); ++ ++ verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt); ++ verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt); ++ verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1, ++ at::kInt); ++ ++ verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt); ++ ++ int dev = sampled_token_ids.get_device(); ++ cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); ++ ++ int blocks; ++ int threads; ++ cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); ++ cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev); ++ ++ int block_tables_stride = block_tables.stride(0); ++ TORCH_CHECK((blocks * threads > num_queries), ++ "multi-step: not enough threads to map to num_queries = ", ++ num_queries, " block_tables.stride(0) = ", block_tables.stride(0), ++ " blocks = ", blocks, " max_threads = ", threads); ++ if (logging) { ++ printf("launching kernels with %d blocks and %d threads\n", blocks, ++ threads); ++ } ++ advance_step_flashinfer_kernel<<>>( ++ threads, num_seqs, num_queries, block_size, ++ reinterpret_cast(input_tokens.data_ptr()), ++ reinterpret_cast(sampled_token_ids.data_ptr()), ++ reinterpret_cast(input_positions.data_ptr()), ++ reinterpret_cast(seq_lens.data_ptr()), ++ reinterpret_cast(slot_mapping.data_ptr()), ++ reinterpret_cast(block_tables.data_ptr()), ++ block_tables.stride(0), ++ reinterpret_cast(paged_kv_last_page_len.data_ptr()), ++ reinterpret_cast(block_table_bound.data_ptr())); ++ ++ advance_step_flashinfer_indptr_kernel<<>>( ++ threads, num_seqs, num_queries, ++ reinterpret_cast(paged_kv_indptr.data_ptr()), ++ reinterpret_cast(block_table_bound.data_ptr())); ++ ++ advance_step_flashinfer_indices_kernel<<>>( ++ num_seqs, num_queries, ++ reinterpret_cast(block_tables.data_ptr()), ++ block_tables.stride(0), ++ reinterpret_cast(paged_kv_indices.data_ptr()), ++ reinterpret_cast(paged_kv_indptr.data_ptr()), ++ reinterpret_cast(block_table_bound.data_ptr())); ++} ++ ++} // namespace prepare_inputs ++ ++void advance_step_flashattn(int64_t num_seqs, int64_t num_queries, ++ int64_t block_size, torch::Tensor& input_tokens, ++ torch::Tensor& sampled_token_ids, ++ torch::Tensor& input_positions, ++ torch::Tensor& seq_lens, ++ torch::Tensor& slot_mapping, ++ torch::Tensor& block_tables) { ++ prepare_inputs::advance_step_flashattn( ++ num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, ++ input_positions, seq_lens, slot_mapping, block_tables); ++} ++ ++void advance_step_flashinfer( ++ int64_t num_seqs, int64_t num_queries, int64_t block_size, ++ torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, ++ torch::Tensor& input_positions, torch::Tensor& seq_lens, ++ torch::Tensor& slot_mapping, torch::Tensor& block_tables, ++ torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, ++ torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) { ++ prepare_inputs::advance_step_flashinfer( ++ num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, ++ input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices, ++ paged_kv_indptr, paged_kv_last_page_len, block_table_bound); ++} +diff --git a/csrc/prepare_inputs/advance_step.cuh b/csrc/prepare_inputs/advance_step.cuh +new file mode 100644 +index 0000000..f215746 +--- /dev/null ++++ b/csrc/prepare_inputs/advance_step.cuh +@@ -0,0 +1,19 @@ ++#pragma once ++ ++#include ++ ++#include ++#include ++#include ++#include ++#include ++#include ++ ++namespace prepare_inputs { ++ ++static constexpr int max_threads = 256; ++static constexpr bool logging = false; ++ ++constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } ++ ++} // namespace prepare_inputs +diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu +index 4415316..79cd2c6 100644 +--- a/csrc/quantization/aqlm/gemm_kernels.cu ++++ b/csrc/quantization/aqlm/gemm_kernels.cu +@@ -18,39 +18,35 @@ + #include + #include + #include +-#include ++#include + #include + #include + + #include + #include + +- + namespace vllm { + namespace aqlm { + + __global__ void Code1x16MatVec( +- const int4* __restrict__ A, +- const int4* __restrict__ B, +- int4* __restrict__ C, +- const int4* __restrict__ codebook, +- const int prob_m, +- const int prob_k, +- const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. +- const int codebook_stride // as int4. ++ const int4* __restrict__ A, const int4* __restrict__ B, ++ int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m, ++ const int prob_k, ++ const int4 codebook_a_sizes, // cumulative sizes of A spanning each ++ // codebook, at most 3 long. ++ const int codebook_stride // as int4. + ) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + +- if (pred) +- { +- // advance to the correct codebook, this easy because we only multiply one column of the codebook. ++ if (pred) { ++ // advance to the correct codebook, this easy because we only multiply one ++ // column of the codebook. + auto codebook_size = &codebook_a_sizes.x; +- while (a_gl_rd >= *codebook_size) +- { +- codebook += codebook_stride; +- ++codebook_size; ++ while (a_gl_rd >= *codebook_size) { ++ codebook += codebook_stride; ++ ++codebook_size; + } + } + +@@ -67,8 +63,7 @@ __global__ void Code1x16MatVec( + // We pad shared memory to avoid bank conflicts during reads + __syncthreads(); + for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { +- if (b_gl_rd + i < prob_k / 8) +- sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; ++ if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + } + __syncthreads(); + b_gl_rd += 32 * 8; +@@ -76,22 +71,19 @@ __global__ void Code1x16MatVec( + int b_sh_rd = 9 * (threadIdx.x % 32); + if (pred && a_gl_rd < a_gl_end) { + const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); +- #pragma unroll ++#pragma unroll + for (int i = 0; i < 8; i++) { + uint32_t dec[4]; +- // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't +- // actually help us; this brings > 2x speedup. +- asm volatile ( +- "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" +- : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) +- : "l"((void*) &codebook[enc[i]]) +- ); ++ // We bypass the L1 cache to avoid massive amounts of memory streaming ++ // that doesn't actually help us; this brings > 2x speedup. ++ asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" ++ : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) ++ : "l"((void*)&codebook[enc[i]])); + half2* a = reinterpret_cast(&dec); + half2* b = reinterpret_cast(&sh_b[b_sh_rd]); + half2 res2 = {}; +- #pragma unroll +- for (int j = 0; j < 4; j++) +- res2 = __hfma2(a[j], b[j], res2); ++#pragma unroll ++ for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2); + res += __half2float(res2.x) + __half2float(res2.y); + b_sh_rd++; + } +@@ -100,37 +92,33 @@ __global__ void Code1x16MatVec( + } + + if (pred) { +- #pragma unroll +- for (int i = 16; i > 0; i /= 2) +- res += __shfl_down_sync(0xffffffff, res, i); ++#pragma unroll ++ for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); + if (threadIdx.x % 32 == 0) + reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); + } + } + + __global__ void Code2x8MatVec( +- const int4* __restrict__ A, +- const int4* __restrict__ B, +- int4* __restrict__ C, +- const int4* __restrict__ codebook, +- int prob_m, +- int prob_k, +- const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. +- const int codebook_stride // as int4. ++ const int4* __restrict__ A, const int4* __restrict__ B, ++ int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m, ++ int prob_k, ++ const int4 codebook_a_sizes, // cumulative sizes of A spanning each ++ // codebook, at most 3 long. ++ const int codebook_stride // as int4. + + ) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + +- if (pred) +- { +- // advance to the correct codebook, this easy because we only multiply one column of the codebook. ++ if (pred) { ++ // advance to the correct codebook, this easy because we only multiply one ++ // column of the codebook. + auto codebook_size = &codebook_a_sizes.x; +- while (a_gl_rd >= *codebook_size) +- { +- codebook += codebook_stride; +- ++codebook_size; ++ while (a_gl_rd >= *codebook_size) { ++ codebook += codebook_stride; ++ ++codebook_size; + } + } + +@@ -148,9 +136,8 @@ __global__ void Code2x8MatVec( + + for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { + int4 dec = codebook[i]; +- #pragma unroll +- for (int j = 0; j < 8; j++) +- sh_code[8 * i + (j + lane) % 8] = dec; ++#pragma unroll ++ for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; + } + __syncthreads(); + +@@ -161,8 +148,7 @@ __global__ void Code2x8MatVec( + // We pad shared memory to avoid bank conflicts during reads + __syncthreads(); + for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { +- if (b_gl_rd + i < prob_k / 8) +- sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; ++ if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + } + __syncthreads(); + b_gl_rd += 32 * 8; +@@ -170,13 +156,15 @@ __global__ void Code2x8MatVec( + int b_sh_rd = 9 * (threadIdx.x % 32); + if (pred && a_gl_rd < a_gl_end) { + const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); +- #pragma unroll ++#pragma unroll + for (int i = 0; i < 8; i++) { +- half2* a0 = reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); +- half2* a1 = reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); +- half2* b = reinterpret_cast(&sh_b[b_sh_rd]); ++ half2* a0 = ++ reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); ++ half2* a1 = ++ reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); ++ half2* b = reinterpret_cast(&sh_b[b_sh_rd]); + half2 res2 = {}; +- #pragma unroll ++#pragma unroll + for (int j = 0; j < 4; j++) + res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2); + res += __half2float(res2.x) + __half2float(res2.y); +@@ -187,36 +175,31 @@ __global__ void Code2x8MatVec( + } + + if (pred) { +- #pragma unroll +- for (int i = 16; i > 0; i /= 2) +- res += __shfl_down_sync(0xffffffff, res, i); ++#pragma unroll ++ for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); + if (threadIdx.x % 32 == 0) + reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); + } + } + +- + __global__ void Code1x16Dequant( +- const int4* __restrict__ A, +- int4* __restrict__ C, +- const int4* __restrict__ codebook, +- int prob_m, +- int prob_k, +- const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m. +- const int codebook_stride // as int4 ++ const int4* __restrict__ A, int4* __restrict__ C, ++ const int4* __restrict__ codebook, int prob_m, int prob_k, ++ const int4 codebook_a_sizes, // cumulative sizes of A spanning each ++ // codebook, at most 3 long, sums to m. ++ const int codebook_stride // as int4 + ) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + +- if (pred) +- { +- // advance to the correct codebook, this easy because we only multiply one column of the codebook. ++ if (pred) { ++ // advance to the correct codebook, this easy because we only multiply one ++ // column of the codebook. + auto codebook_size = &codebook_a_sizes.x; +- while (a_gl_rd >= *codebook_size) +- { +- codebook += codebook_stride; +- ++codebook_size; ++ while (a_gl_rd >= *codebook_size) { ++ codebook += codebook_stride; ++ ++codebook_size; + } + } + +@@ -231,17 +214,15 @@ __global__ void Code1x16Dequant( + while (iters--) { + if (pred && a_gl_rd < a_gl_end) { + const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); +- #pragma unroll ++#pragma unroll + for (int i = 0; i < 8; i++) { + int4 chunk; + auto dec = reinterpret_cast(&chunk); +- // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't +- // actually help us; this brings > 2x speedup. +- asm volatile ( +- "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" +- : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) +- : "l"((void*) &codebook[enc[i]]) +- ); ++ // We bypass the L1 cache to avoid massive amounts of memory streaming ++ // that doesn't actually help us; this brings > 2x speedup. ++ asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" ++ : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) ++ : "l"((void*)&codebook[enc[i]])); + + C[a_gl_rd * 8 + i] = chunk; + } +@@ -250,28 +231,25 @@ __global__ void Code1x16Dequant( + } + } + +- + __global__ void Code2x8Dequant( +- const int4* __restrict__ A, +- int4* __restrict__ C, +- const int4* __restrict__ codebook, +- int prob_m, +- int prob_k, +- const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. +- const int codebook_stride // as int4 ++ const int4* __restrict__ A, int4* __restrict__ C, ++ const int4* __restrict__ codebook, int prob_m, int prob_k, ++ const int4 ++ codebook_a_sizes, // cumulative sizes of A spanning each codebook, at ++ // most 3 long, corresponds to cols. ++ const int codebook_stride // as int4 + ) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + +- if (pred) +- { +- // advance to the correct codebook, this easy because we only multiply one column of the codebook. ++ if (pred) { ++ // advance to the correct codebook, this easy because we only multiply one ++ // column of the codebook. + auto codebook_size = &codebook_a_sizes.x; +- while (a_gl_rd >= *codebook_size) +- { +- codebook += codebook_stride; +- ++codebook_size; ++ while (a_gl_rd >= *codebook_size) { ++ codebook += codebook_stride; ++ ++codebook_size; + } + } + +@@ -290,24 +268,23 @@ __global__ void Code2x8Dequant( + + for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { + int4 dec = codebook[i]; +- #pragma unroll +- for (int j = 0; j < 8; j++) +- sh_code[8 * i + (j + lane) % 8] = dec; ++#pragma unroll ++ for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; + } + __syncthreads(); + +- float res = 0; +- + int iters = (prob_k / 8 - 1) / (8 * 32) + 1; + while (iters--) { + if (pred && a_gl_rd < a_gl_end) { + const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); +- #pragma unroll ++#pragma unroll + for (int i = 0; i < 8; i++) { + int4 chunk; +- half2* a0 = reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); +- half2* a1 = reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); +- #pragma unroll ++ half2* a0 = ++ reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); ++ half2* a1 = ++ reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); ++#pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(&chunk)[j] = __hadd2(a0[j], a1[j]); + C[a_gl_rd * 8 + i] = chunk; +@@ -317,22 +294,15 @@ __global__ void Code2x8Dequant( + } + } + +-inline int ceildiv(int a, int b) { +- return (a + b - 1) / b; +-} ++inline int ceildiv(int a, int b) { return (a + b - 1) / b; } + + const int THREAD_M = 16; + +-void code1x16_matvec_cuda( +- const void* __restrict__ A, +- const void* __restrict__ B, +- void* __restrict__ C, +- const void* __restrict__ codebook, +- int prob_m, +- int prob_k, +- const int4 codebook_a_sizes, +- const int codebook_stride +-) { ++void code1x16_matvec_cuda(const void* __restrict__ A, ++ const void* __restrict__ B, void* __restrict__ C, ++ const void* __restrict__ codebook, int prob_m, ++ int prob_k, const int4 codebook_a_sizes, ++ const int codebook_stride) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; +@@ -345,28 +315,16 @@ void code1x16_matvec_cuda( + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); +- Code1x16MatVec<<>>( +- (const int4*) A, +- (const int4*) B, +- (int4*) C, +- (const int4*) codebook, +- prob_m, +- prob_k, +- codebook_a_sizes, +- codebook_stride +- ); ++ Code1x16MatVec<<>>( ++ (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, ++ prob_k, codebook_a_sizes, codebook_stride); + } + +-void code2x8_matvec_cuda( +- const void* __restrict__ A, +- const void* __restrict__ B, +- void* __restrict__ C, +- const void* __restrict__ codebook, +- int prob_m, +- int prob_k, +- const int4 codebook_a_sizes, +- const int codebook_stride +-) { ++void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B, ++ void* __restrict__ C, ++ const void* __restrict__ codebook, int prob_m, ++ int prob_k, const int4 codebook_a_sizes, ++ const int codebook_stride) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; +@@ -379,30 +337,20 @@ void code2x8_matvec_cuda( + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + int shared = 16 * (2 * 256 * 8 + 32 * 9); +- cudaFuncSetAttribute( +- Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared +- ); ++ cudaFuncSetAttribute(Code2x8MatVec, ++ cudaFuncAttributeMaxDynamicSharedMemorySize, shared); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + Code2x8MatVec<<>>( +- (const int4*) A, +- (const int4*) B, +- (int4*) C, +- (const int4*) codebook, +- prob_m, +- prob_k, +- codebook_a_sizes, +- codebook_stride +- ); ++ (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, ++ prob_k, codebook_a_sizes, codebook_stride); + } + + void code1x16_dequant_cuda( +- const void* __restrict__ A, +- void* __restrict__ C, +- const void* __restrict__ codebook, +- int prob_m, +- int prob_k, +- const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. +- const int codebook_stride // as int4. ++ const void* __restrict__ A, void* __restrict__ C, ++ const void* __restrict__ codebook, int prob_m, int prob_k, ++ const int4 codebook_a_sizes, // cumulative sizes of A spanning each ++ // codebook, at most 3 long. ++ const int codebook_stride // as int4. + ) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); +@@ -417,25 +365,21 @@ void code1x16_dequant_cuda( + int threads = 32 * thread_m; + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + Code1x16Dequant<<>>( +- (const int4*) A, +- (int4*) C, +- (const int4*) codebook, +- prob_m, +- prob_k, +- codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. +- codebook_stride // as int4. ++ (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, ++ codebook_a_sizes, // cumulative sizes of A spanning each codebook, at ++ // most 3 long. ++ codebook_stride // as int4. + ); + } + + // Dequantizes the code and codebook into weights. +-void code2x8_dequant_cuda( +- const void* __restrict__ A, +- void* __restrict__ C, +- const void* __restrict__ codebook, +- int prob_m, +- int prob_k, +- const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. +- const int codebook_stride // as int4 ++void code2x8_dequant_cuda( ++ const void* __restrict__ A, void* __restrict__ C, ++ const void* __restrict__ codebook, int prob_m, int prob_k, ++ const int4 ++ codebook_a_sizes, // cumulative sizes of A spanning each codebook, at ++ // most 3 long, corresponds to cols. ++ const int codebook_stride // as int4 + ) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); +@@ -451,74 +395,50 @@ void code2x8_dequant_cuda( + int shared = 16 * (2 * 256 * 8 + 32 * 9); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + +- cudaFuncSetAttribute( +- Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared +- ); ++ cudaFuncSetAttribute(Code2x8Dequant, ++ cudaFuncAttributeMaxDynamicSharedMemorySize, shared); + Code2x8Dequant<<>>( +- (const int4*) A, +- (int4*) C, +- (const int4*) codebook, +- prob_m, +- prob_k, +- codebook_a_sizes, +- codebook_stride +- ); ++ (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, ++ codebook_a_sizes, codebook_stride); + } + +-int codebook_stride(const torch::Tensor& codebooks) +-{ ++int codebook_stride(const torch::Tensor& codebooks) { + return codebooks.stride(0) * codebooks.element_size() / sizeof(int4); + } + + void code1x16_matvec( +- const torch::Tensor& A, +- const torch::Tensor& B, +- torch::Tensor& C, +- const torch::Tensor& codebook, +- const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long. ++ const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C, ++ const torch::Tensor& codebook, ++ const int4 codebook_a_sizes // cumulative sizes of A spanning each ++ // codebook, at most 3 long. + ) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + int prob_m = C.size(0); + int prob_k = B.size(0); + +- code1x16_matvec_cuda( +- A.data_ptr(), +- B.data_ptr(), +- C.data_ptr(), +- codebook.data_ptr(), +- prob_m, +- prob_k, +- codebook_a_sizes, +- codebook_stride(codebook) +- ); ++ code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), ++ codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, ++ codebook_stride(codebook)); + } + +-torch::Tensor code1x16_matmat( +- const torch::Tensor& input, +- const torch::Tensor& codes, +- const torch::Tensor& codebooks, +- const torch::Tensor& scales, +- const int4 codebook_a_sizes, +- const std::optional& bias) { ++torch::Tensor code1x16_matmat(const torch::Tensor& input, ++ const torch::Tensor& codes, ++ const torch::Tensor& codebooks, ++ const torch::Tensor& scales, ++ const int4 codebook_a_sizes, ++ const std::optional& bias) { + auto input_sizes = input.sizes(); + auto out_features = codes.size(0) * codebooks.size(2); + auto flat_input = input.reshape({-1, input.size(-1)}); +- auto flat_output = torch::empty({flat_input.size(0), out_features}, +- torch::TensorOptions() +- .dtype(input.dtype()) +- .device(input.device()) +- ); ++ auto flat_output = torch::empty( ++ {flat_input.size(0), out_features}, ++ torch::TensorOptions().dtype(input.dtype()).device(input.device())); + + for (int i = 0; i < flat_input.size(0); ++i) { + auto input_vec = flat_input.index({i}); + auto output_vec = flat_output.index({i}); +- code1x16_matvec( +- codes.squeeze(2), +- input_vec, +- output_vec, +- codebooks, +- codebook_a_sizes +- ); ++ code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, ++ codebook_a_sizes); + } + flat_output *= scales.flatten().unsqueeze(0); + +@@ -533,55 +453,35 @@ torch::Tensor code1x16_matmat( + return output; + } + +-void code2x8_matvec( +- const torch::Tensor& A, +- const torch::Tensor& B, +- torch::Tensor& C, +- const torch::Tensor& codebook, +- const int4 codebook_a_sizes +-) { ++void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B, ++ torch::Tensor& C, const torch::Tensor& codebook, ++ const int4 codebook_a_sizes) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + int prob_m = C.size(0); + int prob_k = B.size(0); +- code2x8_matvec_cuda( +- A.data_ptr(), +- B.data_ptr(), +- C.data_ptr(), +- codebook.data_ptr(), +- prob_m, +- prob_k, +- codebook_a_sizes, +- 2 * codebook_stride(codebook) +- ); ++ code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), ++ codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, ++ 2 * codebook_stride(codebook)); + } + +-torch::Tensor code2x8_matmat( +- const torch::Tensor& input, +- const torch::Tensor& codes, +- const torch::Tensor& codebooks, +- const torch::Tensor& scales, +- const int4 codebook_a_sizes, +- const std::optional& bias +-) { ++torch::Tensor code2x8_matmat(const torch::Tensor& input, ++ const torch::Tensor& codes, ++ const torch::Tensor& codebooks, ++ const torch::Tensor& scales, ++ const int4 codebook_a_sizes, ++ const std::optional& bias) { + auto input_sizes = input.sizes(); + auto out_features = codes.size(0) * codebooks.size(2); + auto flat_input = input.reshape({-1, input.size(-1)}); +- auto flat_output = torch::empty({flat_input.size(0), out_features}, +- torch::TensorOptions() +- .dtype(input.dtype()) +- .device(input.device()) +- ); ++ auto flat_output = torch::empty( ++ {flat_input.size(0), out_features}, ++ torch::TensorOptions().dtype(input.dtype()).device(input.device())); + + for (int i = 0; i < flat_input.size(0); ++i) { + auto input_vec = flat_input.index({i}); + auto output_vec = flat_output.index({i}); +- code2x8_matvec( +- codes.squeeze(2), +- input_vec, +- output_vec, +- codebooks, +- codebook_a_sizes +- ); ++ code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, ++ codebook_a_sizes); + } + flat_output *= scales.flatten().unsqueeze(0); + if (bias.has_value()) { +@@ -596,66 +496,58 @@ torch::Tensor code2x8_matmat( + } + + // Accumulate the partition sizes. +-int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) +-{ ++int4 accumulate_sizes(const std::vector& codebook_partition_sizes) { + int4 cumulative_sizes; + auto cumulative_size = &cumulative_sizes.x; +- int i = 0; ++ size_t i = 0; + int last = 0; +- assert(codebook_partition_sizes.size(0) <= 4); +- for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) +- { +- *cumulative_size = codebook_partition_sizes[i].item() + last; ++ assert(codebook_partition_sizes.size() <= 4); ++ for (; i < codebook_partition_sizes.size(); ++i, ++cumulative_size) { ++ *cumulative_size = codebook_partition_sizes[i] + last; + last = *cumulative_size; + } + // fill in the rest with unreachable. +- for (; i < 4; ++i, ++cumulative_size) +- { +- *cumulative_size = last*10; ++ for (; i < 4; ++i, ++cumulative_size) { ++ *cumulative_size = last * 10; + } + return cumulative_sizes; + } + +-} // namespace aqlm +-} // namespace vllm +- ++} // namespace aqlm ++} // namespace vllm + +-torch::Tensor aqlm_gemm( +- const torch::Tensor& input, +- const torch::Tensor& codes, +- const torch::Tensor& codebooks, +- const torch::Tensor& scales, +- const torch::Tensor& codebook_partition_sizes, +- const std::optional& bias +-) +-{ +- int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); ++torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, ++ const torch::Tensor& codebooks, ++ const torch::Tensor& scales, ++ const std::vector& codebook_partition_sizes, ++ const std::optional& bias) { ++ int4 cumulative_sizes = ++ vllm::aqlm::accumulate_sizes(codebook_partition_sizes); + +- int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); ++ int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(); + int const entries = codebooks.size(1); + +- if (nbooks == 1 && entries == (1 << 16)) +- { +- return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); ++ if (nbooks == 1 && entries == (1 << 16)) { ++ return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, ++ cumulative_sizes, bias); + } +- if (nbooks == 2 && entries == (1 << 8)) +- { +- return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); ++ if (nbooks == 2 && entries == (1 << 8)) { ++ return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, ++ cumulative_sizes, bias); + } + +- TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") ++ TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, ++ " entries is not currently supported.") + return {}; + } + + torch::Tensor aqlm_dequant( +- const torch::Tensor& codes, +- const torch::Tensor& codebooks, +- const torch::Tensor& codebook_partition_sizes +-) +-{ +- int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); +- +- int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); ++ const torch::Tensor& codes, const torch::Tensor& codebooks, ++ const std::vector& codebook_partition_sizes) { ++ int4 cumulative_sizes = ++ vllm::aqlm::accumulate_sizes(codebook_partition_sizes); ++ ++ int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(); + int const entries = codebooks.size(1); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(codes)); +@@ -665,48 +557,41 @@ torch::Tensor aqlm_dequant( + auto in_features = codes.size(1) * 8; + auto out_features = codes.size(0); + +- assert(out_features = codebook_partition_sizes.sum().item()); ++ assert(out_features == std::accumulate(codebook_partition_sizes.begin(), ++ codebook_partition_sizes.end(), 0)); + + auto weights = torch::empty({out_features, in_features}, +- torch::TensorOptions() +- .dtype(codebooks.dtype()) +- .device(codebooks.device()) +- ); ++ torch::TensorOptions() ++ .dtype(codebooks.dtype()) ++ .device(codebooks.device())); + +- if (nbooks == 1 && entries == (1 << 16)) +- { +- vllm::aqlm::code1x16_dequant_cuda( +- codes.data_ptr(), +- weights.data_ptr(), +- codebooks.data_ptr(), +- out_features, +- in_features, +- cumulative_sizes, +- vllm::aqlm::codebook_stride(codebooks)); +- +- // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.) +- // weights *= scales.index({"...", 0, 0}); +- +- return weights; ++ if (nbooks == 1 && entries == (1 << 16)) { ++ vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(), ++ codebooks.data_ptr(), out_features, ++ in_features, cumulative_sizes, ++ vllm::aqlm::codebook_stride(codebooks)); ++ ++ // if you wanted to flip to scaling the weights, (though it's 30%-ish slower ++ // and not consistent with gemv implementation.) weights *= ++ // scales.index({"...", 0, 0}); ++ ++ return weights; + } + +- if (nbooks == 2 && entries == (1 << 8)) +- { +- vllm::aqlm::code2x8_dequant_cuda( +- codes.data_ptr(), +- weights.data_ptr(), +- codebooks.data_ptr(), +- out_features, +- in_features, +- cumulative_sizes, +- vllm::aqlm::codebook_stride(codebooks)); +- +- // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation) +- // weights *= scales.index({"...", 0, 0}); +- +- return weights; ++ if (nbooks == 2 && entries == (1 << 8)) { ++ vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(), ++ codebooks.data_ptr(), out_features, ++ in_features, cumulative_sizes, ++ vllm::aqlm::codebook_stride(codebooks)); ++ ++ // if you wanted to flip to scaling the weights, (though it's 30%-ish slower ++ // and not consistent with gemv implementation) weights *= ++ // scales.index({"...", 0, 0}); ++ ++ return weights; + } + +- TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") ++ TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, ++ " entries is not currently supported.") + return {}; + } +diff --git a/csrc/quantization/awq/dequantize.cuh b/csrc/quantization/awq/dequantize.cuh +index d1d926d..5fa4b5f 100644 +--- a/csrc/quantization/awq/dequantize.cuh ++++ b/csrc/quantization/awq/dequantize.cuh +@@ -1,11 +1,11 @@ + /* + Adapted from https://github.com/mit-han-lab/llm-awq +-Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h ++Modified from NVIDIA FasterTransformer: ++https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + @article{lin2023awq, +- title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, +- author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, +- journal={arXiv}, +- year={2023} ++ title={AWQ: Activation-aware Weight Quantization for LLM Compression and ++Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, ++Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} + } + */ + +@@ -14,74 +14,89 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor + namespace vllm { + namespace awq { + +-__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) +-{ ++__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + assert(false); + #else +- uint4 result; ++ uint4 result; + +- uint32_t* h = reinterpret_cast(&result); +- uint32_t const i4s = reinterpret_cast(source); ++ uint32_t* h = reinterpret_cast(&result); ++ uint32_t const i4s = reinterpret_cast(source); + +- // First, we extract the i4s and construct an intermediate fp16 number. +- static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; +- static constexpr uint32_t BOTTOM_MASK = 0x000f000f; +- static constexpr uint32_t TOP_MASK = 0x00f000f0; +- static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; ++ // First, we extract the i4s and construct an intermediate fp16 number. ++ static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; ++ static constexpr uint32_t BOTTOM_MASK = 0x000f000f; ++ static constexpr uint32_t TOP_MASK = 0x00f000f0; ++ static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + +- // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing +- // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. +- // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and +- // elt_67 to fp16 without having to shift them to the bottom bits before hand. ++ // Note that the entire sequence only requires 1 shift instruction. This is ++ // thanks to the register packing format and the fact that we force our ++ // integers to be unsigned, and account for this in the fp16 subtractions. In ++ // addition, I exploit the fact that sub and fma have the same throughput in ++ // order to convert elt_23 and elt_67 to fp16 without having to shift them to ++ // the bottom bits before hand. + +- // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue +- // immediately before required. +- const uint32_t top_i4s = i4s >> 8; +- // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 +- asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" +- : "=r"(h[0]) +- : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); +- // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 +- asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" +- : "=r"(h[1]) +- : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); +- // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 +- asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" +- : "=r"(h[2]) +- : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); +- // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 +- asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" +- : "=r"(h[3]) +- : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); ++ // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW ++ // dependency if we issue immediately before required. ++ const uint32_t top_i4s = i4s >> 8; ++ // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 ++ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" ++ : "=r"(h[0]) ++ : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), ++ "n"(immLut)); ++ // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 ++ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" ++ : "=r"(h[1]) ++ : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), ++ "n"(immLut)); ++ // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 ++ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" ++ : "=r"(h[2]) ++ : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), ++ "n"(immLut)); ++ // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 ++ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" ++ : "=r"(h[3]) ++ : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), ++ "n"(immLut)); + +- // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the +- // half2 ctor. In this case, I chose performance reliability over code readability. ++ // I use inline PTX below because I am not sure if the compiler will emit ++ // float2half instructions if I use the half2 ctor. In this case, I chose ++ // performance reliability over code readability. + +- // This is the half2 {1032, 1032} represented as an integer. +- // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; +- // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] +- static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; +- // This is the half2 {1 / 16, 1 / 16} represented as an integer. +- static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; +- // This is the half2 {-72, -72} represented as an integer. +- // static constexpr uint32_t NEG_72 = 0xd480d480; +- // Haotian: Let's use {-64, -64}. +- static constexpr uint32_t NEG_64 = 0xd400d400; ++ // This is the half2 {1032, 1032} represented as an integer. ++ // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; ++ // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] ++ static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; ++ // This is the half2 {1 / 16, 1 / 16} represented as an integer. ++ static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; ++ // This is the half2 {-72, -72} represented as an integer. ++ // static constexpr uint32_t NEG_72 = 0xd480d480; ++ // Haotian: Let's use {-64, -64}. ++ static constexpr uint32_t NEG_64 = 0xd400d400; + +- // Finally, we construct the output numbers. +- // Convert elt_01 +- asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); +- // Convert elt_23 +- asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); +- // Convert elt_45 +- asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); +- // Convert elt_67 +- asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); ++ // Finally, we construct the output numbers. ++ // Convert elt_01 ++ asm volatile("sub.f16x2 %0, %1, %2;\n" ++ : "=r"(h[0]) ++ : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); ++ // Convert elt_23 ++ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" ++ : "=r"(h[1]) ++ : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); ++ // Convert elt_45 ++ asm volatile("sub.f16x2 %0, %1, %2;\n" ++ : "=r"(h[2]) ++ : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); ++ // Convert elt_67 ++ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" ++ : "=r"(h[3]) ++ : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + +- return result; ++ return result; + #endif ++ __builtin_unreachable(); // Suppress missing return statement warning + } + +-} // namespace awq +-} // namespace vllm ++} // namespace awq ++} // namespace vllm +diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu +index 5aefb0b..9da724a 100644 +--- a/csrc/quantization/awq/gemm_kernels.cu ++++ b/csrc/quantization/awq/gemm_kernels.cu +@@ -1,15 +1,13 @@ + /* + Adapted from https://github.com/mit-han-lab/llm-awq + @article{lin2023awq, +- title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, +- author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, +- journal={arXiv}, +- year={2023} ++ title={AWQ: Activation-aware Weight Quantization for LLM Compression and ++Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, ++Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} + } + */ + +- +-#include ++#include + #include + + #include "dequantize.cuh" +@@ -19,27 +17,13 @@ Adapted from https://github.com/mit-han-lab/llm-awq + namespace vllm { + namespace awq { + +-// Pack two half values. +-static inline __device__ __host__ unsigned +-__pack_half2(const half x, const half y) { +- unsigned v0 = *((unsigned short *)&x); +- unsigned v1 = *((unsigned short *)&y); +- return (v1 << 16) | v0; +-} +- +-template +-__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( +- int G, +- int split_k_iters, +- half* __restrict__ A, +- int* __restrict__ B, +- half* __restrict__ scaling_factors, +- int* __restrict__ zeros, +- int M, +- int IC, +- int OC, +- half* __restrict__ C) +-{ ++template ++__global__ void __launch_bounds__(64) ++ gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters, ++ half* __restrict__ A, int* __restrict__ B, ++ half* __restrict__ scaling_factors, ++ int* __restrict__ zeros, int M, int IC, ++ int OC, half* __restrict__ C) { + // Only support matrix n = 64 or 128 + assert(N == 64 || N == 128); + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 +@@ -50,11 +34,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( + __shared__ half A_shared[16 * (32 + 8)]; + __shared__ half B_shared[32 * (N + 8)]; + +- __shared__ half scaling_factors_shared[N]; +- __shared__ half zeros_shared[N]; +- + int j_factors1 = ((OC + N - 1) / N); +- int blockIdx_x = 0; + int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); + int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); + +@@ -68,45 +48,47 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( + + static constexpr int row_stride_warp = 32 * 8 / 32; + static constexpr int row_stride = 2 * 32 * 8 / N; +- bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 +- bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id ++ bool ld_A_flag = ++ (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + ++ threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id + // bool wb_C_flag = (threadIdx.x / 4) < M; + +- half* A_ptr = A +- + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC +- + (((int)threadIdx.x) % (32 / 8)) * 8; +- +- int* B_ptr = B +- + ((int)threadIdx.y) * (OC / 8) * (256 / N) +- + (((int)threadIdx.x) / (N / 8)) * (OC / 8) +- + (((int)blockIdx_y) % j_factors1) * (N / 8) +- + (((int)threadIdx.x) % (N / 8)) * 1; +-// Why * 1 in the above line? +- +- half* A_shared_ptr = A_shared +- + ((int)threadIdx.y) * row_stride_warp * (32 + 8) +- + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) +- + (((int)threadIdx.x) % (32 / 8) ) * 8; +- +- half* B_shared_ptr = B_shared +- + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) +- + (((int)threadIdx.x) / (N / 8)) * (N + 8) +- + (((int)threadIdx.x) % (N / 8)) * 8; +- +- int* zeros_ptr = zeros +- + (((int)blockIdx_y) % j_factors1) * (N / 8) +- + ((int)threadIdx.x) % (N / 8); +- +- half* scaling_factors_ptr = scaling_factors +- + (((int)blockIdx_y) % j_factors1) * N +- + (((int)threadIdx.x) % (N / 8)) * 8; +- +- half* C_ptr = C +- + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim +- + (((int)blockIdx_y) % j_factors1) * N +- + ((int)threadIdx.y) * (N / 2) +- + (((int)threadIdx.x) % 4) * 2; ++ half* A_ptr = ++ A + ++ (((int)blockIdx_y) / j_factors1 * 16 + ++ (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * ++ IC + ++ (((int)threadIdx.x) % (32 / 8)) * 8; ++ ++ int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) + ++ (((int)threadIdx.x) / (N / 8)) * (OC / 8) + ++ (((int)blockIdx_y) % j_factors1) * (N / 8) + ++ (((int)threadIdx.x) % (N / 8)) * 1; ++ // Why * 1 in the above line? ++ ++ half* A_shared_ptr = A_shared + ++ ((int)threadIdx.y) * row_stride_warp * (32 + 8) + ++ (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + ++ (((int)threadIdx.x) % (32 / 8)) * 8; ++ ++ half* B_shared_ptr = B_shared + ++ ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + ++ (((int)threadIdx.x) / (N / 8)) * (N + 8) + ++ (((int)threadIdx.x) % (N / 8)) * 8; ++ ++ int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) + ++ ((int)threadIdx.x) % (N / 8); ++ ++ half* scaling_factors_ptr = scaling_factors + ++ (((int)blockIdx_y) % j_factors1) * N + ++ (((int)threadIdx.x) % (N / 8)) * 8; ++ ++ half* C_ptr = ++ C + ++ static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim ++ + (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) + ++ (((int)threadIdx.x) % 4) * 2; + + // preload s.f. and zeros + int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; +@@ -115,57 +97,79 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( + int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; + __syncthreads(); + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 +- if (ld_A_flag) +- { ++ if (ld_A_flag) { + *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); +- } +- else +- { ++ } else { + *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); + } + + // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { + uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); +- uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); ++ uint4 B_loaded_scale = ++ *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + /* +- if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ +- printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); ++ if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && ++ threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, ++ B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, ++ B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); + } + */ + // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); + int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); + + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { +- + // B: 32 x 136 (128+8) float16 + // each warp: 32 x 4 +- // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 +- // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); +- // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) +- uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); ++ // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus ++ // zero -> WB UINT4 ++ // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * ++ // 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) ++ // * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * ++ // 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * ++ // 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * ++ // 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N) ++ uint32_t B_loaded = ++ *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); +- //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); + +- // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); + // - zero and * scale +- // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. +- asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); +- asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); +- asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); +- asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); +- asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); +- asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); +- asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); +- asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); ++ // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = ++ // q * scale - zero * scale. ++ asm volatile("sub.f16x2 %0, %1, %2;\n" ++ : "=r"(B_loaded_fp16.x) ++ : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); ++ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" ++ : "=r"(B_loaded_fp16.x) ++ : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); ++ asm volatile("sub.f16x2 %0, %1, %2;\n" ++ : "=r"(B_loaded_fp16.y) ++ : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); ++ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" ++ : "=r"(B_loaded_fp16.y) ++ : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); ++ asm volatile("sub.f16x2 %0, %1, %2;\n" ++ : "=r"(B_loaded_fp16.z) ++ : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); ++ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" ++ : "=r"(B_loaded_fp16.z) ++ : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); ++ asm volatile("sub.f16x2 %0, %1, %2;\n" ++ : "=r"(B_loaded_fp16.w) ++ : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); ++ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" ++ : "=r"(B_loaded_fp16.w) ++ : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + /* +- if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ +- printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); ++ if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == ++ 0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n", ++ B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); + } + */ + + // write back +- *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; ++ *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = ++ B_loaded_fp16; + } + __syncthreads(); + +@@ -173,123 +177,184 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( + { + unsigned int addr; + __asm__ __volatile__( +- "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" +- : "=r"(addr) +- : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) +- ); +- ++ "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " ++ "addr; }\n" ++ : "=r"(addr) ++ : "l"((void*)((&(A_shared[(k_0_1 * 16)])) + ++ (((((int)threadIdx.x) & 15) * 40) + ++ ((((int)threadIdx.x) >> 4) * 8))))); + + __asm__ __volatile__( +- "ldmatrix.sync.aligned.m8n8.x4.shared.b16" +- "{%0, %1, %2, %3}, [%4];\n" +- : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) +- : "r"(addr) +- ); ++ "ldmatrix.sync.aligned.m8n8.x4.shared.b16" ++ "{%0, %1, %2, %3}, [%4];\n" ++ : "=r"(((unsigned*)(A_shared_warp + 0))[0]), ++ "=r"(((unsigned*)(A_shared_warp + 0))[1]), ++ "=r"(((unsigned*)(A_shared_warp + 0))[2]), ++ "=r"(((unsigned*)(A_shared_warp + 0))[3]) ++ : "r"(addr)); + } + + for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { + { + unsigned int addr; + __asm__ __volatile__( +- "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" +- : "=r"(addr) +- : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) +- ); ++ "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " ++ "addr; }\n" ++ : "=r"(addr) ++ : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + ++ (((int)threadIdx.y) * (N / 2))) + ++ (ax1_0 * 16))])) + ++ (((((int)threadIdx.x) & 15) * (N + 8)) + ++ ((((int)threadIdx.x) >> 4) * 8))))); + __asm__ __volatile__( +- "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" +- "{%0, %1, %2, %3}, [%4];\n" +- : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) +- : "r"(addr) +- ); ++ "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" ++ "{%0, %1, %2, %3}, [%4];\n" ++ : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]), ++ "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]), ++ "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]), ++ "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3]) ++ : "r"(addr)); + } + } + for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { +-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 ++ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + { + __asm__ __volatile__( +- "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" +- "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" +- : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) +- : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); ++ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" ++ "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" ++ : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), ++ "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), ++ "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), ++ "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) ++ : "r"(((unsigned*)(A_shared_warp + 0))[0]), ++ "r"(((unsigned*)(A_shared_warp + 0))[1]), ++ "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), ++ "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), ++ "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), ++ "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), ++ "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( +- "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" +- "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" +- : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) +- : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); ++ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" ++ "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" ++ : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), ++ "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), ++ "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), ++ "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) ++ : "r"(((unsigned*)(A_shared_warp + 0))[0]), ++ "r"(((unsigned*)(A_shared_warp + 0))[1]), ++ "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), ++ "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), ++ "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), ++ "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), ++ "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } + + { + __asm__ __volatile__( +- "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" +- "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" +- : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) +- : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); ++ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" ++ "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" ++ : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), ++ "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), ++ "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), ++ "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) ++ : "r"(((unsigned*)(A_shared_warp + 0))[2]), ++ "r"(((unsigned*)(A_shared_warp + 0))[3]), ++ "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), ++ "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), ++ "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), ++ "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), ++ "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( +- "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" +- "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" +- : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) +- : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); ++ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" ++ "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" ++ : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), ++ "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), ++ "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), ++ "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) ++ : "r"(((unsigned*)(A_shared_warp + 0))[2]), ++ "r"(((unsigned*)(A_shared_warp + 0))[3]), ++ "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), ++ "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), ++ "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), ++ "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), ++ "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } +-#else ++ #else + { + __asm__ __volatile__( +- "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" +- "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" +- : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) +- : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); ++ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" ++ "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " ++ "%13};\n" ++ : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), ++ "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), ++ "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), ++ "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) ++ : "r"(((unsigned*)(A_shared_warp + 0))[0]), ++ "r"(((unsigned*)(A_shared_warp + 0))[1]), ++ "r"(((unsigned*)(A_shared_warp + 0))[2]), ++ "r"(((unsigned*)(A_shared_warp + 0))[3]), ++ "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), ++ "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), ++ "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), ++ "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), ++ "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), ++ "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( +- "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" +- "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" +- : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) +- : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); ++ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" ++ "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " ++ "%13};\n" ++ : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), ++ "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), ++ "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), ++ "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) ++ : "r"(((unsigned*)(A_shared_warp + 0))[0]), ++ "r"(((unsigned*)(A_shared_warp + 0))[1]), ++ "r"(((unsigned*)(A_shared_warp + 0))[2]), ++ "r"(((unsigned*)(A_shared_warp + 0))[3]), ++ "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), ++ "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), ++ "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), ++ "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), ++ "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), ++ "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } + +-#endif ++ #endif + } + } + } + +-// TODO: Shang: Hoist loop invariance. ++ // TODO: Shang: Hoist loop invariance. + for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { + for (int local_id = 0; local_id < 8; ++local_id) { +- int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; +- if (row_offset < M) +- { +- *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); ++ int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ++ ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; ++ if (row_offset < M) { ++ *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + ++ local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); + } + } + } + #endif + } + +-__global__ void __launch_bounds__(64) dequantize_weights( +- int* __restrict__ B, +- half* __restrict__ scaling_factors, +- int* __restrict__ zeros, +- half* __restrict__ C, +- int G +-) +-{ +- int j_factors1 = 4; +- int row_stride2 = 4; +- int split_k_iters = 1; ++__global__ void __launch_bounds__(64) ++ dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors, ++ int* __restrict__ zeros, half* __restrict__ C, int G) { + static constexpr uint32_t ZERO = 0x0; + half B_shared[32 * (128 + 8)]; + + half* B_shared_ptr2 = B_shared; + +- half B_shared_warp[32]; +- int OC = 512; +- + int N = blockDim.x * gridDim.x; // 2 + int col = (blockIdx.x * blockDim.x + threadIdx.x); + int row = blockIdx.y * blockDim.y + threadIdx.y; +@@ -310,14 +375,30 @@ __global__ void __launch_bounds__(64) dequantize_weights( + + uint32_t B_loaded = *(uint32_t*)B_ptr2; + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); +- asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); +- asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); +- asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); +- asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); +- asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); +- asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); +- asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); +- asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); ++ asm volatile("sub.f16x2 %0, %1, %2;\n" ++ : "=r"(B_loaded_fp16.x) ++ : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); ++ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" ++ : "=r"(B_loaded_fp16.x) ++ : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); ++ asm volatile("sub.f16x2 %0, %1, %2;\n" ++ : "=r"(B_loaded_fp16.y) ++ : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); ++ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" ++ : "=r"(B_loaded_fp16.y) ++ : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); ++ asm volatile("sub.f16x2 %0, %1, %2;\n" ++ : "=r"(B_loaded_fp16.z) ++ : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); ++ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" ++ : "=r"(B_loaded_fp16.z) ++ : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); ++ asm volatile("sub.f16x2 %0, %1, %2;\n" ++ : "=r"(B_loaded_fp16.w) ++ : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); ++ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" ++ : "=r"(B_loaded_fp16.w) ++ : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + + *(uint4*)B_shared_ptr2 = B_loaded_fp16; + +@@ -326,58 +407,57 @@ __global__ void __launch_bounds__(64) dequantize_weights( + } + } + +-} // namespace awq +-} // namespace vllm +- +-torch::Tensor awq_dequantize( +- torch::Tensor _kernel, +- torch::Tensor _scaling_factors, +- torch::Tensor _zeros, +- int split_k_iters, +- int thx, +- int thy) +-{ +- int in_c = _kernel.size(0); +- int qout_c = _kernel.size(1); +- int out_c = qout_c * 8; +- int G = in_c / _scaling_factors.size(0); +- +- int x_thread = thx; +- int y_thread = thy; +- +- int x_blocks = 1; +- int y_blocks = 1; +- if (thx==0) { +- x_thread = qout_c; +- } +- if (thy==0) { +- y_thread = in_c; +- } +- if (thx==0 && thy==0) { +- x_thread = 8; +- y_thread = 8; +- x_blocks = (int)(qout_c / 8); +- y_blocks = (int)(in_c / 8); +- } ++} // namespace awq ++} // namespace vllm ++ ++torch::Tensor awq_dequantize(torch::Tensor _kernel, ++ torch::Tensor _scaling_factors, ++ torch::Tensor _zeros, int64_t split_k_iters, ++ int64_t thx, int64_t thy) { ++ int in_c = _kernel.size(0); ++ int qout_c = _kernel.size(1); ++ int out_c = qout_c * 8; ++ int G = in_c / _scaling_factors.size(0); ++ ++ int x_thread = thx; ++ int y_thread = thy; ++ ++ int x_blocks = 1; ++ int y_blocks = 1; ++ if (thx == 0) { ++ x_thread = qout_c; ++ } ++ if (thy == 0) { ++ y_thread = in_c; ++ } ++ if (thx == 0 && thy == 0) { ++ x_thread = 8; ++ y_thread = 8; ++ x_blocks = (int)(qout_c / 8); ++ y_blocks = (int)(in_c / 8); ++ } + +- const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); ++ const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); + +- auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); +- at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); ++ auto options = torch::TensorOptions() ++ .dtype(_scaling_factors.dtype()) ++ .device(_scaling_factors.device()); ++ at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); + +- auto kernel = reinterpret_cast(_kernel.data_ptr()); +- auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); +- auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); +- auto zeros = reinterpret_cast(_zeros.data_ptr()); ++ auto kernel = reinterpret_cast(_kernel.data_ptr()); ++ auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); ++ auto scaling_factors = ++ reinterpret_cast(_scaling_factors.data_ptr()); ++ auto zeros = reinterpret_cast(_zeros.data_ptr()); + +- dim3 num_blocks(x_blocks, y_blocks); +- dim3 threads_per_block(x_thread, y_thread); ++ dim3 num_blocks(x_blocks, y_blocks); ++ dim3 threads_per_block(x_thread, y_thread); + +- const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); +- vllm::awq::dequantize_weights<<>>( +- kernel, scaling_factors, zeros, de_kernel, G); ++ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); ++ vllm::awq::dequantize_weights<<>>( ++ kernel, scaling_factors, zeros, de_kernel, G); + +- return _de_kernel; ++ return _de_kernel; + } + + // in_feats: M, IC [float16] +@@ -386,61 +466,61 @@ torch::Tensor awq_dequantize( + // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] + // assume that batch_size < 16 for now + +-torch::Tensor awq_gemm( +- torch::Tensor _in_feats, +- torch::Tensor _kernel, +- torch::Tensor _scaling_factors, +- torch::Tensor _zeros, +- int split_k_iters) +-{ +- int num_in_feats = _in_feats.size(0); +- int num_in_channels = _in_feats.size(1); +- const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); +- +- auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); +- at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); +- int num_out_feats = _out_feats.size(-2); +- int num_out_channels = _out_feats.size(-1); +- +- auto in_feats = reinterpret_cast(_in_feats.data_ptr()); +- auto kernel = reinterpret_cast(_kernel.data_ptr()); +- auto out_feats = reinterpret_cast(_out_feats.data_ptr()); +- auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); +- auto zeros = reinterpret_cast(_zeros.data_ptr()); +- int group_size = num_in_channels / _scaling_factors.size(0); +- +- if (num_out_channels % 64 != 0) +- throw std::invalid_argument("OC is not multiple of cta_N = 64"); +- if (num_out_channels % 8 != 0) +- throw std::invalid_argument("OC is not multiple of pack_num = 8"); +- if (group_size % 32 != 0) +- throw std::invalid_argument("Group size should be a multiple of 32"); +- if (num_out_channels % group_size != 0) +- throw std::invalid_argument("OC is not multiple of Group size"); +- +- const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); +- if (num_out_channels % 128 == 0) +- { +- int j_factors1 = num_out_channels / 128 / 1; +- dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); +- // threadIdx.x: 32 +- // threadIdx.y: i_factors[2] * j_factors[2] +- dim3 threads_per_block(32, 2); +- vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<>>( +- group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, +- num_out_channels, out_feats); +- } +- else if (num_out_channels % 64 == 0) +- { +- int j_factors1 = num_out_channels / 64 / 1; +- dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); +- +- // threadIdx.x: 32 +- // threadIdx.y: i_factors[2] * j_factors[2] +- dim3 threads_per_block(32, 2); +- vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<>>( +- group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, +- num_out_channels, out_feats); +- } +- return _out_feats.sum(0); ++torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, ++ torch::Tensor _scaling_factors, torch::Tensor _zeros, ++ int64_t split_k_iters) { ++ int num_in_feats = _in_feats.size(0); ++ int num_in_channels = _in_feats.size(1); ++ const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); ++ ++ auto options = torch::TensorOptions() ++ .dtype(_in_feats.dtype()) ++ .device(_in_feats.device()); ++ at::Tensor _out_feats = ++ torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); ++ int num_out_feats = _out_feats.size(-2); ++ int num_out_channels = _out_feats.size(-1); ++ ++ auto in_feats = reinterpret_cast(_in_feats.data_ptr()); ++ auto kernel = reinterpret_cast(_kernel.data_ptr()); ++ auto out_feats = reinterpret_cast(_out_feats.data_ptr()); ++ auto scaling_factors = ++ reinterpret_cast(_scaling_factors.data_ptr()); ++ auto zeros = reinterpret_cast(_zeros.data_ptr()); ++ int group_size = num_in_channels / _scaling_factors.size(0); ++ ++ if (num_out_channels % 64 != 0) ++ throw std::invalid_argument("OC is not multiple of cta_N = 64"); ++ if (num_out_channels % 8 != 0) ++ throw std::invalid_argument("OC is not multiple of pack_num = 8"); ++ if (group_size % 32 != 0) ++ throw std::invalid_argument("Group size should be a multiple of 32"); ++ if (num_out_channels % group_size != 0) ++ throw std::invalid_argument("OC is not multiple of Group size"); ++ ++ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); ++ if (num_out_channels % 128 == 0) { ++ int j_factors1 = num_out_channels / 128 / 1; ++ dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); ++ // threadIdx.x: 32 ++ // threadIdx.y: i_factors[2] * j_factors[2] ++ dim3 threads_per_block(32, 2); ++ vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128> ++ <<>>( ++ group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, ++ num_in_feats, num_in_channels, num_out_channels, out_feats); ++ } else if (num_out_channels % 64 == 0) { ++ int j_factors1 = num_out_channels / 64 / 1; ++ dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * ++ split_k_iters); ++ ++ // threadIdx.x: 32 ++ // threadIdx.y: i_factors[2] * j_factors[2] ++ dim3 threads_per_block(32, 2); ++ vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64> ++ <<>>( ++ group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, ++ num_in_feats, num_in_channels, num_out_channels, out_feats); ++ } ++ return _out_feats.sum(0); + } +diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +new file mode 100644 +index 0000000..e797858 +--- /dev/null ++++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +@@ -0,0 +1,286 @@ ++#include ++#include ++#include ++ ++#include "../../dispatch_utils.h" ++ ++#ifndef USE_ROCM ++ #include ++ #include ++#else ++ #include ++ #include ++#endif ++ ++static inline __device__ int8_t float_to_int8_rn(float x) { ++#ifdef USE_ROCM ++ static constexpr auto i8_min = ++ static_cast(std::numeric_limits::min()); ++ static constexpr auto i8_max = ++ static_cast(std::numeric_limits::max()); ++ ++ // To match the rounding mode of CUDA, we use nearbyint. ++ // It uses the current rounding mode, which is always FE_TONEAREST on HIP. ++ // If that changes in the future, we may need to set the rounding mode ++ // explicitly, either at runtime or compile time. ++ float dst = std::nearbyint(x); ++ ++ // saturate ++ dst = std::clamp(dst, i8_min, i8_max); ++ return static_cast(dst); ++#else ++ // CUDA path ++ uint32_t dst; ++ asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); ++ return reinterpret_cast(dst); ++#endif ++} ++ ++static inline __device__ int32_t float_to_int32_rn(float x) { ++#ifdef USE_ROCM ++ // int32_max is not exactly representable as float. ++ // Therefore, we need to be careful and manually return int32_max on overflow. ++ // For symmetry, we also do the same for int32_min, even though it is exactly ++ // representable as float and the conversion should be exact. ++ static constexpr auto i32_min = std::numeric_limits::min(); ++ static constexpr auto i32_min_f = static_cast(i32_min); ++ static constexpr auto i32_max = std::numeric_limits::max(); ++ static constexpr auto i32_max_f = static_cast(i32_max); ++ ++ // To match the rounding mode of CUDA, we use nearbyint. ++ // It uses the current rounding mode, which is always FE_TONEAREST on HIP. ++ // If that changes in the future, we may need to set the rounding mode ++ // explicitly, either at runtime or compile time. ++ float dst = std::nearbyint(x); ++ ++ // saturate on the higher end. ++ if (dst >= i32_max_f) { ++ return i32_max; ++ } ++ // saturate on the lower end. ++ if (dst <= i32_min_f) { ++ return i32_min; ++ } ++ ++ return static_cast(dst); ++#else ++ // CUDA path ++ uint32_t dst; ++ asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x)); ++ return reinterpret_cast(dst); ++#endif ++} ++ ++static inline __device__ int8_t int32_to_int8(int32_t x) { ++#ifdef USE_ROCM ++ static constexpr auto i8_min = ++ static_cast(std::numeric_limits::min()); ++ static constexpr auto i8_max = ++ static_cast(std::numeric_limits::max()); ++ ++ // saturate ++ int32_t dst = std::clamp(x, i8_min, i8_max); ++ return static_cast(dst); ++#else ++ // CUDA path ++ uint32_t dst; ++ asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x)); ++ return reinterpret_cast(dst); ++#endif ++} ++ ++namespace vllm { ++ ++template ++__global__ void static_scaled_int8_quant_kernel( ++ scalar_t const* __restrict__ input, int8_t* __restrict__ out, ++ scale_type const* scale_ptr, const int hidden_size) { ++ int const tid = threadIdx.x; ++ int64_t const token_idx = blockIdx.x; ++ scale_type const scale = *scale_ptr; ++ ++ // Must be performed using 64-bit math to avoid integer overflow. ++ out += token_idx * hidden_size; ++ input += token_idx * hidden_size; ++ ++ for (int i = tid; i < hidden_size; i += blockDim.x) { ++ out[i] = float_to_int8_rn(static_cast(input[i]) / scale); ++ } ++} ++ ++template ++__global__ void static_scaled_int8_azp_quant_kernel( ++ scalar_t const* __restrict__ input, int8_t* __restrict__ out, ++ scale_type const* scale_ptr, azp_type const* azp_ptr, ++ const int hidden_size) { ++ int const tid = threadIdx.x; ++ int64_t const token_idx = blockIdx.x; ++ scale_type const scale = *scale_ptr; ++ azp_type const azp = *azp_ptr; ++ ++ // Must be performed using 64-bit math to avoid integer overflow. ++ out += token_idx * hidden_size; ++ input += token_idx * hidden_size; ++ ++ for (int i = tid; i < hidden_size; i += blockDim.x) { ++ auto const val = static_cast(input[i]); ++ auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); ++ out[i] = quant_val; ++ } ++} ++ ++template ++__global__ void dynamic_scaled_int8_quant_kernel( ++ scalar_t const* __restrict__ input, int8_t* __restrict__ out, ++ scale_type* scale, const int hidden_size) { ++ int const tid = threadIdx.x; ++ int64_t const token_idx = blockIdx.x; ++ float absmax_val = 0.0f; ++ float const zero = 0.0f; ++ ++ // Must be performed using 64-bit math to avoid integer overflow. ++ out += token_idx * hidden_size; ++ input += token_idx * hidden_size; ++ ++ for (int i = tid; i < hidden_size; i += blockDim.x) { ++ float val = static_cast(input[i]); ++ val = val > zero ? val : -val; ++ absmax_val = val > absmax_val ? val : absmax_val; ++ } ++ ++ using BlockReduce = cub::BlockReduce; ++ __shared__ typename BlockReduce::TempStorage reduceStorage; ++ float const block_absmax_val_maybe = ++ BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x); ++ __shared__ float block_absmax_val; ++ if (tid == 0) { ++ block_absmax_val = block_absmax_val_maybe; ++ scale[token_idx] = block_absmax_val / 127.0f; ++ } ++ __syncthreads(); ++ ++ float const tmp_scale = 127.0f / block_absmax_val; ++ for (int i = tid; i < hidden_size; i += blockDim.x) { ++ out[i] = float_to_int8_rn(static_cast(input[i]) * tmp_scale); ++ } ++} ++ ++template ++__global__ void dynamic_scaled_int8_azp_quant_kernel( ++ scalar_t const* __restrict__ input, int8_t* __restrict__ out, ++ scale_type* scale, azp_type* azp, const int hidden_size) { ++ int64_t const token_idx = blockIdx.x; ++ ++ // Must be performed using 64-bit math to avoid integer overflow. ++ out += token_idx * hidden_size; ++ input += token_idx * hidden_size; ++ ++ // Scan for the min and max value for this token ++ float max_val = std::numeric_limits::min(); ++ float min_val = std::numeric_limits::max(); ++ for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { ++ auto val = static_cast(input[i]); ++ max_val = std::max(max_val, val); ++ min_val = std::min(min_val, val); ++ } ++ ++ // Reduce the max and min values across the block ++ using BlockReduce = cub::BlockReduce; ++ __shared__ typename BlockReduce::TempStorage reduceStorage; ++ max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x); ++ __syncthreads(); // Make sure min doesn't mess with max shared memory ++ min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x); ++ ++ __shared__ scale_type scale_sh; ++ __shared__ azp_type azp_sh; ++ ++ // Compute the scale and zero point and store them, only on the first thread ++ if (threadIdx.x == 0) { ++ float const scale_val = (max_val - min_val) / 255.0f; ++ // Use rounding to even (same as torch.round) ++ auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val); ++ auto const azp_val = static_cast(azp_float); ++ ++ // Store the scale and azp into shared and global ++ scale[token_idx] = scale_sh = scale_val; ++ azp[token_idx] = azp_sh = azp_val; ++ } ++ ++ // Wait for the scale and azp to be computed ++ __syncthreads(); ++ ++ float const scale_val = scale_sh; ++ azp_type const azp_val = azp_sh; ++ ++ // Quantize the values ++ for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { ++ auto const val = static_cast(input[i]); ++ auto const quant_val = ++ int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); ++ out[i] = quant_val; ++ } ++} ++ ++} // namespace vllm ++ ++void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] ++ torch::Tensor const& input, // [..., hidden_size] ++ torch::Tensor const& scale, ++ std::optional const& azp) { ++ TORCH_CHECK(input.is_contiguous()); ++ TORCH_CHECK(out.is_contiguous()); ++ TORCH_CHECK(scale.numel() == 1); ++ TORCH_CHECK(!azp || azp->numel() == 1); ++ ++ int const hidden_size = input.size(-1); ++ int const num_tokens = input.numel() / hidden_size; ++ dim3 const grid(num_tokens); ++ dim3 const block(std::min(hidden_size, 1024)); ++ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); ++ VLLM_DISPATCH_FLOATING_TYPES( ++ input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { ++ if (!azp) { ++ vllm::static_scaled_int8_quant_kernel ++ <<>>( ++ input.data_ptr(), out.data_ptr(), ++ scale.data_ptr(), hidden_size); ++ } else { ++ vllm::static_scaled_int8_azp_quant_kernel ++ <<>>( ++ input.data_ptr(), out.data_ptr(), ++ scale.data_ptr(), azp->data_ptr(), ++ hidden_size); ++ } ++ }); ++} ++ ++void dynamic_scaled_int8_quant( ++ torch::Tensor& out, // [..., hidden_size] ++ torch::Tensor const& input, // [..., hidden_size] ++ torch::Tensor& scales, std::optional const& azp) { ++ TORCH_CHECK(input.is_contiguous()); ++ TORCH_CHECK(out.is_contiguous()); ++ TORCH_CHECK(scales.is_contiguous()); ++ TORCH_CHECK(!azp || azp->is_contiguous()); ++ ++ int const hidden_size = input.size(-1); ++ int const num_tokens = input.numel() / hidden_size; ++ dim3 const grid(num_tokens); ++ dim3 const block(std::min(hidden_size, 1024)); ++ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); ++ VLLM_DISPATCH_FLOATING_TYPES( ++ input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { ++ if (!azp) { ++ vllm::dynamic_scaled_int8_quant_kernel ++ <<>>( ++ input.data_ptr(), out.data_ptr(), ++ scales.data_ptr(), hidden_size); ++ } else { ++ vllm::dynamic_scaled_int8_azp_quant_kernel ++ <<>>( ++ input.data_ptr(), out.data_ptr(), ++ scales.data_ptr(), azp->data_ptr(), ++ hidden_size); ++ } ++ }); ++} +diff --git a/csrc/quantization/cutlass_w8a8/Epilogues.md b/csrc/quantization/cutlass_w8a8/Epilogues.md +new file mode 100644 +index 0000000..aae0415 +--- /dev/null ++++ b/csrc/quantization/cutlass_w8a8/Epilogues.md +@@ -0,0 +1,147 @@ ++# CUTLASS Epilogues ++ ++## Introduction ++This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs. ++ ++Currently, we only support symmetric quantization for weights, ++and symmetric and asymmetric quantization for activations. ++Both can be quantized per-tensor or per-channel (weights) / per-token (activations). ++ ++There are 4 epilogues: ++1. ScaledEpilogue: symmetric quantization for activations, no bias. ++1. ScaledEpilogueBias: symmetric quantization for activations, supports bias. ++1. ScaledEpilogueAzp: asymmetric per-tensor quantization for activations, supports bias. ++1. ScaledEpilogueAzpPerToken: asymmetric per-token quantization for activations, supports bias. ++ ++We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size. ++Instead, if no bias is passed, the epilogue will use 0 as the bias. ++That induces a redundant addition operation (and runtime check), but the performance impact is minor. ++ ++## Underlying Linear Algebra ++ ++More details available in the [Activation Quantization RFC](https://github.com/vllm-project/vllm/issues/3975). ++ ++If $` \widehat X `$ is the quantized $` X `$, our matrices become the following ++ ++```math ++A = s_a (\widehat A - J_a z_a) ++``` ++```math ++B = s_b \widehat B ++``` ++```math ++D = A B + C ++``` ++```math ++D = s_a s_b \widehat D + C ++``` ++ ++Here, D is the output of the GEMM, and C is the bias. ++A is the activations and supports asymmetric quantization, ++and B is the weights and only supports symmetric quantization. ++$ s_a $ and $s_b$ are the scales for activations and weights, respectively. ++$ z_a $ is the zero-point for activations, and $ J_a $ is the matrix of all ones with dimensions of A. ++Additional epilogues would be required to support asymmetric quantization for weights. ++ ++Expanding further, we can calculate $` \widehat D `$ as follows: ++ ++```math ++A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B ++``` ++```math ++A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right) ++``` ++```math ++\widehat D = \widehat A \widehat B - z_a J_a \widehat B ++``` ++ ++Note that $` \widehat A \widehat B `$ is the raw output of the GEMM, ++and $` J_a \widehat B `$ is known ahead of time. ++Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of column sums of $` \widehat B `$. ++ ++## Epilogues ++ ++### ScaledEpilogue ++This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$. ++The output of the GEMM is: ++ ++```math ++\widehat D = \widehat A \widehat B ++``` ++```math ++D = s_a s_b \widehat D ++``` ++```math ++D = s_a s_b \widehat A \widehat B ++``` ++ ++Epilogue parameters: ++- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). ++- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). ++ ++### ScaledEpilogueBias ++This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$. ++The output of the GEMM is: ++ ++```math ++\widehat D = \widehat A \widehat B ++``` ++```math ++D = s_a s_b \widehat D + C ++``` ++```math ++D = s_a s_b \widehat A \widehat B + C ++``` ++ ++ ++Epilogue parameters: ++- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). ++- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). ++- `bias` is the bias, is always per-channel (row-vector). ++ ++### ScaledEpilogueAzp ++This epilogue computes the asymmetric per-tensor quantization for activations with bias. ++The output of the GEMM is: ++ ++```math ++\widehat D = \widehat A \widehat B - z_a J_a \widehat B ++``` ++```math ++D = s_a s_b \widehat D + C ++``` ++```math ++D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C ++``` ++ ++Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$. ++That is precomputed and stored in `azp_with_adj` as a row-vector. ++ ++Epilogue parameters: ++- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). ++ - Generally this will be per-tensor as the zero-points are per-tensor. ++- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). ++- `azp_with_adj` is the precomputed zero-point term ($` z_a J_a \widehat B `$), is per-channel (row-vector). ++- `bias` is the bias, is always per-channel (row-vector). ++ ++To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel. ++ ++### ScaledEpilogueAzpPerToken ++This epilogue computes the asymmetric per-token quantization for activations with bias. ++ ++The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector. ++That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$. ++ ++Epilogue parameters: ++- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). ++ - Generally this will be per-token as the zero-points are per-token. ++- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). ++- `azp_adj` is the precomputed zero-point adjustment term ($` \mathbf 1 \widehat B `$), is per-channel (row-vector). ++- `azp` is the zero-point (`z_a`), is per-token (column-vector). ++- `bias` is the bias, is always per-channel (row-vector). ++ ++To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel. ++ ++The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM): ++``` ++out = scale_a * scale_b * (Dq - azp_adj * azp) + bias ++``` +diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +new file mode 100644 +index 0000000..865fef5 +--- /dev/null ++++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +@@ -0,0 +1,199 @@ ++#include ++#include ++#include "cutlass/cutlass.h" ++ ++#include "scaled_mm_c2x.cuh" ++#include "scaled_mm_c2x_sm75_dispatch.cuh" ++#include "scaled_mm_c2x_sm80_dispatch.cuh" ++#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh" ++#include "scaled_mm_c2x_sm89_int8_dispatch.cuh" ++ ++#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp" ++ ++using namespace vllm; ++ ++/* ++ This file defines quantized GEMM operations using the CUTLASS 2.x API, for ++ NVIDIA GPUs with SM versions prior to sm90 (Hopper). ++*/ ++ ++template