diff --git a/torch_npu/_inductor/__init__.py b/torch_npu/_inductor/__init__.py index d1c588ddd5f3a9e7959e405c3146f6109402e01b..2352246d3d17d392e065f97c5d9333f857f46245 100644 --- a/torch_npu/_inductor/__init__.py +++ b/torch_npu/_inductor/__init__.py @@ -1,106 +1,117 @@ import os import torch -from torch._dynamo.device_interface import register_interface_for_device, get_interface_for_device -from torch._inductor import lowering as inductor_lowering -from torch._inductor.choices import InductorChoices -from torch._inductor.codegen.common import register_backend_for_device, register_device_op_overrides -from torch._inductor.runtime import autotune_cache -from torch_npu.npu import device_count -from torch_npu.utils._dynamo_device import NpuInterface, current_device, set_device -from torch_npu.utils._inductor import NPUDeviceOpOverrides -from torch_npu.utils._triton import patch_triton_for_inductor +from torch._inductor import config as inductor_config + +if os.getenv('TORCHINDUCTOR_MAX_AUTOTUNE', '0') == '1': + _has_inited = False + if not _has_inited: + _has_inited = True + from .ascend_npu_ir.build_ext import build_ascend_npu_ir_ext + build_ascend_npu_ir_ext() + from .ascend_npu_ir.ascend_npu_ir.npu import npu_inductor_plugin +else: + from torch._dynamo.device_interface import register_interface_for_device, get_interface_for_device + from torch._inductor import lowering as inductor_lowering + from torch._inductor.choices import InductorChoices + from torch._inductor.codegen.common import register_backend_for_device, register_device_op_overrides + from torch._inductor.runtime import autotune_cache -from . import config as npu_config -from . import codegen -from .npu_fusion_attention_graph import register_fa_pass -from .codecache import patch_cache_base_get_system -from .config import aggresive_autotune, num_vector_core, set_compile_threads -from .config import log as npulog -from .decomposition import _register_npu_inductor_decompositons -from .lowering import make_reduction, npu_make_fallback -from .npu_choices import should_use_persistent_reduction -from .npu_device import NewNPUDeviceOpOverrides -from .runtime import _load_cached_autotuning -from .utils import get_current_raw_stream + from torch_npu.npu import device_count + from torch_npu.utils._dynamo_device import NpuInterface, current_device, set_device + from torch_npu.utils._inductor import NPUDeviceOpOverrides + from torch_npu.utils._triton import patch_triton_for_inductor -set_compile_threads() + from . import config as npu_config + from . import codegen + from .npu_fusion_attention_graph import register_fa_pass + from .codecache import patch_cache_base_get_system + from .config import aggresive_autotune, num_vector_core, set_compile_threads + from .config import log as npulog + from .decomposition import _register_npu_inductor_decompositons + from .lowering import make_reduction, npu_make_fallback + from .npu_choices import should_use_persistent_reduction + from .npu_device import NewNPUDeviceOpOverrides + from .runtime import _load_cached_autotuning + from .utils import get_current_raw_stream + set_compile_threads() -def _inductor_register_backend_for_device(): - from .codegen.scheduling import NPUTritonScheduling - from .codegen.wrapper import NPUWrapperCodeGen - from .codegen.cpp_wrapper import CppWrapperNpu - register_backend_for_device('npu', NPUTritonScheduling, NPUWrapperCodeGen, CppWrapperNpu) + def _inductor_register_backend_for_device(): + from .codegen.scheduling import NPUTritonScheduling + from .codegen.wrapper import NPUWrapperCodeGen + from .codegen.cpp_wrapper import CppWrapperNpu + register_backend_for_device('npu', NPUTritonScheduling, NPUWrapperCodeGen, CppWrapperNpu) -_inductor_register_backend_for_device() + _inductor_register_backend_for_device() -def _inductor_register_device_op_overrides(): - register_device_op_overrides('npu', NewNPUDeviceOpOverrides()) + def _inductor_register_device_op_overrides(): + register_device_op_overrides('npu', NewNPUDeviceOpOverrides()) -_inductor_register_device_op_overrides() -device = get_interface_for_device("npu") + _inductor_register_device_op_overrides() -inductor_lowering.make_reduction = make_reduction -inductor_lowering.make_fallback = npu_make_fallback + device = get_interface_for_device("npu") + inductor_lowering.make_reduction = make_reduction + inductor_lowering.make_fallback = npu_make_fallback -def patch_torch_for_aoti(): - from .graph import patch_codegen_with_cpp_wrapper - from .cpp_builder import patch_get_cpp_torch_device_options - from .codegen.cpp_utils import patch_device_to_aten - from .utils import patch_is_same_tensor - from .fx_passes.joint_graph import patch_constant_fold_uniform_value - from .ir import patch_fallback_kernel_codegen - from .codecache import patch_aot_code_compiler_compile - patch_codegen_with_cpp_wrapper() - patch_get_cpp_torch_device_options() - patch_device_to_aten() - patch_is_same_tensor() - patch_constant_fold_uniform_value() - patch_fallback_kernel_codegen() - patch_aot_code_compiler_compile() + def patch_torch_for_aoti(): + from .graph import patch_codegen_with_cpp_wrapper + from .cpp_builder import patch_get_cpp_torch_device_options + from .codegen.cpp_utils import patch_device_to_aten + from .utils import patch_is_same_tensor + from .fx_passes.joint_graph import patch_constant_fold_uniform_value + from .ir import patch_fallback_kernel_codegen + from .codecache import patch_aot_code_compiler_compile + patch_codegen_with_cpp_wrapper() + patch_get_cpp_torch_device_options() + patch_device_to_aten() + patch_is_same_tensor() + patch_constant_fold_uniform_value() + patch_fallback_kernel_codegen() + patch_aot_code_compiler_compile() -if os.environ.get("DISABLE_AOTI_PATCH", "0") != "1": - patch_torch_for_aoti() + if os.environ.get("DISABLE_AOTI_PATCH", "0") != "1": + patch_torch_for_aoti() -if npu_config.dump_fx_graph: - from .codegen.ir_fx import _patch_npu_inductor_ir - _patch_npu_inductor_ir() + if npu_config.dump_fx_graph: + from .codegen.ir_fx import _patch_npu_inductor_ir -if npu_config.dump_fx_graph: - from .lowering_fx import _register_npu_inductor_fallbacks -else: - from .lowering import _register_npu_inductor_fallbacks + _patch_npu_inductor_ir() + + if npu_config.dump_fx_graph: + from .lowering_fx import _register_npu_inductor_fallbacks + else: + from .lowering import _register_npu_inductor_fallbacks -_register_npu_inductor_fallbacks() -_register_npu_inductor_decompositons() + _register_npu_inductor_fallbacks() + _register_npu_inductor_decompositons() -# register fx_pass should be put behind of _register_npu_inductor_decompositons -def _replace_benchmark_all_configs(): - from torch._inductor.triton_heuristics import CachingAutotuner - from .npu_triton_heuristics import benchmark_all_configs - CachingAutotuner.benchmark_all_configs = benchmark_all_configs + # register fx_pass should be put behind of _register_npu_inductor_decompositons + def _replace_benchmark_all_configs(): + from torch._inductor.triton_heuristics import CachingAutotuner + from .npu_triton_heuristics import benchmark_all_configs + CachingAutotuner.benchmark_all_configs = benchmark_all_configs -if (aggresive_autotune): - _replace_benchmark_all_configs() - import os + if (aggresive_autotune): + _replace_benchmark_all_configs() + import os - os.environ["TRITON_BENCH_METHOD"] = "npu" + os.environ["TRITON_BENCH_METHOD"] = "npu" -InductorChoices.should_use_persistent_reduction = should_use_persistent_reduction -autotune_cache._load_cached_autotuning = _load_cached_autotuning + InductorChoices.should_use_persistent_reduction = should_use_persistent_reduction + autotune_cache._load_cached_autotuning = _load_cached_autotuning -register_fa_pass() + register_fa_pass() -patch_cache_base_get_system() -patch_triton_for_inductor() + patch_cache_base_get_system() + patch_triton_for_inductor() diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/_C/extension.cpp b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/_C/extension.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6c5460c456b57e22ef85ed037e2f170cbf36d1a3 --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/_C/extension.cpp @@ -0,0 +1,70 @@ +#include +#include +#include + +#include "hacl_rt.h" + +namespace py = pybind11; +static uint64_t kBiShengStartAddr = 0xbadbeef; + +void ReadBinFile(const char* file_name, uint32_t* fileSize, char** buffer) +{ + std::filebuf* pbuf; + std::ifstream filestr; + size_t size; + filestr.open(file_name, std::ios::binary); + if (!filestr) { + printf("open file failed!"); + throw std::runtime_error("open file failed!\n"); + } + pbuf = filestr.rdbuf(); + size = pbuf->pubseekoff(0, std::ios::end, std::ios::in); + pbuf->pubseekpos(0, std::ios::in); + + *buffer = new char[size]; + if (NULL == *buffer) { + printf("cannot malloc buffer size\n"); + throw std::runtime_error("cannot malloc buffer size"); + } + pbuf->sgetn(*buffer, size); + *fileSize = size; + + filestr.close(); +} + +const uintptr_t RegisterBinaryKernel(const char* func_name, const char* bin_file, char* buffer) +{ + rtDevBinary_t binary; + void* binHandle = NULL; + uint32_t bufferSize = 0; + ReadBinFile(bin_file, &bufferSize, &buffer); + if (NULL == buffer) { + printf("ReadBinFile failed\n"); + return reinterpret_cast(nullptr); + } + binary.data = buffer; + binary.length = bufferSize; + binary.magic = RT_DEV_BINARY_MAGIC_ELF_AIVEC; + binary.version = 0; + rtError_t rtRet = rtDevBinaryRegister(&binary, &binHandle); + if (rtRet != RT_ERROR_NONE) { + printf("rtDevBinaryRegister failed!\n"); + return reinterpret_cast(nullptr); + } + kBiShengStartAddr += 1; + rtRet = rtFunctionRegister(binHandle, reinterpret_cast(kBiShengStartAddr), func_name, (void*)func_name, 0); + if (rtRet != RT_ERROR_NONE) { + printf("rtFunctionRegister failed!\n"); + return reinterpret_cast(nullptr); + } + return reinterpret_cast (kBiShengStartAddr); +} + +PYBIND11_MODULE(_C, m) { + m.def("load_kernel_binary", [](const char* func_name, const char* bin_file){ + char* buffer = nullptr; + const uintptr_t kBiShengStartAddr = RegisterBinaryKernel(func_name, bin_file, buffer); + delete buffer; + return kBiShengStartAddr; + }); +} \ No newline at end of file diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/_C/include/hacl_rt.h b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/_C/include/hacl_rt.h new file mode 100644 index 0000000000000000000000000000000000000000..cb85a27dd8fbd5dad5859fa21233f5826962ec6f --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/_C/include/hacl_rt.h @@ -0,0 +1,397 @@ +#ifndef __HACL_RT_H__ +#define __HACL_RT_H__ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// If you need export the function of this library in Win32 dll, use __declspec(dllexport) +#ifndef RTS_API +#ifdef RTS_DLL_EXPORT +#define RTS_API __declspec(dllexport) +#else +#define RTS_API +#endif +#endif + +/** + * @ingroup dvrt_base + * @brief stream handle. + */ +typedef void *rtStream_t; + +/** + * @ingroup dvrt_base + * @brief runtime error numbers. + */ +typedef enum tagRtError { + RT_ERROR_NONE = 0x0, // success + RT_ERROR_INVALID_VALUE = 0x1, // invalid value + RT_ERROR_MEMORY_ALLOCATION = 0x2, // memory allocation fail + RT_ERROR_INVALID_RESOURCE_HANDLE = 0x3, // invalid handle + RT_ERROR_INVALID_DEVICE_POINTER = 0x4, // invalid device point + RT_ERROR_INVALID_MEMCPY_DIRECTION = 0x5, // invalid memory copy dirction + RT_ERROR_INVALID_DEVICE = 0x6, // invalid device + RT_ERROR_NO_DEVICE = 0x7, // no valid device + RT_ERROR_CMD_OCCUPY_FAILURE = 0x8, // command occpuy failure + RT_ERROR_SET_SIGNAL_FAILURE = 0x9, // set signal failure + RT_ERROR_UNSET_SIGNAL_FAILURE = 0xA, // unset signal failure + RT_ERROR_OPEN_FILE_FAILURE = 0xB, // unset signal failure + RT_ERROR_WRITE_FILE_FAILURE = 0xC, + RT_ERROR_MEMORY_ADDRESS_UNALIGNED = 0xD, + RT_ERROR_DRV_ERR = 0xE, + RT_ERROR_LOST_HEARTBEAT = 0xF, + RT_ERROR_REPORT_TIMEOUT = 0x10, + RT_ERROR_NOT_READY = 0x11, + RT_ERROR_DATA_OPERATION_FAIL = 0x12, + RT_ERROR_INVALID_L2_INSTR_SIZE = 0x13, + RT_ERROR_DEVICE_PROC_HANG_OUT = 0x14, + RT_ERROR_DEVICE_POWER_UP_FAIL = 0x15, + RT_ERROR_DEVICE_POWER_DOWN_FAIL = 0x16, + RT_ERROR_FEATURE_NOT_SUPPROT = 0x17, + RT_ERROR_KERNEL_DUPLICATE = 0x18, // register same kernel repeatly + RT_ERROR_MODEL_STREAM_EXE_FAILED = 0x91, // the model stream failed + RT_ERROR_MODEL_LOAD_FAILED = 0x94, // the model stream failed + RT_ERROR_END_OF_SEQUENCE = 0x95, // end of sequence + RT_ERROR_NO_STREAM_CB_REG = 0x96, // no callback register info for stream + RT_ERROR_DATA_DUMP_LOAD_FAILED = 0x97, // data dump load info fail + RT_ERROR_CALLBACK_THREAD_UNSUBSTRIBE = 0x98, // callback thread unsubstribe + RT_ERROR_RESERVED +} rtError_t; + +/** + * @ingroup rt_kernel + * @brief device binary type + */ +typedef struct tagRtDevBinary { + uint32_t magic; // magic number + uint32_t version; // version of binary + const void *data; // binary data + uint64_t length; // binary length +} rtDevBinary_t; + +/** + * @ingroup rt_kernel + * @brief shared memory data control + */ +typedef struct tagRtSmData { + uint64_t L2_mirror_addr; // preload or swap source address + uint32_t L2_data_section_size; // every data size + uint8_t L2_preload; // 1 - preload from mirrorAddr, 0 - no preload + uint8_t modified; // 1 - data will be modified by kernel, 0 - no modified + uint8_t priority; // data priority + int8_t prev_L2_page_offset_base; // remap source section offset + uint8_t L2_page_offset_base; // remap destination section offset + uint8_t L2_load_to_ddr; // 1 - need load out, 0 - no need + uint8_t reserved[2]; // reserved +} rtSmData_t; + +/** + * @ingroup rt_kernel + * @brief shared memory description + */ +typedef struct tagRtSmCtrl { + rtSmData_t data[8]; // data description + uint64_t size; // max page Num + uint8_t remap[64]; /* just using for static remap mode, default:0xFF + array index: virtual l2 page id, array value: physic l2 page id */ + uint8_t l2_in_main; // 0-DDR, 1-L2, default:0xFF + uint8_t reserved[3]; +} rtSmDesc_t; + +/** + * @ingroup rt_kernel + * @brief magic number of plain binary for aicore + */ +#define RT_DEV_BINARY_MAGIC_PLAIN 0xabceed50 + +/** + * @ingroup rt_kernel + * @brief magic number of plain binary for aicpu + */ +#define RT_DEV_BINARY_MAGIC_PLAIN_AICPU 0xabceed51 + +/** + * @ingroup rt_kernel + * @brief magic number of plain binary for aivector + */ +#define RT_DEV_BINARY_MAGIC_PLAIN_AIVEC 0xabceed52 + +/** + * @ingroup rt_kernel + * @brief magic number of elf binary for aicore + */ +#define RT_DEV_BINARY_MAGIC_ELF 0x43554245 + +/** + * @ingroup rt_kernel + * @brief magic number of elf binary for aicpu + */ +#define RT_DEV_BINARY_MAGIC_ELF_AICPU 0x41415243 + +/** + * @ingroup rt_kernel + * @brief magic number of elf binary for aivector + */ +#define RT_DEV_BINARY_MAGIC_ELF_AIVEC 0x41415246 + +/** + * @ingroup rt_kernel + * @brief register device binary + * @param [in] bin device binary description + * @param [out] handle device binary handle + * @return RT_ERROR_NONE for ok + * @note:if this interface is changed, pls notify the compiler changing at the same time. + */ +RTS_API rtError_t rtDevBinaryRegister(const rtDevBinary_t *bin, void **handle); + +/** + * @ingroup rt_kernel + * @brief register device function + * @param [in] binHandle device binary handle + * @param [in] stubFunc stub function + * @param [in] stubName stub function name + * @param [in] devFunc device function description. symbol name or address + * offset, depending binary type. + * @return RT_ERROR_NONE for ok + * @note:if this interface is changed, pls notify the compiler changing at the same time. + */ +RTS_API rtError_t rtFunctionRegister(void *binHandle, const void *stubFunc, const char *stubName, const void *devFunc, + uint32_t funcMode); + +/** + * @ingroup rt_kernel + * @brief launch kernel to device + * @param [in] stubFunc stub function + * @param [in] blockDim block dimentions + * @param [in] args argments address for kernel function + * @param [in] argsSize argements size + * @param [in] smDesc shared memory description + * @param [in] stream associated stream + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtKernelLaunch(const void *stubFunc, uint32_t blockDim, void *args, uint32_t argsSize, + rtSmDesc_t *smDesc, rtStream_t stream); + +typedef struct tagRtArgsEx { + void *args; // input + output + scalar + tiling addr + tiling data + void *hostInputInfoPtr; // nullptr + uint32_t argsSize; // input addr size + output addr size + scalar size + tiling addr size + tiling data size + uint16_t tilingAddrOffset; // size to tiling addr + uint16_t tilingDataOffset; // size to tiling data + uint16_t hostInputInfoNum; // 0 + uint8_t hasTiling; // has tiling + uint8_t isNoNeedH2DCopy; // not need rtKernelLaunchWithFlag copy tiling from host to device + uint8_t reserved[4]; +} rtArgsEx_t; + +/** + * @ingroup rt_kernel + * @brief launch kernel and tiling to device + * @param [in] stubFunc stub function + * @param [in] blockDim block dimentions + * @param [in] argsInfo argments address for kernel function + * @param [in] smDesc shared memory description + * @param [in] stream associated stream + * @param [in] flag not use, set 0 + * @note:if this interface is changed, pls notify the compiler changing at the same time. + */ + +RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim, rtArgsEx_t *argsInfo, + rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flag = 0); + +/** + * @ingroup dvrt_mem + * @brief memory type + */ +#define RT_MEMORY_DEFAULT ((uint32_t)0x0) // default memory on device +#define RT_MEMORY_HBM ((uint32_t)0x2) // HBM memory on device +#define RT_MEMORY_DDR ((uint32_t)0x4) // DDR memory on device +#define RT_MEMORY_SPM ((uint32_t)0x8) // shared physical memory on device +#define RT_MEMORY_P2P_HBM ((uint32_t)0x10) // HBM memory on other 4P device +#define RT_MEMORY_P2P_DDR ((uint32_t)0x11) // DDR memory on other device +#define RT_MEMORY_DDR_NC ((uint32_t)0x20) // DDR memory of non-cache +#define RT_MEMORY_TS_4G ((uint32_t)0x40) +#define RT_MEMORY_TS ((uint32_t)0x80) +#define RT_MEMORY_RESERVED ((uint32_t)0x100) + +#define RT_MEMORY_L1 ((uint32_t)0x1<<16) +#define RT_MEMORY_L2 ((uint32_t)0x1<<17) + +/** + * @ingroup dvrt_mem + * @brief memory type | memory Policy + */ +typedef uint32_t rtMemType_t; + +/** + * @ingroup dvrt_mem + * @brief memory copy type + */ +typedef enum tagRtMemcpyKind { + RT_MEMCPY_HOST_TO_HOST = 0, // host to host + RT_MEMCPY_HOST_TO_DEVICE, // host to device + RT_MEMCPY_DEVICE_TO_HOST, // device to host + RT_MEMCPY_DEVICE_TO_DEVICE, // device to device, 1P && P2P + RT_MEMCPY_MANAGED, // managed memory + RT_MEMCPY_ADDR_DEVICE_TO_DEVICE, + RT_MEMCPY_HOST_TO_DEVICE_EX, // host to device ex (only used for 8 bytes) + RT_MEMCPY_RESERVED, +} rtMemcpyKind_t; + +/** + * @ingroup dvrt_mem + * @brief alloc device memory + * @param [in|out] devPtr memory pointer + * @param [in] size memory size + * @param [in] type memory type + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_MEMORY_ALLOCATION for memory allocation failed + */ +RTS_API rtError_t rtMalloc(void **devPtr, uint64_t size, rtMemType_t type); + +/** + * @ingroup dvrt_mem + * @brief free device memory + * @param [in|out] devPtr memory pointer + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_DEVICE_POINTER for error device memory pointer + */ +RTS_API rtError_t rtFree(void *devPtr); + + +/** + * @ingroup dvrt_mem + * @brief alloc host memory + * @param [in|out] hostPtr memory pointer + * @param [in] size memory size + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_MEMORY_ALLOCATION for memory allocation failed + */ +RTS_API rtError_t rtMallocHost(void **hostPtr, uint64_t size); + +/** + * @ingroup dvrt_mem + * @brief free host memory + * @param [in] hostPtr memory pointer + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_DEVICE_POINTER for error device memory pointer + */ +RTS_API rtError_t rtFreeHost(void *hostPtr); + +/** + * @ingroup dvrt_mem + * @brief synchronized memcpy + * @param [in] dst destination address pointer + * @param [in] Max length of destination address memory + * @param [in] src source address pointer + * @param [in] count the number of byte to copy + * @param [in] kind memcpy type + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input of count + * @return RT_ERROR_INVALID_DEVICE_POINTER for error input memory pointer of dst,src + * @return RT_ERROR_INVALID_MEMCPY_DIRECTION for error copy direction of kind + */ +RTS_API rtError_t rtMemcpy(void *dst, uint64_t destMax, const void *src, uint64_t count, rtMemcpyKind_t kind); + +/** + * @ingroup dvrt_mem + * @brief asynchronized memcpy + * @param [in] dst destination address pointer + * @param [in] Max length of destination address memory + * @param [in] src source address pointer + * @param [in] count the number of byte to copy + * @param [in] kind memcpy type + * @param [in] stream asynchronized task stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input of count,stream + * @return RT_ERROR_INVALID_DEVICE_POINTER for error input memory pointer of dst,src + * @return RT_ERROR_INVALID_MEMCPY_DIRECTION for error copy direction of kind + */ +RTS_API rtError_t rtMemcpyAsync(void *dst, uint64_t destMax, const void *src, uint64_t count, rtMemcpyKind_t kind, + rtStream_t stream); + +/** + * @ingroup dvrt_stream + * @brief create stream instance + * @param [in|out] stream created stream + * @param [in] priority stream priority + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input stream handle + * @return RT_ERROR_INVALID_VALUE for error input priority + */ +RTS_API rtError_t rtStreamCreate(rtStream_t *stream, int32_t priority); + +/** + * @ingroup dvrt_stream + * @brief destroy stream instance. + * @param [in] stream the stream to destroy + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input stream handle + */ +RTS_API rtError_t rtStreamDestroy(rtStream_t stream); + +/** + * @ingroup dvrt_stream + * @brief wait stream to be complete + * @param [in] stream stream to wait + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input stream or event handle + */ +RTS_API rtError_t rtStreamSynchronize(rtStream_t stream); + +/** + * @ingroup dvrt_dev + * @brief set target device for current thread + * @param [int] device the device id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_DEVICE for can not match ID and device + */ +RTS_API rtError_t rtSetDevice(int32_t device); + +/** + * @ingroup dvrt_dev + * @brief reset all opened device + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_DEVICE if no device set + */ +RTS_API rtError_t rtDeviceReset(int32_t device); + +RTS_API rtError_t rtGetTaskIdAndStreamID(uint32_t *taskId, uint32_t *streamId); + +RTS_API rtError_t rtGetC2cCtrlAddr(uint64_t *addr, uint32_t *len); + +#ifndef char_t +typedef char char_t; +#endif + +/** + * @ingroup dvrt_dev + * @brief get chipType + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtGetSocVersion(char_t *ver, const uint32_t maxLen); + +/** + * @ingroup + * @brief get AI core count + * @param [in] aiCoreCnt + * @return aiCoreCnt + */ +RTS_API rtError_t rtGetAiCoreCount(uint32_t *aiCoreCnt); + +/** + * @ingroup + * @brief get AI cpu count + * @param [in] aiCpuCnt + * @return aiCpuCnt + */ +RTS_API rtError_t rtGetAiCpuCount(uint32_t *aiCpuCnt); + +#ifdef __cplusplus +} +#endif + +#endif // __HACL_RT_H__ diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/__init__.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/build_info.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/build_info.py new file mode 100644 index 0000000000000000000000000000000000000000..7ad413fa9320beea41f974c9f38951c37c60a18e --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/build_info.py @@ -0,0 +1 @@ +ABI_TAG = 0 diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/cache.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..366cc39ba252bfce1bb4eeaddd5f5eff5de1e6cc --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/cache.py @@ -0,0 +1,295 @@ +import importlib +import json +import os +import uuid +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Dict, List, Optional +import base64 +import hashlib + + +def get_home_dir(): + return os.getenv("TRITON_HOME", Path.home()) + + +def default_cache_dir(): + return os.path.join(get_home_dir(), ".triton", "cache") + + +def default_override_dir(): + return os.path.join(get_home_dir(), ".triton", "override") + + +def default_dump_dir(): + return os.path.join(get_home_dir(), ".triton", "dump") + + +class CacheManager(ABC): + + def __init__(self, key): + pass + + @abstractmethod + def get_file(self, filename) -> Optional[str]: + pass + + @abstractmethod + def put(self, data, filename, binary=True) -> str: + pass + + @abstractmethod + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + pass + + @abstractmethod + def put_group(self, filename: str, group: Dict[str, str]): + pass + + +class FileCacheManager(CacheManager): + + def __init__(self, key, override=False, dump=False): + self.key = key + self.lock_path = None + if dump: + self.cache_dir = os.getenv("TRITON_DUMP_DIR", "").strip() or default_dump_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = os.getenv("TRITON_OVERRIDE_DIR", "").strip() or default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir() + if self.cache_dir: + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") + + def _make_path(self, filename) -> str: + return os.path.join(self.cache_dir, filename) + + def has_file(self, filename) -> bool: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + return os.path.exists(self._make_path(filename)) + + def get_file(self, filename) -> Optional[str]: + if self.has_file(filename): + return self._make_path(filename) + else: + return None + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + grp_filename = f"__grp__{filename}" + if not self.has_file(grp_filename): + return None + grp_filepath = self._make_path(grp_filename) + with open(grp_filepath) as f: + grp_data = json.load(f) + child_paths = grp_data.get("child_paths", None) + # Invalid group data. + if child_paths is None: + return None + result = {} + for c, p in child_paths.items(): + if os.path.exists(p): + result[c] = p + return result + + # Note a group of pushed files as being part of a group + def put_group(self, filename: str, group: Dict[str, str]) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + grp_contents = json.dumps({"child_paths": group}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename, binary=False) + + def put(self, data, filename, binary=True) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + binary = isinstance(data, bytes) + if not binary: + data = str(data) + assert self.lock_path is not None + filepath = self._make_path(filename) + # Random ID to avoid any collisions + rnd_id = str(uuid.uuid4()) + # we use the PID in case a bunch of these around so we can see what PID made it + pid = os.getpid() + # use temp dir to be robust against program interruptions + temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}") + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.join(temp_dir, filename) + + mode = "wb" if binary else "w" + with open(temp_path, mode) as f: + f.write(data) + # Replace is guaranteed to be atomic on POSIX systems if it succeeds + # so filepath cannot see a partial write + os.replace(temp_path, filepath) + os.removedirs(temp_dir) + return filepath + + +class RemoteCacheBackend: + """ + A backend implementation for accessing a remote/distributed cache. + """ + + def __init__(self, key: str): + pass + + @abstractmethod + def get(self, filenames: List[str]) -> Dict[str, bytes]: + pass + + @abstractmethod + def put(self, filename: str, data: bytes): + pass + + +class RedisRemoteCacheBackend(RemoteCacheBackend): + + def __init__(self, key): + import redis + self._key = key + self._key_fmt = os.environ.get("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}") + self._redis = redis.Redis( + host=os.environ.get("TRITON_REDIS_HOST", "localhost"), + port=int(os.environ.get("TRITON_REDIS_PORT", 6379)), + ) + + def _get_key(self, filename: str) -> str: + return self._key_fmt.format(key=self._key, filename=filename) + + def get(self, filenames: List[str]) -> Dict[str, str]: + results = self._redis.mget([self._get_key(f) for f in filenames]) + return {filename: result for filename, result in zip(filenames, results) if result is not None} + + def put(self, filename: str, data: bytes) -> Dict[str, bytes]: + self._redis.set(self._get_key(filename), data) + + +class RemoteCacheManager(CacheManager): + + def __init__(self, key, override=False, dump=False): + # Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`. + remote_cache_manager = os.environ["TRITON_REMOTE_CACHE_BACKEND"] + module_path, clz_nme = remote_cache_manager.split(":") + module = importlib.import_module(module_path) + remote_cache_cls = getattr(module, clz_nme) + self._backend = remote_cache_cls(key) + + self._override = override + self._dump = dump + + # Use a `FileCacheManager` to materialize remote cache paths locally. + self._file_cache_manager = FileCacheManager(key, override=override, dump=dump) + + def _materialize(self, filename: str, data: bytes): + # We use a backing `FileCacheManager` to provide the materialized data. + return self._file_cache_manager.put(data, filename, binary=True) + + def get_file(self, filename: str) -> Optional[str]: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.get_file(filename) + + # We always check the remote cache backend -- even if our internal file- + # based cache has the item -- to make sure LRU accounting works as + # expected. + results = self._backend.get([filename]) + if len(results) == 0: + return None + (_, data), = results.items() + return self._materialize(filename, data) + + def put(self, data, filename: str, binary=True) -> str: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.put(data, filename, binary=binary) + + if not isinstance(data, bytes): + data = str(data).encode("utf-8") + self._backend.put(filename, data) + return self._materialize(filename, data) + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.get_group(filename) + + grp_filename = f"__grp__{filename}" + grp_filepath = self.get_file(grp_filename) + if grp_filepath is None: + return None + with open(grp_filepath) as f: + grp_data = json.load(f) + child_paths = grp_data.get("child_paths", None) + + result = None + + # Found group data. + if child_paths is not None: + result = {} + for child_path, data in self._backend.get(child_paths).items(): + result[child_path] = self._materialize(child_path, data) + + return result + + def put_group(self, filename: str, group: Dict[str, str]): + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.put_group(filename, group) + + grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename) + + +__cache_cls = FileCacheManager +__cache_cls_nme = "DEFAULT" + + +def _base64(key): + # Assume key is a hex string. + return base64.urlsafe_b64encode(bytes.fromhex(key)).decode("utf-8").rstrip("=") + + +def get_cache_manager(key) -> CacheManager: + import os + + user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None) + global __cache_cls + global __cache_cls_nme + + if user_cache_manager is not None and user_cache_manager != __cache_cls_nme: + module_path, clz_nme = user_cache_manager.split(":") + module = importlib.import_module(module_path) + __cache_cls = getattr(module, clz_nme) + __cache_cls_nme = user_cache_manager + + return __cache_cls(_base64(key)) + + +def get_override_manager(key) -> CacheManager: + return __cache_cls(_base64(key), override=True) + + +def get_dump_manager(key) -> CacheManager: + return __cache_cls(_base64(key), dump=True) + + +def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): + # Get unique key for the compiled code + signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()} + key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}" + for kw in kwargs: + key = f"{key}-{kwargs.get(kw)}" + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + return _base64(key) \ No newline at end of file diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/codecache.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/codecache.py new file mode 100644 index 0000000000000000000000000000000000000000..9ab93ba0a72b329045eab254af44282c7cc526cb --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/codecache.py @@ -0,0 +1,281 @@ +import os +import sys +import functools +import importlib + +from typing import ( + Callable, + List, + Any, + Dict +) + +import torch +from time import time +from concurrent.futures import Future +from torch._inductor.utils import developer_warning + +from torch._inductor.async_compile import ( + AsyncCompile, + _compile_start, + SubprocPool, + get_compile_threads, + _pool_set, + log +) + +from torch._dynamo.device_interface import get_interface_for_device + +from torch._inductor.codecache import ( + config, + Union, + TritonFuture, + ModuleType, + TritonCodeCache + ) + +from .npu.mlir_compiler import NpuMlirCompiler +from . import config as anir_config +from .npu.utils import logger + +class CompiledKernel: + def __init__(self, kernel_call): + self.kernel_call = kernel_call + + def run(self, *args, **kwargs): + return self.kernel_call(*args, **kwargs) + +def codegen_subgraph_dump(inds, shapes, strides, dtypes, inds2): + codes = ["args = [arg.cuda() if isinstance(arg, torch.Tensor) else arg for arg in args]"] + codes.append(f'new_args = [None] * {len(inds) + len(inds2)}') + for ind, shape, stride, dtype in zip(inds, shapes, strides, dtypes): + codes.append(f'new_args[{ind}] = rand_strided({shape}, {stride}, device="cuda:0", dtype={dtype})') + codes.append(f'indices = {inds2}') + codes.append(f'for i, ind in enumerate(indices):') + codes.append(f' new_args[ind] = args[i]') + codes.append(f'args = new_args') + return '\n'.join(codes) + + +def _worker_compile( + kernel, cc: int, device: torch.device, logger_level=None, extra_env=None +) -> None: + device_info = (device, device.index) + try: + kernel.get_best_kernel() + except: + kernel.precompile(device_info=device_info, logger_level=logger_level) + +def _load_kernel( + kernel_name: str, + source_code: str, + no_more_compile=False, + suppress_error=False, + kernel_meta=None, + extra_env=None) -> ModuleType: + device_str = kernel_meta.get('device_str') + device_interface = get_interface_for_device(device_str) + device = torch.device(device_str, device_interface.current_device()) + device_info = (device, device.index) + kernel = NpuMlirCompiler(kernel_name, no_more_compile=no_more_compile, kernel_meta=kernel_meta) + kernel.init(module=source_code, extra_env=extra_env) + try: + kernel.get_best_kernel() + except: + kernel.precompile(device_info=device_info, suppress_error=suppress_error) + return kernel + +def _load_fx_graph(kernel_name: str, source_code=None, extra_env=None, kernel_meta=None, autotune=True) -> ModuleType: + kernel = NpuMlirCompiler(kernel_name, kernel_meta=kernel_meta, autotune=autotune) + if source_code is not None: + kernel.init(module=source_code, extra_env=extra_env) + kernel.register_fx_fallback(kernel_meta) + os.makedirs(os.path.join(kernel_meta.get('traced_graph_cache'), str(kernel_meta.get('device_index')), kernel_meta.get('traced_graph_hash'), 'keep'), exist_ok=True) + return kernel + +class MulitprocessCompileFuture(TritonFuture): + kernel: ModuleType + + def __init__( + self, + kernel_name: str, + source_code: str, + futures: List[Future], + kernel_meta, + extra_env, + ) -> None: + self.kernel_name = kernel_name + self.source_code = source_code + self.futures = futures + self.kernel_meta = kernel_meta + self.extra_env = extra_env + + # @dynamo_utils.dynamo_timed + def result(self) -> ModuleType: + t0 = time() + if hasattr(self, "kernel"): + return self.kernel + errors = [] + for future in self.futures: + try: + future.result() + except Exception as e: + logger.warning(f"Error detected when multiprocess compile, error message: {e}") + errors.append(e) + + if len(errors) < len(self.futures): + kernel = self.kernel = _load_kernel(self.kernel_name, self.source_code, + no_more_compile=True, suppress_error=True, + kernel_meta=self.kernel_meta, extra_env=self.extra_env) + elif self.kernel_meta.get('num_outputs', 0): # All compiles fail and auto fallback + print("==========================Kernel compiled failed!=======================================") + print(f'kernel name: {self.kernel_name}') + print(f'{self.source_code}') + print("========================================================================================") + kernel = self.kernel = _load_fx_graph( + self.kernel_name, source_code=self.source_code, extra_env=self.extra_env, kernel_meta=self.kernel_meta) + else: + raise errors[0] + + latency = time() - t0 + if latency > 50: + developer_warning( + f"Detected long compilation time of {latency} seconds for kernel name {self.kernel_name}" + ) + developer_warning(self.source_code) + del self.kernel_name, self.source_code, self.futures + return kernel + + +class NPUTritonFuture(TritonFuture): + kernel: ModuleType + + def __init__( + self, + kernel_name: str, + source_code: str, + future, + kernel_meta, + extra_env + ) -> None: + self.kernel_name = kernel_name + self.source_code = source_code + self.future = future + self.kernel_meta = kernel_meta + self.extra_env = extra_env + + # @dynamo_utils.dynamo_timed + def result(self) -> ModuleType: + t0 = time() + if hasattr(self, "kernel"): + return self.kernel + # If the worker failed this will throw an exception. + if self.kernel_meta.get('num_outputs'): + try: + self.future.result() + kernel = self.kernel = _load_kernel(self.kernel_name, self.source_code, no_more_compile=True, kernel_meta=self.kernel_meta, extra_env=self.extra_env) + except Exception as e: + kernel = self.kernel = _load_fx_graph( + self.kernel_name, source_code=self.source_code, extra_env=self.extra_env, kernel_meta=self.kernel_meta) + else: + self.future.result() + kernel = self.kernel = _load_kernel(self.kernel_name, self.source_code, no_more_compile=True, kernel_meta=self.kernel_meta, extra_env=self.extra_env) + latency = time() - t0 + if latency > 50: + developer_warning( + f"Detected long compilation time of {latency} seconds for kernel name {self.kernel_name}" + ) + developer_warning(self.source_code) + del self.kernel_name, self.source_code, self.future + return kernel + + +class CustomAsyncCompile(AsyncCompile): + @staticmethod + @functools.lru_cache(1) + def process_pool() -> SubprocPool: + assert get_compile_threads() > 1 + # Wrapper around ProcessPoolExecutor forks in a new process we control + log.info("Creating subprocess pool with %d workers", get_compile_threads()) + os.environ['TORCHINDUCTOR_MAX_AUTOTUNE'] = '1' + pool = SubprocPool(get_compile_threads()) + + # Set an attribute we can check to see if the pool is ready. + pool.ready_future = pool.submit(AsyncCompile._get_ready) # type: ignore[attr-defined] + _pool_set.add(pool) + return pool + + def mlir( + self, kernel_name: str, source_code: str, device_str: str = "npu" + ) -> Union[NPUTritonFuture, ModuleType]: + if 'PY_DIR_PATH' in os.environ: + raise RuntimeError('Stop early.') + _compile_start() + + if config.compile_threads > 1: + device_interface = get_interface_for_device(device_str) + device = torch.device(device_str, device_interface.current_device()) + cc = device_interface.get_compute_capability(device) + future = self.process_pool().submit( + _worker_compile, kernel_name, source_code, cc, device, logger_level=logger.level + ) + return NPUTritonFuture(kernel_name, source_code, future) + else: + return _load_kernel(kernel_name, source_code) + + def mlir_auto_fallback( + self, kernel_name: str, source_code: str, kernel_meta: Dict[str, Any]) -> Callable: + _compile_start() + + device_interface = get_interface_for_device(kernel_meta.get('device_str')) + device = torch.device(kernel_meta.get('device_str'), device_interface.current_device()) + cc = device_interface.get_compute_capability(device) + env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"] + extra_env = {v: os.environ[v] for v in env_vars if v in os.environ} + + if config.compile_threads > 1: + if anir_config.multiprocess_compile: + device_info = (device, device.index) + kernel = NpuMlirCompiler(kernel_name, multiprocess_compile=True, kernel_meta=kernel_meta) + kernel.init(module=source_code, extra_env=extra_env) + try: + kernel.get_best_kernel() + return kernel + except: + compile_args = kernel.get_autotune_config() + futures = [] + for cargs in compile_args: + future = self.process_pool().submit( + kernel.compile_mlir, device_info, cargs, logger.level + ) + futures.append(future) + return MulitprocessCompileFuture(kernel_name, source_code, futures, kernel_meta, extra_env) + else: + kernel = NpuMlirCompiler(kernel_name, multiprocess_compile=True, kernel_meta=kernel_meta) + kernel.init(module=source_code, extra_env=extra_env) + try: + kernel.get_best_kernel() + return kernel + except: + future = self.process_pool().submit( + _worker_compile, kernel, cc, device, logger_level=logger.level, extra_env=extra_env + ) + return NPUTritonFuture(kernel_name, source_code, future, kernel_meta, extra_env) + else: + kernel = _load_kernel(kernel_name, source_code, suppress_error=anir_config.autotune, kernel_meta=kernel_meta, extra_env=extra_env) + if len(kernel.launchers) == 0: + logger.info(f"fallback to fx graph call") + return _load_fx_graph(kernel_name, source_code=source_code, extra_env=extra_env, kernel_meta=kernel_meta) + return kernel + + def import_fx( + self, module_name: str, kernel_meta: Dict[str, Any]) -> Callable: + + device_interface = get_interface_for_device(kernel_meta.get('device_str')) + device = torch.device(kernel_meta.get('device_str'), device_interface.current_device()) + cc = device_interface.get_compute_capability(device) + env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"] + extra_env = {v: os.environ[v] for v in env_vars if v in os.environ} + + _compile_start() + return _load_fx_graph(module_name, kernel_meta=kernel_meta, autotune=False) \ No newline at end of file diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/config.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/config.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb4948894ba6f158adc568ef5ca5f802fa1d0bd --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/config.py @@ -0,0 +1,255 @@ +import os +import sys +import time + +import torch +from os.path import abspath, dirname +from typing import ( + Any, Callable, Dict, Optional, List, + Set, Type, TYPE_CHECKING, Union, + ) + +from torch._inductor import config, inductor_prims + +aten = torch.ops.aten +tr_c10d = torch.ops.tr_c10d +prims = torch.ops.prims +npu = torch.ops.npu + +always_compile = False +enable_graph_trace = True +acc_comp_mode = True +disable_any_pbr = True +autotune_fx_fallback = False +cache_named_op = False + +traced_graph_cache = os.environ.get("ANIR_TRACED_GRAPH_CACHE", None) +torch_mlir_dump_path = os.environ.get("ANIR_TORCH_MLIR_DUMP", None) + +online_acc_comp = os.environ.get("ANIR_ONLINE_ACC_COMP", "0") == "1" +runtime_error_dump = os.environ.get("ANIR_RUNTIME_ERROR_DUMP", "0") == "1" +fallback_dump = os.environ.get("ANIR_FALLBACK_DUMP", "0") == "1" +acc_comp_tol = { + torch.float32: {'rtol': 1.3e-6, 'atol': 1e-5}, + torch.float16: {'rtol': 1e-3, 'atol': 1e-5}, + torch.bfloat16: {'rtol': 1.6e-2, 'atol': 1e-5}, + "default": {'rtol': 1.3e-6, 'atol': 1e-5}, +} + +autotune = os.environ.get("AUTOTUNE", "1") == "1" +multiprocess_compile = autotune and os.environ.get("DISABLE_MP_COMPILE", "0") == "0" + +mode = os.getenv('ANIR_MODE', 'O1') +if mode not in ["O0", "O1"]: + raise ValueError(f"Invalid MODE value: {mode}. Allowed values are 'O0' and 'O1'.") + +''' +Add extra command for bisheng compile. Some useful commands: +"-mlir-print-ir-before-all": print the entire IR before each pass. +"-mlir-print-ir-after-all": print the entire IR after each pass. +Add extra commands before model excute, code like: + +from torch_npu._inductor.ascend_npu_ir.ascend_npu_ir import config as anir_config +anir_config.extra_command += [ + "-mlir-print-ir-before-all", + "-mlir-print-ir-after-all" +] +''' +extra_command = [] + +debug = os.environ.get("ANIR_DEBUG", "0") == "1" +fallback_warning = os.environ.get("ANIR_FALLBACK_WARNING", "0") == "1" + +''' +set force_fallback_kernel_names while your faces runtime errors, and want to skip the kernel by kernel name. +set force_fallback_kernel_paths while your faces runtime errors, and want to skip the kernel by kernel cache paths. +examples: +force_fallback_kernel_names = {'mlir_fused_add_1', 'mlir_fused_add_2'} +force_fallback_kernel_paths = {'/path/to/mlir_fused_add_1_0_True_True.o', '/path/to/mlir_fused_add_2_0_True_True.o'} +''' +force_fallback_kernel_names = {} +force_fallback_kernel_paths = {} + + +if debug: + debug_dir = os.environ.get("ANIR_DEBUG_DIR", f"{os.environ['PWD']}/anir_debug") + os.makedirs(debug_dir, exist_ok=True) + +fx_graph_dump_path: str = None + +# dump fx subgraph for debugging +fx_subgraph_dump_path: str = os.environ.get("FX_SUBGRAPH_DUMP_PATH", None) +""" +compile_mode introductions: +"default" refers to the mode of fully compiling with MLIR. Currently, it is not fully supported, but it will be set as the default once the capability matures. +"complete_fallback" refers to completely falling back to the eager execution mode of the FX graph, without performing any MLIR compilation. It is primarily used for debugging. +"auto_fallback" refers to automatically falling back to the fx_graph_backend when compilation fails. +auto_fallback mechanism is designed to provide a fallback strategy when the primary compilation process encounters an issue. It works in conjunction with the fx_graph_backend configuration, allowing for the fallback approach: +Fallback to fx_graph_backend: If the first fallback attempt fails, the system falls back to the fx_graph_backend. +If you need further clarification or have other questions, please let me know! +""" +compile_mode: str = 'auto_fallback' +def _get_compile_mode(): + m = sys.modules[__name__] + if isinstance(m.compile_mode, str): + mode = m.compile_mode + if mode not in ['default', 'complete_fallback', 'auto_fallback']: + raise ValueError(f"Invalid mode {mode=}") + else: + raise ValueError(f"Please use the *str* type to set the compile mode, current type is {type(compile_mode)=}") + + return mode + +block_dim = 48 + +""" +support {"off", "include", "exclude"}, to +"off": No fallback at all. +"include": At compile-time, Aten IR included in FALLBACK_LIST will fall back to aten. +"exclude": At compile-time, Aten IR excluded from GENERATE_LIST will fall back to aten. +""" +fallback_to_aten_mode: str = "exclude" + +REDUCTION_OPS = [ + aten.sum, + prims.sum, + aten.prod, + aten.any, + aten.max, + aten.min, + prims.xor_sum, + aten.amax, + aten.amin, + aten.argmax, + aten.argmin, + aten.mean, + aten.var, + prims.var, + aten.var_mean, +] + +# fall back to aten exclude GENERATE_LIST, all aten IR except +GENERATE_LIST = [ + aten.mul, + aten.add, + aten.sub, + aten.div, + aten.exp, + aten.pow, + aten.rsqrt, + aten.neg, + aten.lt, + aten.gt, + aten.ge, + aten.le, + aten.eq, + aten.sigmoid, + prims.convert_element_type, + torch.ops.npu.npu_dtype_cast, + torch.ops.npu.npu_dtype_cast_backward, + torch.ops.npu._npu_dtype_cast, + torch.ops.npu._npu_dtype_cast_backward, + + aten.squeeze, + aten.unsqueeze, + aten.expand, + aten.repeat, + aten.clone, + aten.reshape, + aten.sin, + aten.cos, + aten.var_mean, + aten.sum, + aten.mean, + aten.full, + aten.slice, + aten.split, + aten.split_with_sizes, + aten.reciprocal, + aten.select, + # prims.iota, + aten.relu, + aten.copy_, + aten.where, + aten.log, + aten.scalar_tensor, + aten.permute, + # aten.cat, + aten.constant_pad_nd, + aten.amax, + aten.slice_scatter, + aten.sqrt, + aten.copy, + aten.clamp_min, + aten.clamp_max, + aten.bitwise_not, + aten.tanh, + aten.unbind, + aten.lift_fresh_copy, +] + + +FALLBACK_LIST = [ + aten.mm, + aten.bmm, + aten.addmm, + aten.convolution, + aten.convolution_backward, + aten._adaptive_avg_pool2d, + aten.max_pool2d_with_indices, + aten.max_pool2d_with_indices_backward, + aten.avg_pool2d, + aten.avg_pool2d_backward, + inductor_prims.lookup_seed, + inductor_prims.random, + prims.device_put, + aten.upsample_nearest2d, + aten.upsample_nearest2d_backward, + aten.embedding, + # aten.cat, + # aten.permute, + aten.constant_pad_nd, + aten.abs, + aten.max, + aten.amax, + aten.amin, + aten.slice_scatter, + aten.select_scatter, + aten.gather, + aten.scatter, + npu._npu_dropout, + aten.empty, + aten.index, + aten.copy_ +] + + +decomps_to_exclude_npu = [ + aten.gelu.default, + aten.gelu_backward.default, + aten.embedding, + aten.embedding_backward, + aten.embedding_dense_backward, + aten.upsample_nearest2d, + aten.upsample_nearest2d_backward, + aten.upsample_nearest1d, + aten.upsample_nearest1d_backward, + aten.upsample_nearest3d, + aten.upsample_nearest3d_backward, + aten.upsample_bilinear2d, + aten.upsample_bilinear2d_backward, + aten.nll_loss2d_forward, + aten.nll_loss2d_backward, + aten.nll_loss_backward, + aten.nll_loss_forward, + aten.triu, + aten.convolution_backward, + aten._softmax_backward_data.default, + aten.max_pool2d_with_indices, + aten.max_pool2d_with_indices_backward, + aten.slice.Tensor, + aten.reflection_pad2d_backward, + aten.reflection_pad2d, + aten.grid_sampler_2d, + aten.grid_sampler_2d_backward, +] \ No newline at end of file diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/cpp_common/cpp_common.cpp b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/cpp_common/cpp_common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..268d8eeeaf526c1675cb26dbc971dedd237438f8 --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/cpp_common/cpp_common.cpp @@ -0,0 +1,229 @@ +#include "cpp_common.h" + +#include +#include +#include +#include +#include +#include + +extern "C" { +#pragma pack(1) +typedef struct ApiDef { + unsigned short int magicNumber; + unsigned short int level; + unsigned int type; + unsigned int threadId; + unsigned int reserve; + unsigned long int beginTime; + unsigned long int endTime; + unsigned long int itemId; +}MsprofApi; + +typedef struct NodeBasicInfo { + unsigned long int opName; + unsigned int taskType; + unsigned long int opType; + unsigned int blockDim; + unsigned int opFlag; +}MsprofNodeBasicInfo; + +typedef struct CompactInfo { + unsigned short int magicNumber; + unsigned short int level; + unsigned int type; + unsigned int threadId; + unsigned int dataLen; + unsigned long int timeStamp; + union { + unsigned char info[40]; + MsprofNodeBasicInfo nodeBasicInfo; + } data; +}MsprofCompactInfo; + +extern int MsprofReportApi(unsigned int agingFlag, const MsprofApi *api); +extern int MsprofReportCompactInfo(unsigned int agingFlag, const void* data, unsigned int length); +extern unsigned long int MsprofGetHashId(char *hashInfo, size_t length); +extern unsigned long int MsprofSysCycleTime(); + +extern aclError aclrtMallocHost(void **hostPtr, size_t size); +extern aclError aclrtMalloc(void **devPtr, size_t size, aclrtMemMallocPolicy policy); +extern aclError aclrtMemcpy(void *dst, size_t destMax, const void *src, size_t count, aclrtMemcpyKind kind); + +extern aclError aclrtFreeHost(void *hostPtr); +extern aclError aclrtFree(void *devPtr); + +typedef struct TilingMem { + std::unique_ptr arg_tiling_host; + std::unique_ptr arg_tiling_device; + TilingMem() + : arg_tiling_host(nullptr, aclrtFreeHost), + arg_tiling_device(nullptr, aclrtFree) + {} +}TilingMemInfo; + +TilingMemInfo MEM_CACHE; + +typedef struct WorkspaceMem { + std::unique_ptr arg_workspace_host; + std::unique_ptr arg_workspace_device; + WorkspaceMem() + : arg_workspace_host(nullptr, aclrtFreeHost), + arg_workspace_device(nullptr, aclrtFree) + {} +}WorkspaceMemInfo; + +WorkspaceMemInfo MEM_WORK_CACHE; +} + +rtError_t common_launch(char* kernelName, const void* func, uint32_t gridX, + void* args, uint32_t argsSize, rtStream_t stream) +{ + unsigned long int beginTime = 0; + unsigned long int endTime = 0; + unsigned long int opName = 0; + unsigned int threadId = 0; + size_t length = strlen(kernelName); + + if (torch_npu::profiler::GetTraceLevel() != -1) { + beginTime = MsprofSysCycleTime(); + } + rtError_t ret = rtKernelLaunch(func, gridX, args, argsSize, NULL, stream); + + if (torch_npu::profiler::GetTraceLevel() != -1) { + endTime = MsprofSysCycleTime(); + opName = MsprofGetHashId(kernelName, length); + threadId = (unsigned int)(syscall(SYS_gettid)); + MsprofApi info; + info.magicNumber = 0x5a5a; // MSPROF_REPORT_DATA_MAGIC_NUM + info.level = 10000; // MSPROF_REPORT_NODE_LEVEL + info.type = 5; // MSPROF_REPORT_NODE_LAUNCH_TYPE + info.threadId = threadId; + info.reserve = 0; + info.beginTime = beginTime; + info.endTime = endTime; + info.itemId = opName; + MsprofReportApi(0, &info); + } + if (torch_npu::profiler::GetTraceLevel() >= 1) { + MsprofCompactInfo nodeBasicInfo; + nodeBasicInfo.magicNumber = 0x5a5a; // MSPROF_REPORT_DATA_MAGIC_NUM + nodeBasicInfo.level = 10000; // MSPROF_REPORT_NODE_LEVEL + nodeBasicInfo.type = 0; // MSPROF_REPORT_NODE_BASIC_INFO_TYPE + nodeBasicInfo.threadId = threadId; + nodeBasicInfo.timeStamp = endTime; + nodeBasicInfo.data.nodeBasicInfo.opName = opName; + nodeBasicInfo.data.nodeBasicInfo.taskType = 0; // MSPROF_GE_TASK_TYPE_AI_CORE + nodeBasicInfo.data.nodeBasicInfo.opType = opName; + nodeBasicInfo.data.nodeBasicInfo.blockDim = gridX; + MsprofReportCompactInfo(0, &nodeBasicInfo, sizeof(MsprofCompactInfo)); + } + return ret; +} + +static void prepare_tiling(void* args, void* tiling_func, int64_t tilingSize, void* arg_tiling_host, + void* arg_tiling_device, uint32_t gridX, rtStream_t stream, uint32_t argsSize) +{ + uint32_t args_num = argsSize / sizeof(void *); + void **args_cast = static_cast(args); + + args_cast[args_num - 5] = arg_tiling_host; // MEM_CACHE.arg_tiling_host.get(); // 5: TilingMemrefAlignedOffset + args_cast[args_num - 4] = arg_tiling_host; //MEM_CACHE.arg_tiling_host.get(); // 4: TilingMemrefAllocatedOffset + + // tiling_func to update args + typedef int64_t (*mlir_tiling_func)(void*); + mlir_tiling_func func_tiling_pre = reinterpret_cast(tiling_func); + + // update args with tiling_key from tiling_func + + func_tiling_pre(args); + + // copy host arg_tiling to device arg_tiling, and also replace corresponding place in args + aclError err = aclrtMemcpy(arg_tiling_device, tilingSize, arg_tiling_host, + tilingSize, ACL_MEMCPY_HOST_TO_DEVICE); + if (err != ACL_ERROR_NONE) { + printf("aclrtMemcpy Failed, err: %d \n", err); + return; + } + + args_cast[args_num - 5] = arg_tiling_device; + args_cast[args_num - 4] = arg_tiling_device; +} + +rtError_t common_launch_dyn(char* kernelName, void* func, void* tiling_func, int64_t tilingSize, uint32_t gridX, + void* args, uint32_t argsSize, rtStream_t stream) +{ + unsigned long int beginTime = 0; + unsigned long int endTime = 0; + unsigned long int opName = 0; + unsigned int threadId = 0; + size_t length = strlen(kernelName); + + if (tilingSize != 0) { + void *arg_tiling_host = nullptr; + void *arg_tiling_device = nullptr; + aclError err = aclrtMallocHost((void **)&arg_tiling_host, tilingSize); + if (err != ACL_ERROR_NONE) { + printf("Failed to malloc arg_tiling_host, err: %d \n", err); + } + // malloc device memory for device arg_tiling + err = aclrtMalloc((void **)&arg_tiling_device, tilingSize, ACL_MEM_MALLOC_HUGE_FIRST); + if (err != ACL_ERROR_NONE) { + printf("Failed to malloc arg_tiling_device, err: %d \n", err); + } + prepare_tiling(args, tiling_func, tilingSize, arg_tiling_host, arg_tiling_device, gridX, stream, argsSize); + typedef void (*mlir_func)(uint32_t, void*, void*, void*); + mlir_func func_cast = (mlir_func)func; + if (torch_npu::profiler::GetTraceLevel() != -1) { + beginTime = MsprofSysCycleTime(); + } + func_cast(gridX, nullptr, stream, args); + } + else { + typedef void (*mlir_func)(uint32_t, void*, void*, void*); + mlir_func func_cast = (mlir_func)func; + if (torch_npu::profiler::GetTraceLevel() != -1) { + beginTime = MsprofSysCycleTime(); + } + func_cast(gridX, nullptr, stream, args); + } + + if (torch_npu::profiler::GetTraceLevel() != -1) { + endTime = MsprofSysCycleTime(); + opName = MsprofGetHashId(kernelName, length); + threadId = (unsigned int)(syscall(SYS_gettid)); + MsprofApi info; + info.magicNumber = 0x5a5a; // MSPROF_REPORT_DATA_MAGIC_NUM + info.level = 10000; // MSPROF_REPORT_NODE_LEVEL + info.type = 5; // MSPROF_REPORT_NODE_LAUNCH_TYPE + info.threadId = threadId; + info.reserve = 0; + info.beginTime = beginTime; + info.endTime = endTime; + info.itemId = opName; + MsprofReportApi(0, &info); + } + if (torch_npu::profiler::GetTraceLevel() >= 1) { + MsprofCompactInfo nodeBasicInfo; + nodeBasicInfo.magicNumber = 0x5a5a; // MSPROF_REPORT_DATA_MAGIC_NUM + nodeBasicInfo.level = 10000; // MSPROF_REPORT_NODE_LEVEL + nodeBasicInfo.type = 0; // MSPROF_REPORT_NODE_BASIC_INFO_TYPE + nodeBasicInfo.threadId = threadId; + nodeBasicInfo.timeStamp = endTime; + nodeBasicInfo.data.nodeBasicInfo.opName = opName; + nodeBasicInfo.data.nodeBasicInfo.taskType = 0; // MSPROF_GE_TASK_TYPE_AI_CORE + nodeBasicInfo.data.nodeBasicInfo.opType = opName; + nodeBasicInfo.data.nodeBasicInfo.blockDim = gridX; + MsprofReportCompactInfo(0, &nodeBasicInfo, sizeof(MsprofCompactInfo)); + } + + return RT_ERROR_NONE; +} + +void opcommand_call(const char *name, std::function launch_call) +{ + at_npu::native::OpCommand cmd; + cmd.Name(name) + .SetCustomHandler(launch_call) + .Run(); +} diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/cpp_common/cpp_common.h b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/cpp_common/cpp_common.h new file mode 100644 index 0000000000000000000000000000000000000000..688d464bec4c2adfeed4a2cb3af8606f5b1c116b --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/cpp_common/cpp_common.h @@ -0,0 +1,10 @@ +#define PY_SSIZE_T_CLEAN +#include +#include +#include "hacl_rt.h" + +rtError_t common_launch(char* kernelName, const void* func, uint32_t gridX, + void* args, uint32_t argsSize, rtStream_t stream); +rtError_t common_launch_dyn(char* kernelName, void* func, void* tiling_func, int64_t tilingSize, uint32_t gridX, + void* args, uint32_t argsSize, rtStream_t stream); +void opcommand_call(const char *name, std::function launch_call); \ No newline at end of file diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/__init__.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/codegen/__init__.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/codegen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/codegen/cpp_wrapper.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/codegen/cpp_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..bada89bfdc18a8e0ddddaf9e313d6a003e27f0b4 --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/codegen/cpp_wrapper.py @@ -0,0 +1,393 @@ +def cpp_launcher(signature, kernel_name, ranks, dynamic=False) -> str: + def _ty_to_cpp(ty): + if ty[0] == '*': + return "void*" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u32": "uint32_t", + "u64": "uint64_t", + "f16": "float", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + def _extracted_ty(ty): + if ty[0] == '*': + return "PyObject*" + return { + 'i1': 'int32_t', + 'i32': 'int32_t', + 'i64': 'int64_t', + 'u32': 'uint32_t', + 'u64': 'uint64_t', + 'f16': 'float', + 'fp16': 'float', + 'bf16': 'float', + 'fp32': 'float', + 'f32': 'float', + 'fp64': 'double', + }[ty] + + def _format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "uint32_t": "I", + "int32_t": "i", + "uint64_t": "K", + "int64_t": "L", + }[ty] + if dynamic: + arg_decls = ', '.join( + f"{_ty_to_cpp(ty)} arg{i}" + + ("" if "torch." in ty else f", {_ty_to_cpp(ty)} arg_allocate{i}, {_ty_to_cpp(ty)} offset{i}, " + + ', '.join(f"{_ty_to_cpp(ty)} sizes{i}_{rank}" for rank in range(ranks[i])) + ', ' + + ', '.join(f"{_ty_to_cpp(ty)} strides{i}_{rank}" for rank in range(ranks[i]))) + for i, ty in signature.items() + ) + format = "iKkkLOOO" + ''.join([_format_of(_extracted_ty(ty)) + ('' if "torch." in ty else _format_of(_extracted_ty(ty)) + 'L' + 'L'*ranks[i]*2) for i, ty in signature.items()]) + return f""" +#include +#include +#include +#include +#include + +typedef struct _DevicePtrInfo {{ + void *dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = reinterpret_cast(PyLong_AsLongLong(obj)); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = reinterpret_cast(PyLong_AsLongLong(ret)); + if(!ptr_info.dev_ptr) + return ptr_info; + Py_DECREF(ret); + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either int64 or have data_ptr method"); + return ptr_info; +}} + +static void _launch(void* func, void* tiling_func, int64_t tiling_size, rtStream_t stream, int gridX, {arg_decls}) {{ + // only 1D parallelization is supported for NPU + // Pointer type becomes flattend 1-D Memref tuple: base_ptr, data_ptr, offset, shape, stride + // base_ptr offset shape and stride are not used, arbitrarily set for now + + if (tiling_size == 0) {{ + auto launch_call = [func, tiling_func, tiling_size, gridX, stream, {', '.join(f"arg{i}" + ("" if "torch." in ty else f", arg_allocate{i}, offset{i}, " + ', '.join(f"sizes{i}_{rank}" for rank in range(ranks[i])) + ', ' + ', '.join(f"strides{i}_{rank}" for rank in range(ranks[i]))) for i, ty in signature.items())}]() {{ + struct __attribute__((packed)) {{ + + {' '.join(f'{_ty_to_cpp(ty)} arg{i} __attribute__((aligned({4 if ty[0] != "*" and ty[-2:] != "64" else 8}))); ' + ('' if "torch." in ty else f'{_ty_to_cpp(ty)} arg_allocate{i} __attribute__((aligned({4 if ty[0] != "*" and ty[-2:] != "64" else 8}))); {_ty_to_cpp(ty)} offset{i} __attribute__((aligned(8))); ' + ' '.join(f'{_ty_to_cpp(ty)} sizes{i}_{rank} __attribute__((aligned(8)));' for rank in range(ranks[i])) + ' ' + ' '.join(f'{_ty_to_cpp(ty)} strides{i}_{rank} __attribute__((aligned(8)));' for rank in range(ranks[i]))) for i, ty in signature.items())} + + }} args = {{ + {', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" + ("" if "torch." in ty else f", static_cast<{_ty_to_cpp(ty)}>(arg_allocate{i}), static_cast<{_ty_to_cpp(ty)}>(offset{i}), " + ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(sizes{i}_{rank})" for rank in range(ranks[i])) + ', ' + ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(strides{i}_{rank})" for rank in range(ranks[i]))) for i, ty in signature.items())} + + }}; + + rtError_t ret = common_launch_dyn(const_cast("{kernel_name}"), func, tiling_func, tiling_size, gridX, static_cast(&args), sizeof(args), stream); + return ret; + }}; + opcommand_call("{kernel_name}", launch_call); + }} else {{ + int64_t __attribute__((aligned(8))) key_tiling; + void* arg_tiling = 0; + void* arg_allocate_tiling = 0; + void* offset_tiling = 0; + void* sizes_tiling = (void*)(tiling_size / sizeof(int64_t)); + void* strides_tiling = (void*)1; + auto launch_call = [func, tiling_func, tiling_size, gridX, stream, {', '.join(f"arg{i}" + ("" if "torch." in ty else f", arg_allocate{i}, offset{i}, " + ', '.join(f"sizes{i}_{rank}" for rank in range(ranks[i])) + ', ' + ', '.join(f"strides{i}_{rank}" for rank in range(ranks[i]))) for i, ty in signature.items())}, key_tiling, arg_tiling, arg_allocate_tiling, offset_tiling, sizes_tiling, strides_tiling]() {{ + struct __attribute__((packed)) {{ + + {' '.join(f'{_ty_to_cpp(ty)} arg{i} __attribute__((aligned({4 if ty[0] != "*" and ty[-2:] != "64" else 8}))); ' + ('' if "torch." in ty else f'{_ty_to_cpp(ty)} arg_allocate{i} __attribute__((aligned({4 if ty[0] != "*" and ty[-2:] != "64" else 8}))); {_ty_to_cpp(ty)} offset{i} __attribute__((aligned(8))); ' + ' '.join(f'{_ty_to_cpp(ty)} sizes{i}_{rank} __attribute__((aligned(8)));' for rank in range(ranks[i])) + ' ' + ' '.join(f'{_ty_to_cpp(ty)} strides{i}_{rank} __attribute__((aligned(8)));' for rank in range(ranks[i]))) for i, ty in signature.items())} + + void* key_tiling __attribute__((aligned(8))); + void* arg_tiling __attribute__((aligned(8))); + void* arg_allocate_tiling __attribute__((aligned(8))); + void* offset_tiling __attribute__((aligned(8))); + void* sizes_tiling __attribute__((aligned(8))); + void* strides_tiling __attribute__((aligned(8))); + + }} args = {{ + {', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" + ("" if "torch." in ty else f", static_cast<{_ty_to_cpp(ty)}>(arg_allocate{i}), static_cast<{_ty_to_cpp(ty)}>(offset{i}), " + ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(sizes{i}_{rank})" for rank in range(ranks[i])) + ', ' + ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(strides{i}_{rank})" for rank in range(ranks[i]))) for i, ty in signature.items())+ ', '} + + (void*)(&key_tiling), static_cast(arg_tiling), static_cast(arg_allocate_tiling), static_cast(offset_tiling), static_cast(sizes_tiling), static_cast(strides_tiling) + }}; + + rtError_t ret = common_launch_dyn(const_cast("{kernel_name}"), func, tiling_func, tiling_size, gridX, static_cast(&args), sizeof(args), stream); + return ret; + }}; + opcommand_call("{kernel_name}", launch_call); + }} +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX; + rtStream_t stream; + PyObject *func; + PyObject *tiling_func; + int64_t tiling_size; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *metadata = NULL; + {'; '.join(f"{_extracted_ty(ty)} _arg{i}" + ("" if "torch." in ty else f"; {_extracted_ty(ty)} _arg_allocate{i}; int64_t offset{i}; " + ''.join(f"int64_t sizes{i}_{rank}; " for rank in range(ranks[i])) + '; '.join(f"int64_t strides{i}_{rank}" for rank in range(ranks[i]))) for i, ty in signature.items()) + '; '} + if(!PyArg_ParseTuple( + args, \"{format}\", + &gridX, &stream, &func, &tiling_func, &tiling_size, + &launch_enter_hook, &launch_exit_hook, &metadata + {', ' + ', '.join((f"&_arg{i}" + ("" if "torch." in ty else f", &_arg_allocate{i}, &offset{i}" + ', ' + ', '.join(f"&sizes{i}_{rank}" for rank in range(ranks[i])) + ', ' + ', '.join(f"&strides{i}_{rank}" for rank in range(ranks[i])))) for i, ty in signature.items()) if len(signature) > 0 else ''} + ) + ) {{ + return NULL; + }} + + + if (launch_enter_hook != Py_None && !PyObject_CallObject(launch_enter_hook, args)) {{ + return NULL; + }} + + // raise exception asap + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL" if ty[0] == "*" else "" for i, ty in signature.items()]) + "; "} + {"; ".join([f"DevicePtrInfo ptr_allocate_info{i} = getPointer(_arg_allocate{i}, {i}); if (!ptr_allocate_info{i}.valid) return NULL" if (ty[0] == "*" and "torch." not in ty) else "" for i, ty in signature.items()]) + "; "} + + _launch(reinterpret_cast(func), reinterpret_cast(tiling_func), tiling_size, stream, gridX, {', '.join([f"ptr_info{i}.dev_ptr" + ('' if "torch." in ty else ', ' + f"ptr_allocate_info{i}.dev_ptr, reinterpret_cast(offset{i}), " + ', '.join(f"reinterpret_cast(sizes{i}_{rank})" for rank in range(ranks[i])) + ', ' + ', '.join(f"reinterpret_cast(strides{i}_{rank})" for rank in range(ranks[i]))) for i, ty in signature.items()])}); + + if (PyErr_Occurred()) {{ + return NULL; + }} + if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{ + return NULL; + }} + + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyObject* get_host_func_and_tiling_size(PyObject* self, PyObject* args) {{ + const char *func_name; + const char *tiling_func_name; + const char *get_tiling_struct_size_func_name; + const char *so_file; + if(!PyArg_ParseTuple( + args, "ssss", &func_name, &tiling_func_name, &get_tiling_struct_size_func_name, &so_file + ) + ) {{ + return NULL; + }} + void *handle = dlopen(so_file, RTLD_LAZY); + if (handle == NULL) {{ + std::cout<<"handle == NULL!"<(func)), PyLong_FromUnsignedLong(reinterpret_cast(tiling_func)), PyLong_FromLongLong(tilingSize)); +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{"get_host_func_and_tiling_size", get_host_func_and_tiling_size, METH_VARARGS, "Get host func from kernel.so"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + "__launcher", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + format = "iKKOOO" + ''.join([_format_of(_extracted_ty(ty)) for ty in signature.values()]) + return f""" +#include +#include +#include + +typedef struct _DevicePtrInfo {{ + void *dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(obj)); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = reinterpret_cast(PyLong_AsUnsignedLongLong(ret)); + if(!ptr_info.dev_ptr) + return ptr_info; + Py_DECREF(ret); + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + return ptr_info; +}} + +static void _launch(const void* func, rtStream_t stream, int gridX, {arg_decls}) {{ + // only 1D parallelization is supported for NPU + // Pointer type becomes flattend 1-D Memref tuple: base_ptr, data_ptr, offset, shape, stride + // base_ptr offset shape and stride are not used, arbitrarily set for now + auto launch_call = [func, gridX, stream, {', '.join(f" arg{i}" for i, ty in signature.items())}]() {{ + struct __attribute__((packed)) {{ + {' '.join(f'{_ty_to_cpp(ty)} arg{i} __attribute__((aligned({4 if ty[0] != "*" and ty[-2:] != "64" else 8})));' for i, ty in signature.items())} + }} args = {{ + {', '.join(f'static_cast<{_ty_to_cpp(ty)}>(arg{i})' for i, ty in signature.items())} + }}; + + rtError_t ret = common_launch(const_cast("{kernel_name}"), func, gridX, static_cast(&args), sizeof(args), stream); + return ret; + }}; + opcommand_call("{kernel_name}", launch_call); +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX; + rtStream_t stream; + const void *function; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *metadata = NULL; + {' '.join([f"{_extracted_ty(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple( + args, \"{format}\", + &gridX, &stream, &function, + &launch_enter_hook, &launch_exit_hook, &metadata + {', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''} + ) + ) {{ + return NULL; + }} + + if (launch_enter_hook != Py_None && !PyObject_CallObject(launch_enter_hook, args)) {{ + return NULL; + }} + + // raise exception asap + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + + _launch(function, stream, gridX, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())}); + + if (PyErr_Occurred()) {{ + return NULL; + }} + if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{ + return NULL; + }} + + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + "__launcher", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/codegen/mlir.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/codegen/mlir.py new file mode 100644 index 0000000000000000000000000000000000000000..7cff96c65fdfdbf56ec0b883846190588b198a5f --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/codegen/mlir.py @@ -0,0 +1,467 @@ +import os +import copy +import sympy +import collections +import textwrap + +import torch +from sympy import Expr, Integer +from itertools import count +from torch._dynamo.utils import counters + +from typing import List, Union, Optional, Tuple, Any, Dict + +from torch.fx.experimental.proxy_tensor import make_fx +from torch._inductor.scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse +from torch.utils._ordered_set import OrderedSet +from torch._inductor.codegen.simd import ( + log, + OrderedSet, + EnableReduction, + DisableReduction, + SIMDKernel, + SIMDKernelFeatures, + MultiKernel, + code_hash +) +from torch._inductor.codegen.triton import ( + SIMDScheduling, + FixedTritonConfig, +) + +from torch._inductor.ir import IRNode + +from torch._inductor.codegen.common import ( + IndentedBuffer, + Kernel, +) + +from torch._inductor.codegen.triton import ( + TritonKernel +) + +from torch._inductor.utils import ( + get_fused_kernel_name, +) +from torch._inductor import config, ir, scheduler +from ... import config as anir_config +from torch._inductor.virtualized import V +from ...npu.utils import ( + MLIRProcessor, + parse_fx_example_inputs, + fx_graph_op_types, + npu_cast_to_prim_cast, + get_fx_graph_code, + scalarize_tensor_ops_on_scalars, + to_folder, + modify_gm_for_acc_comp, + get_num_call_functions, + is_fx_dynamic, + view_to_reshape +) + +if anir_config.enable_graph_trace: + from ...npu.inductor_patch.lowering import ( + merge_fx_graphs, + map_strings_to_operators + ) + +id_iter = count() + +class NpuMlirKernel(Kernel): + def __init__(self, + gm: torch.fx.GraphModule, + snodes: List[scheduler.SchedulerNode], + call_args: List[str], + non_contiguous_indices: List[int], + num_outputs: List[int]=None, + mutated_indices: List[int]=None + ): + super().__init__() + self._gm = gm + self._gm_with_prim_cast = self.build_gm_with_prim_cast(gm) + self._is_dynamic = is_fx_dynamic(self._gm) + if anir_config.online_acc_comp: + modify_gm_for_acc_comp(self._gm) + self._snodes = snodes + self._call_args = call_args + self.non_contiguous_indices = non_contiguous_indices + self.num_outputs = num_outputs + self.mutated_indices = mutated_indices + + + def imports_for_benchmark_kernel(self): + return textwrap.dedent( + """ + from torch._dynamo.testing import rand_strided + {} + import torch + """.format( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + ) + + def build_gm_with_prim_cast(self, gm): + gm_with_prim_cast = npu_cast_to_prim_cast(gm) + return gm_with_prim_cast + + def codegen_kernel(self, name=None): + + code = IndentedBuffer() + + import torch_mlir + from ..torch_mlir_patch import stateless_fx_import + from torch._functorch.aot_autograd import ( + set_model_name, + get_aot_compilation_context, + ) + from torch_mlir.compiler_utils import ( + run_pipeline_with_repro_report, + lower_mlir_module, + ) + from torch_mlir.compiler_utils import OutputType + + scalarize_tensor_ops_on_scalars(self._gm_with_prim_cast) + set_model_name(f'MODEL_NAME') + *_, model_name, nth_graph = get_aot_compilation_context() + mlir_module = stateless_fx_import( + self._gm_with_prim_cast, + model_name=model_name, + import_symbolic_shape_expressions=False) + run_pipeline_with_repro_report( + mlir_module, + # f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})", + f"builtin.module(torch-lower-to-backend-contract)", + "Lowering TorchFX IR -> Torch Backend IR", + ) + + with mlir_module.context: + for func in mlir_module.body.operations: + if isinstance(func, torch_mlir.dialects.func.FuncOp): + func.attributes["torch.assume_strict_symbolic_shapes"] = torch_mlir.ir.UnitAttr.get() + + code.splice(f'{str(mlir_module)}') + + return code.getvalue() + + def get_call_args(self): + return self._call_args + + def call_kernel(self, name: str, node: Optional[IRNode] = None): + wrapper = V.graph.wrapper_code + call_args = self.get_call_args() + for call_arg in call_args: + if call_arg.startswith('_'): + expression = map_strings_to_operators(call_arg) + wrapper.writeline(f'{call_arg} = {expression}') + if len(call_args) > 0: + wrapper.generate_kernel_call( + name, + call_args, + ) + + def codegen_debug_performance(self, fd): + from ...npu.utils import generate_compiler_repro_string, generate_fake_inputs + name_to_example_inputs = parse_fx_example_inputs(self._gm) + call_args_str = ", ".join(list(name_to_example_inputs.keys())) + fd.write( + generate_compiler_repro_string( + self._gm, + ) + ) + fd.write("\n") + fd.write("if __name__ == '__main__':\n") + fd.write(" from torch._inductor.utils import print_performance\n") + fd.write(f" with torch.no_grad():\n" + ) + fd.write(generate_fake_inputs(name_to_example_inputs)) + fd.write('\n') + fd.write( + f" fn = lambda: mod({call_args_str})\n" + f" print_performance(fn, times=10, repeat=10)\n" + ) + + +class NpuTritonKernel(TritonKernel): + def __init__(self, + tiling: Dict[str, sympy.Expr], + min_elem_per_thread=0, + optimize_mask=True, + fixed_config: Optional[FixedTritonConfig] = None, + **kwargs, + ): + super().__init__( + tiling, + min_elem_per_thread=min_elem_per_thread, + optimize_mask=optimize_mask, + fixed_config = fixed_config, + **kwargs, + ) + + @staticmethod + def inductor_meta_common(): + return {} + + def call_kernel(self, call_args, name: str): + wrapper = V.graph.wrapper_code + for call_arg in call_args: + if call_arg.startswith('_'): + expression = map_strings_to_operators(call_arg) + wrapper.writeline(f'{call_arg} = {expression}') + if len(call_args) > 0: + wrapper.generate_kernel_call( + name, + call_args, + ) + +def find_common_positions(list1, list2): + common_elements = set(list1) & set(list2) + merged_list = list1 + list2 + positions = [index for index, element in enumerate(merged_list) if element in common_elements] + + return merged_list, positions + +def create_fx_from_snodes_by_traced_graph(snodes: List[scheduler.SchedulerNode], triton_kernel: TritonKernel): + call_inputs = [] + for snode in snodes: + snode.node.data.traced_graph.last_node.name = snode.node.get_name() + if len(snodes) == 1: + traced_graph = snodes[0].node.data.traced_graph + else: + traced_graph = merge_fx_graphs([snode.node.data.traced_graph for snode in snodes]) + inputs = [] + for node in traced_graph.graph.nodes: + if node.op == 'placeholder': + call_inputs.append(node.target) + inputs.append(node.meta['val']) + non_contiguous_indices = {} + non_contiguous_indices["inputs"] = [i for i, inp in enumerate(inputs) if torch.is_tensor(inp) and not inp.is_contiguous()] + num_inputs = len(call_inputs) + call_outputs = [] + for snode in snodes: + if snode.has_aliasing_or_mutation(): + for buf in snode.get_outputs(): + if len(buf.get_mutations()): + call_outputs.extend(buf.get_mutations()) + elif len(buf.get_aliases()): + call_outputs.append(buf.get_name()) + elif snode.node.get_name() not in (V.graph.removed_buffers | V.graph.inplaced_to_remove): + call_outputs.append(snode.node.get_name()) + num_outputs = len(call_outputs) + call_args, mutated_indices = find_common_positions(call_inputs, call_outputs) + outputs = traced_graph.last_node if isinstance(traced_graph.last_node, List) \ + else [traced_graph.last_node] + outputs = [output for output in outputs if output.name not in (V.graph.removed_buffers | V.graph.inplaced_to_remove)] + traced_graph.graph.output(tuple(outputs)) + traced_graph.graph.lint() + orig_module = torch.nn.Module() + gm = torch.fx.GraphModule(orig_module, traced_graph.graph) + gm.recompile() + def runnable_gm(*args): + return torch.fx.Interpreter(gm).run(*args) + with V.graph.fake_mode: + gm = make_fx(runnable_gm)(*inputs) + view_to_reshape(gm) + non_contiguous_indices["outputs"] = [i + num_inputs for i, call_output in enumerate(call_outputs) \ + if not V.graph.try_get_buffer(call_output).layout.is_contiguous()] + return (gm, call_args, {"num_outputs": num_outputs, + "non_contiguous_indices": non_contiguous_indices, + "mutated_indices": mutated_indices,}) + + +class NpuMlirScheduling(SIMDScheduling): + + kernel_type = NpuTritonKernel + + def __init__(self, scheduler: Scheduler): + super().__init__(scheduler) + self.orig_fnode_name_to_fnode = {} + + def define_kernel(self, src_code, mlir_kernel, traced_graph, mode=None): + if mode is None: + mode = anir_config._get_compile_mode() + kernel_key = (src_code, tuple(mlir_kernel.non_contiguous_indices)) + wrapper = V.graph.wrapper_code + + if kernel_key in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[kernel_key] + else: + fused_kernel_name = get_fused_kernel_name(mlir_kernel._snodes, config.triton.descriptive_names) + if mode in ["complete_fallback", "auto_fallback"]: + fx_graph_suffix = f"{next(id_iter)}" + else: + kernel_suffix = V.graph.wrapper_code.next_kernel_suffix() + kernel_name = "_".join(["mlir", fused_kernel_name, fx_graph_suffix if mode in ["complete_fallback", "auto_fallback"] else kernel_suffix]) + + traced_graph_hash = code_hash(traced_graph.print_readable(print_output=False) + kernel_name) + + num_call_functions = get_num_call_functions(mlir_kernel._gm) + + if num_call_functions <= 1 or kernel_name in anir_config.force_fallback_kernel_names: + mode = "complete_fallback" + wrapper.src_to_kernel[kernel_key] = kernel_name + + if mode in ["auto_fallback", "default"]: + src_code = src_code.replace("MODEL_NAME", kernel_name) + + mlir_processor = MLIRProcessor() + src_code, kernel_info = mlir_processor.get_named_op_str(src_code, kernel_name, dynamic=mlir_kernel._is_dynamic) + current_device = V.graph.get_current_device_or_throw() + + kernel_meta = { + 'device_str': current_device.type, + 'device_index': current_device.index, + 'num_outputs': mlir_kernel.num_outputs, + 'non_contiguous_indices': mlir_kernel.non_contiguous_indices, + 'dynamic': mlir_kernel._is_dynamic, + 'mutated_indices': mlir_kernel.mutated_indices, + 'traced_graph_cache': anir_config.traced_graph_cache, + 'traced_graph_hash': traced_graph_hash, + 'num_call_functions': num_call_functions, + **kernel_info + } + + compile_wrapper = IndentedBuffer() + if mode == "auto_fallback": + compile_wrapper.writeline(f"{kernel_name} = async_compile.mlir_auto_fallback({kernel_name!r}, '''") + compile_wrapper.splice(src_code, strip=True) + if 'PY_DIR_PATH' in os.environ: + kernel_path = os.path.join(os.environ['PY_DIR_PATH'], kernel_name + '.mlir') + with open(kernel_path, 'w') as f: + f.write(src_code) + line = f"''', kernel_meta={kernel_meta})" + compile_wrapper.writeline(line) + metadata_comment = '' + wrapper.header.splice(f"\n\n{metadata_comment}{compile_wrapper.getvalue()}") + elif mode == "complete_fallback": + compile_wrapper.writeline(f"async_compile.import_fx({kernel_name!r}, kernel_meta={kernel_meta})") + metadata_comment = f'"""\n{mlir_kernel._gm.print_readable(print_output=False)}\n"""' + wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), metadata_comment) + elif mode == "default": + compile_wrapper.writeline(f"async_compile.mlir({kernel_name!r}, '''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline( + f"''', device_str='{V.graph.scheduler.current_device.type}')" + ) + metadata_comment = '' + if anir_config.debug: + with open(f"{anir_config.debug_dir}/fx_graph_runnable_{kernel_name}.py", 'w') as fd: + mlir_kernel.codegen_debug_performance(fd) + comment = 'related Nodes: ' + '+'.join(fx_graph_op_types(mlir_kernel._gm)) + metadata_comment = f"'''\nsource fx graph:\n{comment}\n{mlir_kernel._gm.print_readable(print_output=False)}\n'''" + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + + num_args = len(mlir_kernel._gm.code.split('forward(', )[1].split(')')[0].split(', ')) - 1 + + if mode in ["complete_fallback", "auto_fallback"]: + dump_path = os.path.join(os.getenv("TORCHINDUCTOR_CACHE_DIR"), anir_config.traced_graph_cache, str(current_device.index), traced_graph_hash) + if not os.path.exists(dump_path): + os.makedirs(dump_path, exist_ok=True) + to_folder(mlir_kernel._gm, dump_path, graph_hash=traced_graph_hash, module_name=traced_graph_hash) + + if anir_config.fx_subgraph_dump_path is not None and mode in ["complete_fallback", "auto_fallback"]: + subgraph_dump_path = os.path.join(anir_config.fx_subgraph_dump_path, str(current_device.index), kernel_name) + os.makedirs(subgraph_dump_path, exist_ok=True) + + if mode == "complete_fallback": + fx_graph_code = get_fx_graph_code(mlir_kernel._gm.code, num_args, runnable=False) + runnable_fx_graph_code = get_fx_graph_code(mlir_kernel._gm.code, num_args, runnable=True) + else: + fx_graph_code = get_fx_graph_code(mlir_kernel._gm.code, num_args, runnable=False, kernel_code=compile_wrapper.getvalue(), \ + kernel_name=kernel_name) + runnable_fx_graph_code = get_fx_graph_code(mlir_kernel._gm.code, num_args, runnable=True, kernel_code=compile_wrapper.getvalue(), \ + kernel_name=kernel_name) + with open(os.path.join(subgraph_dump_path, f'{kernel_name}.py'), 'w') as f: + f.write(fx_graph_code) + with open(os.path.join(subgraph_dump_path, f'runnable_{kernel_name}.py'), 'w') as f: + f.write(runnable_fx_graph_code) + + if mode == "auto_fallback": + with open(os.path.join(subgraph_dump_path, f'{kernel_name}.mlir'), 'w') as f: + f.write(src_code) + + return kernel_name + + # transform indexing before call codegen_node_schedule_with_kernel + def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures, nodes): + node_schedule = kernel_features.node_schedule + tiling = self.select_tiling( + node_schedule, kernel_features.numel, kernel_features.reduction_numel + ) + kernels = self.create_kernel_choices( + kernel_features, [tiling], {"features": kernel_features} + ) + for kernel in kernels: + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + MultiKernel.merge_workspaces_inplace(kernels) + for kernel in kernels: + V.graph.removed_buffers |= kernel.removed_buffers + V.graph.inplaced_to_remove |= kernel.inplaced_to_remove + if not anir_config.traced_graph_cache: + anir_config.traced_graph_cache = "traced_graph_cache" + os.makedirs(os.path.join(os.getenv("TORCHINDUCTOR_CACHE_DIR"), anir_config.traced_graph_cache), exist_ok=True) + traced_graph, call_args, compile_kwargs = create_fx_from_snodes_by_traced_graph(nodes, kernel) + mlir_kernel = NpuMlirKernel(traced_graph, nodes, call_args, **compile_kwargs) + with V.set_kernel_handler(mlir_kernel): + src_code = mlir_kernel.codegen_kernel() + kernel_name = self.define_kernel(src_code, mlir_kernel, traced_graph) + log.debug("Generating kernel code with kernel_name: %s", kernel_name) + kernel.kernel_name = kernel_name + kernel.code_hash = code_hash(src_code) + del kernel + + final_kernel: Union[SIMDKernel, MultiKernel] + if len(kernels) > 1: + raise RuntimeError("MultiKernel not Implemented!") + else: + (final_kernel,) = kernels + + with V.set_kernel_handler(final_kernel): + for node in kernel_features.scheduler_nodes(): + node.mark_run() + + self.codegen_comment(node_schedule) + final_kernel.call_kernel(call_args, final_kernel.kernel_name) + + if config.nan_asserts: + final_kernel.codegen_nan_check() + if config.warn_mix_layout: + final_kernel.warn_mix_layout(kernels[0].kernel_name) + + if ( + V.graph.wrapper_code.supports_intermediate_hooks + and config.generate_intermediate_hooks + ): + # Not every node in the schedule will actually be live on output; + # we can't check dead buffers. + live_outs = kernels[0].args.live_output_buffers() + for node in kernel_features.scheduler_nodes(): + name = node.get_name() + if name not in live_outs: + continue + if node.node is None: + raise RuntimeError("assert node.node is not None") + + origin_node = node.node.get_origin_node() + if origin_node is not None: + counters["inductor"]["intermediate_hooks"] += 1 + V.graph.wrapper_code.writeline( + f"run_intermediate_hooks({origin_node.name!r}, {name})" + ) + + self.scheduler.free_buffers() + + + def codegen_node(self, node: Union[scheduler.FusedSchedulerNode, scheduler.SchedulerNode]): + """ + Given a set of pre-fused nodes, generate a Mlir kernel. + """ + nodes: List[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment] + + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + kernel_features = SIMDKernelFeatures(node_schedule, numel, rnumel) + return self.codegen_node_schedule(kernel_features, nodes) \ No newline at end of file diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/codegen/wrapper.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/codegen/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..8c5d9136cf27eac59e873ee3f07399abecd96da9 --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/codegen/wrapper.py @@ -0,0 +1,152 @@ +import sympy +import functools +import torch +from torch._inductor.virtualized import V +from torch._inductor import config, ir + +from typing import List, Optional, Tuple, Union, Callable, Dict +from torch._inductor.codegen.wrapper import ( + PythonWrapperCodegen, + pexpr, + cache_on_self, + SubgraphPythonWrapperCodegen, + counters, +) +from ... import codecache +from torch._inductor.codegen.common import ( + IndentedBuffer, +) + +class NpuMlirWrapperCodeGen(PythonWrapperCodegen): + def __init__(self): + super().__init__() + self.write_get_raw_stream = functools.lru_cache(None)( # type: ignore[assignment] + self.write_get_raw_stream + ) + + @staticmethod + def create( + is_subgraph: bool, subgraph_name: str, parent_wrapper: PythonWrapperCodegen + ): + if is_subgraph: + return SubgraphPythonWrapperCodegen(subgraph_name, parent_wrapper) + return NpuMlirWrapperCodeGen() + + @cache_on_self + def write_triton_header_once(self) -> None: + self.header.splice( + """ + {} + """.format( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + ) + + def write_header(self) -> None: + self.header.splice( + f""" + from ctypes import c_void_p, c_long + import torch + import torch_npu + import math + import random + import os + import tempfile + from math import inf, nan + from torch._inductor.hooks import run_intermediate_hooks + from torch._inductor.utils import maybe_profile + from torch._inductor.codegen.memory_planning import _align as align + + from torch import device, empty_strided + from {codecache.__name__} import CustomAsyncCompile + from torch._inductor.select_algorithm import extern_kernels + from torch._inductor.codegen.multi_kernel import MultiKernelCall + + aten = torch.ops.aten + inductor_ops = torch.ops.inductor + assert_size_stride = torch._C._dynamo.guards.assert_size_stride + empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu + empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda + alloc_from_pool = torch.ops.inductor._alloc_from_pool + reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor + async_compile = CustomAsyncCompile() + + """ + ) + + def generate_extern_kernel_alloc(self, extern_kernel, args): + # If it's a NoneLayout then the extern_kernel should essentially be + # treated as if it doesn't return anything + no_return = isinstance(extern_kernel.layout, ir.NoneLayout) + output_name = extern_kernel.get_name() + origin_node = extern_kernel.get_origin_node() + kernel_name = extern_kernel.get_kernel_name() + ending = self.ending + if config.memory_planning and "view_as_complex" in kernel_name: + # view operation fallbacks cause issues since inductor + # doesn't know the memory is still needed and might reuse it. + ending = f".clone(){ending}" + + if no_return: + self.writeline(f"{self.declare}{kernel_name}({', '.join(args)}){ending}") + else: + self.writeline( + f"{self.declare}{output_name} = {kernel_name}({', '.join(args)}){ending}" + ) + if kernel_name == 'torch.ops.npu_stream.npu_set_stream.default': + device_idx = V.graph.scheduler.current_device.index + name = f'stream{device_idx}' + self.writeline(f"{name} = get_raw_stream({device_idx})") + if ( + self.supports_intermediate_hooks + and config.generate_intermediate_hooks + and origin_node is not None + ): + counters["inductor"]["intermediate_hooks"] += 1 + self.writeline( + f"run_intermediate_hooks({origin_node.name!r}, {output_name})" + ) + + def write_get_raw_stream(self, device_idx: int, graph=None) -> str: + self.write_triton_header_once() + name = f"stream{device_idx}" + self.writeline(f"{name} = get_raw_stream({device_idx})") + self.header.writeline( + f"torch_npu.npu.set_device({device_idx})" + ) + return name + + def generate_kernel_call( + self, + kernel_name, + call_args, + grid=None, + device_index=None, + gpu=True, + triton=True, + arg_types=None, + raw_args=None, + grid_fn: str = "grid", + triton_meta=None, + autotune_configs=None, + grid_extra_kwargs="", + ): + """ + Generates kernel call code. + + cuda: Defines whether the backend is GPU. Otherwise the backend is CPU. + + triton: Defines whether the GPU backend uses Triton for codegen. + Otherwise it uses the CUDA language for codegen. + Only valid when cuda == True. + """ + if gpu: + call_args_str = ", ".join(pexpr(item) for item in call_args) + stream_name = self.write_get_raw_stream( + V.graph.scheduler.current_device.index, V.graph + ) + self.writeline( + f"{kernel_name}.run({call_args_str}, stream={stream_name})" + ) + else: + self.writeline(self.wrap_kernel_call(kernel_name, call_args)) diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/inductor_patch/__init__.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/inductor_patch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..969627517014411916ec7cb99e352755fa3c6989 --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/inductor_patch/__init__.py @@ -0,0 +1,40 @@ +import os +import importlib +import inspect +import pkgutil + + +__all__ = list(module for _, module, _ in pkgutil.iter_modules([os.path.dirname(__file__)])) + +from . import ir +from . import lowering as npu_lowering +from torch._inductor import lowering +import sys + +def get_functions_from_module(module): + functions = {} + members = inspect.getmembers(module, inspect.isfunction) + + for name, func in members: + if inspect.getmodule(func) == module: + functions[name] = func + + return functions + +npu_functions = get_functions_from_module(npu_lowering) +functions = get_functions_from_module(lowering) +for name, _ in functions.items(): + setattr(lowering, name, npu_functions[name]) + +extra_lowerings = set(lowering.lowerings.keys()) - set(npu_lowering.lowerings.keys()) +npu_lowering.lowerings.update({k: lowering.lowerings[k] for k in extra_lowerings}) +lowering.lowerings = npu_lowering.lowerings +lowering._maybe_layout_constraints = npu_lowering._maybe_layout_constraints +lowering.fallbacks = npu_lowering.fallbacks +lowering.needs_realized_inputs = npu_lowering.needs_realized_inputs +lowering.foreach_ops = npu_lowering.foreach_ops +lowering.inplace_foreach_ops = npu_lowering.inplace_foreach_ops +lowering.inplaceable_foreach_ops = npu_lowering.inplaceable_foreach_ops + +from torch._inductor import graph +importlib.reload(graph) \ No newline at end of file diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/inductor_patch/fake_tensor.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/inductor_patch/fake_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..ad9104cd0dea2d5307e406fc18ccc643d2112de6 --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/inductor_patch/fake_tensor.py @@ -0,0 +1,84 @@ +import torch +from torch._subclasses import fake_tensor +from torch._ops import OpOverload +from torch._subclasses.fake_tensor import ( + Tensor, + FakeTensor, + FakeTensorMode, + _StoragePointer, + Sequence, + PyTree, + pytree, + no_dispatch, + is_sparse_any, + Union, + Set, + T +) + +def _npu_run_fallback_kernel( + fake_mode: FakeTensorMode, + func: OpOverload, + flat_args: Sequence[object], + args_spec: PyTree, + orig_not_implemented_exception: RuntimeError, +) -> FakeTensor: + # these should all be supported, just to be safe + # avoid fallback for operators which inplace modify metadata + # because the input fake tensors would be umodified + if torch.Tag.inplace_view in func.tags: + raise orig_not_implemented_exception + + inp_impls = {} + + # Don't use in_kernel_invocation_manager(fake_mode) as we want to do + # REAL compute (not with meta device) + with no_dispatch(): + + def to_real_tensor(e: T) -> Union[T, Tensor]: + if fake_mode.is_our_fake(e): + out = torch.zeros(e.shape, dtype=e.dtype, device=e.fake_device) + if e.is_sparse: + out._coalesced_(e.is_coalesced()) + inp_impls[id(out)] = e + return out + return e + + flat_args = [to_real_tensor(a) for a in flat_args] + args, kwargs = pytree.tree_unflatten(flat_args, args_spec) + + r = func(*args, **kwargs) + + storages: Set[_StoragePointer] = set() + + for e in flat_args: + if isinstance(e, Tensor): + if not is_sparse_any(e): + storages.add(e._typed_storage()._cdata) + + # TODO: also check metadata change on inputs + # proper aliasing/metadata relationship between outputs and inputs will + # not be set up, bc of conversion to device, unless we can reuse an + # input impl + + def map_out(e: T) -> Union[T, FakeTensor]: + if id(e) not in inp_impls and ( + isinstance(e, Tensor) + and not is_sparse_any(e) + and e._typed_storage()._cdata in storages + ): + raise orig_not_implemented_exception + + if isinstance(e, Tensor): + if id(e) in inp_impls: + return inp_impls[id(e)] + else: + return fake_mode.fake_tensor_converter.from_real_tensor(fake_mode, e) + else: + return e + + return pytree.tree_map(map_out, r) + + +def _patch_fake_tensor(): + fake_tensor.run_fallback_kernel = _npu_run_fallback_kernel \ No newline at end of file diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/inductor_patch/ir.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/inductor_patch/ir.py new file mode 100644 index 0000000000000000000000000000000000000000..0e6196e1b369706f94d25432e103bf5cae154003 --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/inductor_patch/ir.py @@ -0,0 +1,707 @@ +import traceback +from unittest.mock import patch + +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + TYPE_CHECKING, + Union, +) + +import functools + +import sympy +from sympy import Expr, Integer + +import torch +from torch._inductor import ir +from torch._inductor import config + +from torch._inductor.virtualized import ops, V +from torch._subclasses import FakeTensor +from torch.utils._ordered_set import OrderedSet + +from .lowering import ( + fetch_graphs, + merge_traced_graphs, + node_id, + clone, + create_fake_input, + subtract_graph +) + +def _patch_loops_get_name(self): + return self.node_name + +def _patch_loops_get_traced_graph(self): + return self.traced_graph + +@classmethod +def _patch_loops_create(cls, *args, **kwargs): + origin_node = kwargs.pop("origin_node", None) + traced_graph = kwargs.pop("traced_graph", None) + node_name = kwargs.pop("node_name", None) + tb = kwargs.pop("traceback", None) + r = cls(*args, **kwargs) + # Need to explicitly set origin_node here to propagate it down. + # todo(chilli): I think it would be better for IRNode to directly set + # origin_node + r._post_init_setattr("origin_node", origin_node) + r._post_init_setattr("traceback", tb or r.traceback) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return ir.TensorBox.create(r) + +ir.Loops.get_name = _patch_loops_get_name +ir.Loops.get_traced_graph = _patch_loops_get_traced_graph +ir.Loops.create = _patch_loops_create + + +def _patch_pointwise_constant_to_device(self, device, traced_graph=None, node_name=None): + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ir.ConstantBuffer, "override_device", device)(loader) + + r = ir.Pointwise(device, self.dtype, loader, self.ranges) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + +ir.Pointwise.constant_to_device = _patch_pointwise_constant_to_device + +@classmethod +def _patch_reduction_create( # type: ignore[override] + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + ranges: List[Expr], + reduction_ranges: List[Expr], + reduction_type: str, + reduction_hint: ir.ReductionHint = ir.ReductionHint.DEFAULT, + input_node: Optional[ir.IRNode] = None, + traced_graph = None, + node_name: str = None + +): + reduction_numel = V.graph.sizevars.simplify(ir.sympy_product(reduction_ranges)) + + if reduction_numel == 0: + # N.B. This is a hack to generate the literal of the given type + # Ideally, we should be fixing `def constant` in triton.py + # but it breaks due to hardcoded dtypes in other places + def py_cnst(val): + return ( + bool(val) + if dst_dtype == torch.bool + else float(val) + if dst_dtype.is_floating_point + else int(val) + ) + + rtypes_to_inits = { + "sum": py_cnst(0), + "xor_sum": py_cnst(0), + "prod": py_cnst(1), + "any": py_cnst(0), + # "all" is desugared to `!any(!val)` + } + + assert ( + reduction_type in rtypes_to_inits.keys() + ), f"{reduction_type} not supported for zero-dimension tensors!" + + def const_fn(index): + return ops.constant(rtypes_to_inits[reduction_type], dst_dtype) + + + return ir.Pointwise.create( + device=device, + dtype=src_dtype, + inner_fn=const_fn, + ranges=list(ranges), + traced_graph=traced_graph, + node_name=node_name + ) + + if reduction_numel == 1: + # this reduction is actually a pointwise op + if reduction_type in ("argmin", "argmax"): + + def fn(index): + return ops.constant(0, dst_dtype) + + else: + + def fn(index): + reduction_index = [sympy.Integer(0) for _ in reduction_ranges] + return inner_fn(index, reduction_index) + + return ir.Pointwise.create( + device=device, + dtype=dst_dtype, + inner_fn=fn, + ranges=ranges, + traced_graph=traced_graph, + node_name=node_name) + + if ( + isinstance(reduction_numel, sympy.Integer) + and V.graph.sizevars.size_hint(reduction_numel) + < config.unroll_reductions_threshold + and ir.sympy_product(ranges) != 1 + ): + return ir.Pointwise.create( + device=device, + dtype=dst_dtype, + inner_fn=cls._unroll_reduction_fn( + inner_fn, reduction_ranges, reduction_type, src_dtype + ), + ranges=ranges, + traced_graph=traced_graph, + node_name=node_name + ) + + r = ir.Reduction( + device=device, + dtype=dst_dtype, + inner_fn=inner_fn, + ranges=ranges, + reduction_ranges=reduction_ranges, + reduction_type=reduction_type, + src_dtype=src_dtype, + reduction_hint=reduction_hint, + ) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + + return ir.TensorBox.create(r) + +ir.Reduction.create = _patch_reduction_create + +def _patch_baseview_get_traced_graph(self): + if hasattr(self, 'traced_graph') and self.traced_graph is not None: + return self.traced_graph + return self.data.get_traced_graph() + +ir.BaseView.get_traced_graph = _patch_baseview_get_traced_graph + +def _patch_base_view_get_reads(self): + with patch.object(ir.FlexibleLayout, "allow_indexing", True): + r = ir.extract_read_writes( + self.make_loader(), + self.get_size(), + ).reads + for md in r: + if md.index.has(ir.ModularIndexing): + if md.index.has(ir.FloorDiv): + self.realize() + return r + else: + for m in md.index.find(ir.ModularIndexing): + for arg in m.args: + if arg.has(ir.ModularIndexing): + self.realize() + return r + return r + +ir.BaseView.get_reads = _patch_base_view_get_reads + +def try_get_buffer(inp): + if not hasattr(inp, 'data'): + return False + if isinstance(inp.data, ir.Buffer): + return inp.data + return try_get_buffer(inp.data) + +def _patch_baseview_realize(self): + if hasattr(self, 'traced_graph') and self.traced_graph is not None: + r = self.data.realize() + buffer = try_get_buffer(self) + if not buffer: + return r + if isinstance(buffer, (ir.MultiOutput, ir.InputBuffer, ir.ConcatKernel)): + return r + traced_graph = buffer.data.get_traced_graph() + buf_name = buffer.get_name() + new_traced_graph, placeholder = subtract_graph(self.traced_graph, traced_graph, node_name=buf_name) + if placeholder is not None: + placeholder.name = buf_name + device = buffer.get_device() + dtype = buffer.get_dtype() + size = buffer.get_size() + stride = buffer.get_stride() + fake_input = create_fake_input(size, stride, device, dtype) + placeholder.meta['val'] = fake_input + self._post_init_setattr("traced_graph", new_traced_graph) + return r + else: + return self.data.realize() + +def _patch_baseview_realize_hint(self): + if hasattr(self, 'traced_graph') and self.traced_graph is not None: + r = self.data.realize_hint() + buffer = try_get_buffer(self) + if not buffer: + return r + if isinstance(buffer, (ir.MultiOutput, ir.InputBuffer, ir.ConcatKernel)): + return r + traced_graph = buffer.data.get_traced_graph() + buf_name = buffer.get_name() + new_traced_graph, placeholder = subtract_graph(self.traced_graph, traced_graph, node_name=buf_name) + if placeholder is not None: + placeholder.name = buf_name + device = buffer.get_device() + dtype = buffer.get_dtype() + size = buffer.get_size() + stride = buffer.get_stride() + fake_input = create_fake_input(size, stride, device, dtype) + placeholder.meta['val'] = fake_input + self._post_init_setattr("traced_graph", new_traced_graph) + return r + else: + return self.data.realize_hint() + +def _patch_mark_reuse(self, users): + if hasattr(self, 'traced_graph') and self.traced_graph is not None: + r = self.data.mark_reuse(users) + buffer = try_get_buffer(self) + if not buffer: + return r + if isinstance(buffer, (ir.MultiOutput, ir.InputBuffer, ir.ConcatKernel)): + return r + traced_graph = buffer.data.get_traced_graph() + buf_name = buffer.get_name() + new_traced_graph, placeholder = subtract_graph(self.traced_graph, traced_graph, node_name=buf_name) + if placeholder is not None: + placeholder.name = buf_name + device = buffer.get_device() + dtype = buffer.get_dtype() + size = buffer.get_size() + stride = buffer.get_stride() + fake_input = create_fake_input(size, stride, device, dtype) + placeholder.meta['val'] = fake_input + self._post_init_setattr("traced_graph", new_traced_graph) + return r + else: + return self.data.mark_reuse(users) + +ir.BaseView.realize = _patch_baseview_realize +ir.BaseView.realize_hint = _patch_baseview_realize_hint +ir.BaseView.mark_reuse = _patch_mark_reuse + +@classmethod +def _patch_expandview_create(cls, x, new_size, traced_graph=None, node_name=None): + new_size = cls._normalize_size(x, new_size) + + if ir.is_storage_and_layout(x): + storage, old_layout = ir.as_storage_and_layout(x) + skip = len(new_size) - len(old_layout.size) + assert skip >= 0 + new_stride = [sympy.Integer(0)] * skip + for stride, size in zip(old_layout.stride, old_layout.size): + new_stride.append( + stride + if not V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(size, 1), size_oblivious=True + ) + else sympy.Integer(0) + ) + new_layout = ir.FixedLayout( + old_layout.device, + old_layout.dtype, + list(new_size), + new_stride, + old_layout.offset, + ) + + r = ir.ReinterpretView(data=storage, layout=new_layout) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + r = ir.ExpandView(data=x, size=new_size) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + + return r + +ir.ExpandView.create = _patch_expandview_create + +@classmethod +def _patch_permuteview_create(cls, x, dims, traced_graph=None, node_name=None): + dims = cls._map_neg_dims(dims) + assert OrderedSet(dims) == OrderedSet(range(len(dims))) + + if ir.is_storage_and_layout(x): + storage, old_layout = ir.as_storage_and_layout(x) + new_layout = ir.FixedLayout( + old_layout.device, + old_layout.dtype, + [old_layout.size[i] for i in dims], + [old_layout.stride[i] for i in dims], + old_layout.offset, + ) + r = ir.ReinterpretView(data=storage, layout=new_layout) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + + r = ir.PermuteView(data=x, dims=dims) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + +ir.PermuteView.create = _patch_permuteview_create + + +@classmethod +def _patch_view_create(cls, x, new_size, traced_graph=None, node_name=None): + assert isinstance(new_size, (tuple, list)) + old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size) + # Skip pointless views + if V.graph.sizevars.statically_known_list_equals(old_size, new_size): + return x + + unbacked_symbols_in_sizes = False + if ( + len(ir.free_unbacked_symbols(old_size)) > 0 + or len(ir.free_unbacked_symbols(new_size)) > 0 + ): + unbacked_symbols_in_sizes = True + + if 0 in new_size: + + def fake_reindex(index): + return tuple([0] * len(old_size)) + + r = cls(x, list(new_size), fake_reindex) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout + elif ir.is_contiguous_storage_and_layout(x) or unbacked_symbols_in_sizes: + if unbacked_symbols_in_sizes and (not ir.is_contiguous_storage_and_layout(x)): + # realize x; otherwise, the dynamic_reshape_indexer below will fail + # due to the size_hint's inability to process unbacked SymInts + x = ir.ExternKernel.realize_input(x) + + storage, old_layout = ir.as_contiguous_storage_and_layout(x) + new_layout = ir.FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + ir.FlexibleLayout.contiguous_strides(new_size), + old_layout.offset, + ) + + r = ir.ReinterpretView(data=storage, layout=new_layout) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + reindex = cls.dynamic_reshape_indexer(old_size, new_size) + + r = cls(data=x, size=list(new_size), reindex=reindex) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + +ir.View.create = _patch_view_create + + +@classmethod +def _patch_sliceview_create(cls, x, dim, start, end, step=1, clamp=True, traced_graph=None, node_name=None): # TODO: crm, clamp=True + step = sympy.expand(step) + assert isinstance(step, sympy.Expr) or step > 0 + try: + if start == 0 and end >= 2**63 - 1 and step == 1: + return x + except TypeError: + pass + sizevars = V.graph.sizevars + new_size = list(x.get_size()) + + if clamp: + start, end = cls.normalize_start_end(x, dim, start, end) + + new_size[dim] = ir.FloorDiv(end - start + (step - 1), step) + + if ir.is_storage_and_layout(x): + # Fast path + storage, old_layout = ir.as_storage_and_layout(x) + new_stride = list(old_layout.stride) + new_stride[dim] = new_stride[dim] * step + new_layout = ir.FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + new_stride, + old_layout.offset + old_layout.stride[dim] * start, + ) + r = ir.ReinterpretView(data=storage, layout=new_layout) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + def reindex(index): + assert len(index) == len(new_size), f"wrong ndim {index} {new_size}" + index = list(index) + index[dim] = index[dim] * step + start + return index + + # redirect to a generic view + r = ir.SliceView(data=x, size=new_size, reindex=reindex) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + +ir.SliceView.create = _patch_sliceview_create + +def _patch_buffer_get_traced_graph(self): + return self.traced_graph + +ir.Buffer.traced_graph = None +ir.Buffer.get_traced_graph = _patch_buffer_get_traced_graph + +def _patch_concatkernel_get_traced_graph(self): + return None + +@classmethod +def _patch_concatkernel_realize_into(cls, src, dst): + # Attempt to turn this into a ReinterpretView rather than assert. + # This has concessions around layout, as as_storage_and_layout + # can cause us to go from flexible to fixed layout. + if not isinstance(dst, ir.ReinterpretView): + if ir.is_storage_and_layout(dst): + storage, layout = ir.as_storage_and_layout(dst) + dst = ir.ReinterpretView(data=storage, layout=layout) + assert isinstance(dst, ir.ReinterpretView), dst + if isinstance(src, ir.TensorBox): + # unwrap a TensorBox + return cls.realize_into(src.data, dst) + if isinstance(src, ir.StorageBox): + src.realize() + # ExternKernelAlloc has specific requirements for output layout, should create a copy + assert hasattr(src.data, "layout") + if cls.can_realize_into_without_copy(src): + src.data.layout = ir.NonOwningLayout(dst) + return src.data + # introduce a copy + input_graphs = fetch_graphs(src) + node_name = f'clone_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, torch.ops.aten.clone, node_name) + pw = ir.Pointwise.create( + device=src.get_device(), + dtype=src.get_dtype(), + inner_fn=src.make_loader(), + ranges=[ + V.graph.sizevars.guard_equals(a, b) + for a, b in zip(src.get_size(), dst.get_size()) + ], + traced_graph = new_graph, + node_name = node_name + ) + return cls.realize_into(pw, dst) + +ir.ConcatKernel.get_traced_graph = _patch_concatkernel_get_traced_graph +ir.ConcatKernel.realize_into = _patch_concatkernel_realize_into + +def _patch_externkernel_copy_input(x): + traced_graph = x.get_traced_graph() + node_name = x.get_name() + if traced_graph is None: + traced_graph = fetch_graphs([x])[0] + node_name = f'getitem_{next(node_id)}' + pw = ir.Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + ranges=x.get_size(), + origin_node=x.get_origin_node(), + traceback=x.get_traceback(), + traced_graph=traced_graph, + node_name=node_name + ) + pw.realize() + return pw + +ir.ExternKernel.copy_input = _patch_externkernel_copy_input + +@classmethod +def _patch_externkernel_convert_to_reinterpret_view(cls, x): + """ + In order to pass this to an extern kernel we need a + ReinterpretView not a View. This allows us to avoid some + unneeded copies. + """ + assert isinstance(x, ir.BaseView) + if isinstance(x, ir.ReinterpretView): + return x + + # NOTE: Don't use extract_read_writes here as it fails when + # make_loader() inlines the computation + x_unwrap_view = x.unwrap_view() + buf = V.graph.get_buffer(x_unwrap_view.get_name()) + assert buf is not None + x_unwrap_view_fx_node = buf.get_origin_node() + # Prefer channels last format according to how the format is set from eager. + if ( + x_unwrap_view_fx_node is not None + and "val" in x_unwrap_view_fx_node.meta + and isinstance(x_unwrap_view.layout, ir.FlexibleLayout) + and ( + x_unwrap_view_fx_node.meta["val"].is_contiguous( + memory_format=torch.channels_last + ) + or x_unwrap_view_fx_node.meta["val"].is_contiguous( + memory_format=torch.channels_last_3d + ) + ) + ): + x_unwrap_view.freeze_layout_with_same_order( + ir.make_channels_last_strides_for(x_unwrap_view.get_size()) + ) + else: + x_unwrap_view.freeze_layout() + + index_args, var_ranges = ir.dependencies.index_vars_squeeze( + x.get_size(), prefix="r" + ) + range_vars = index_args[0] + index = x.make_indexer()(range_vars) + + index = V.graph.sizevars.simplify_with_ranges(index, var_ranges) + strides = V.graph.sizevars.stride_vars(index, range_vars) + offset = V.graph.sizevars.offset_var(index, range_vars) + expected = ir.sympy_dot(range_vars, strides) + offset + + if index != expected: + ir.log.debug( + "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s", + strides, + offset, + index, + ) + raise NotImplementedError + + r = ir.ReinterpretView( + data=x.data, + layout=ir.FixedLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=x.get_size(), + stride=strides, + offset=offset, + ), + ) + r._post_init_setattr("traced_graph", x.get_traced_graph()) + r._post_init_setattr("node_name", x.get_name()) + return r + +ir.ExternKernel.convert_to_reinterpret_view = _patch_externkernel_convert_to_reinterpret_view + +@classmethod +def _patch_devicecopy_create(cls, x, device, non_blocking, traced_graph=None, node_name=None): + if ( + not x.is_extern() + and all(r in V.graph.constants for r in x.get_read_names()) + and not config.aot_inductor.use_runtime_constant_folding + ): + return x.constant_to_device(device) + + V.graph.add_device_info(device) + V.graph.add_device_info(x.get_device()) + + ir.developer_warning("DeviceCopy in input program") + constant_args = (non_blocking,) + r = ir.DeviceCopy( + ir.FlexibleLayout( + device=device, + dtype=x.get_dtype(), + size=x.get_size(), + ), + [cls.realize_input(x)], + constant_args, + ) + + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + +def _patch_devicecopy_get_traced_graph(self): + return self.traced_graph + +ir.DeviceCopy.create = _patch_devicecopy_create +ir.DeviceCopy.get_traced_graph = _patch_devicecopy_get_traced_graph + + +def _patch_multioutput_get_traced_graph(self): + return None + +ir.MultiOutput.get_traced_graph = _patch_multioutput_get_traced_graph + +def _patch_mutablebox_get_name(self): + return self.data.get_name() + +def _patch_mutablebox_get_traced_graph(self): + return self.data.get_traced_graph() + +ir.MutableBox.get_name = _patch_mutablebox_get_name +ir.MutableBox.get_traced_graph = _patch_mutablebox_get_traced_graph + +@classmethod +def _patch_mutationlayout_realize_into(cls, src, dst, unsafe_alias=False): + dst.realize() + # NOTE: We must realize users of `dst` before we realize `src`, since + # realization order determines scheduling order. Otherwise, src's + # mutation would be scheduled before the existing users of dst! + V.graph.mark_buffer_mutated(dst.get_name()) + + if isinstance(src, ir.TensorBox): + src = src.data + + # We copy the contents of src into dst. In most cases this should + # be fused into a single kernel by the scheduler. + # NOTE: We cannot change src's layout to mutate dst directly as this + # would alias src to dst, which is not correct as further s to + # dst would effect users of src. However if there are no more users of + # dst, we can alias src to dst. + src.realize_hint() + + if not unsafe_alias: + + input_graphs = fetch_graphs([dst, src]) + node_name = f'copy__{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, torch.ops.aten.copy, node_name) + + src = ir.Pointwise.create( + device=src.get_device(), + dtype=src.get_dtype(), + inner_fn=src.make_loader(), + ranges=[ + V.graph.sizevars.guard_equals(a, b) + for a, b in zip(src.get_size(), dst.get_size()) + ], + traced_graph=new_graph, + node_name=node_name, + ).data + + src.realize() + assert isinstance(src.data.layout, ir.FlexibleLayout) + src.data.layout = ir.MutationLayoutSHOULDREMOVE(dst) + return src.data + +ir.MutationLayoutSHOULDREMOVE.realize_into = _patch_mutationlayout_realize_into diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/inductor_patch/lowering.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/inductor_patch/lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..72f5604bba7f21ca00afbf0c887122f3920ffb2b --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/inductor_patch/lowering.py @@ -0,0 +1,6946 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import dataclasses +import functools +import itertools +import logging +import math +import operator +import os +import warnings +from collections import defaultdict +from copy import deepcopy +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Set, + Tuple, + Union, + ) +from unittest.mock import patch + +import sympy +from sympy.core import Expr, Integer, Symbol + +import torch +import torch.ao.quantization.fx._decomposed +import torch.fx +import torch.utils._pytree as pytree +from torch._higher_order_ops.associative_scan import associative_scan_op +from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation +from torch._prims_common import ( + canonicalize_dim, + canonicalize_dims, + check, + dtype_to_type, + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + get_computation_dtype, + is_boolean_dtype, + is_float_dtype, + is_integer_dtype, + Number, +) +from torch.fx.experimental.sym_node import magic_methods, method_to_operator +from torch.utils._sympy.functions import ( + CeilDiv, + FloorDiv, + Identity, + IntTrueDiv, + ModularIndexing, +) + +from torch._dynamo.utils import import_submodule + +from torch._inductor import config, inductor_prims, ir, test_operators # NOQA: F401 +from torch._inductor.decomposition import decompositions, get_decompositions +from torch._inductor.ir import ( + DtypeView, + ExpandView, + IndexingConstant, + is_triton, + ops_wrapper, + PermuteView, + Pointwise, + Reduction, + SqueezeView, + TensorBox, + validate_ir, + View, +) +from torch._inductor.utils import ( + ceildiv, + decode_device, + is_dynamic, + is_pointwise_use, + pad_listlike, + parallel_num_threads, + sympy_product, + use_scatter_fallback, +) +from torch._inductor.virtualized import ops, V + +from ... import config as anir_config + + +log = logging.getLogger(__name__) +lowerings: Dict[torch._ops.OpOverload, Callable[..., Any]] = {} +# Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints +_maybe_layout_constraints: Dict[ + torch._ops.OpOverload, Optional[Callable[..., Any]] +] = {} +fallbacks: Set[torch._ops.OpOverload] = set() +aten = torch.ops.aten +tr_c10d = torch.ops.tr_c10d +prims = torch.ops.prims +needs_realized_inputs: Set[torch._ops.OpOverload] = set() +foreach_ops: Set[torch._ops.OpOverload] = set() +inplace_foreach_ops: Set[torch._ops.OpOverload] = set() +inplaceable_foreach_ops: Dict[torch._ops.OpOverload, torch._ops.OpOverload] = {} +quantized_decomposed = torch.ops.quantized_decomposed + +fn_to_aten_fn = {} +node_id = itertools.count(0) + +def register_fn_to_aten_fn(fn, aten_fn=None): + if fn not in fn_to_aten_fn: + fn_to_aten_fn[fn] = aten_fn + return fn + +def register_to_aten(aten_fn=None): + def decorator(fn): + if fn not in fn_to_aten_fn: + fn_to_aten_fn[fn] = aten_fn + return fn + return decorator + +reduction_type_to_aten_fn = { + "sum": aten.sum, + "prod": aten.prod, + "xor_sum": prims.xor_sum, + "any": aten.any, + "max": aten.amax, + "min": aten.amin, + "argmax": aten.argmax, + "argmin": aten.argmin +} + +operator_to_string = { + '+': 'a', + '-': 'sub', + '*': 'm', + '/': 'd', + '(': 'l', + ')': 'r', + '.': 'p' +} + +string_to_operator = {v: k for k, v in operator_to_string.items()} + +def map_operators_to_strings(expr_str: str): + expr_str = expr_str.replace(' ', '') + for op, string in operator_to_string.items(): + expr_str = expr_str.replace(op, string) + return '_' + expr_str + +def map_strings_to_operators(expr_str: str): + for op, string in string_to_operator.items(): + expr_str = expr_str.replace(op, string) + return expr_str[1:] + + + +class TracedGraph: + def __init__(self): + self.graph = torch.fx.Graph() + self.last_node: Optional[torch.fx.Node] = None + self.sym_nodes: Dict[str, torch.fx.Node] = {} + + def __str__(self): + return str(self.graph) + + def get_placeholder_names(self): + placeholder_names = set() + for node in self.graph.nodes: + if node.op == 'placeholder' and node.name not in self.sym_nodes: + placeholder_names.add(node.name) + return placeholder_names + + __repr__ = __str__ + + + +def create_fake_input(size, stride, device, dtype): + size = [V.graph.sizevars.shape_env.create_symintnode(s, hint=None) \ + if isinstance(s, Expr) and not isinstance(s, Integer) else s for s in size] + stride = [V.graph.sizevars.shape_env.create_symintnode(s, hint=None) \ + if isinstance(s, Expr) and not isinstance(s, Integer) else s for s in stride] + with V.graph.fake_mode: + fake_input = torch.empty_strided(size, stride, device=device, dtype=dtype) + return fake_input + + +def create_sym_inputs(traced_graph: TracedGraph, size: List[Expr]): + for s in size: + if isinstance(s, (List, Tuple)): + create_sym_inputs(traced_graph, s) + continue + if isinstance(s, Expr) and not isinstance(s, Integer): + s_name = str(s) + if not isinstance(s, Symbol): + s_name = map_operators_to_strings(s_name) + if s_name in traced_graph.sym_nodes: + continue + new_node = traced_graph.graph.placeholder(s_name) + new_node.meta['val'] = V.graph.sizevars.shape_env.create_symintnode(s, hint=None) + traced_graph.sym_nodes.update({s_name: new_node}) + + +def process_ir_constant(inp: ExpandView) -> Union[TracedGraph, int, float]: + skip = False + if isinstance(inp.data, IndexingConstant): + dtype = inp.data.dtype + inp = inp.data.index + # convert to original dtype. + if dtype in [torch.float32, torch.float16, torch.bfloat16]: + # sympy inputs + if isinstance(inp, Expr) and not isinstance(inp, Integer): + traced_graph = TracedGraph() + create_sym_inputs(traced_graph, [inp]) + s_name = str(inp) + if not isinstance(inp, Symbol): + s_name = map_operators_to_strings(str(inp)) + traced_graph.last_node = traced_graph.sym_nodes[s_name] + inp = traced_graph + else: + inp = float(inp) + elif isinstance(inp.data, ir.Constant): + dtype = inp.data.dtype + inp = inp.data.value + else: + skip = True + return inp, skip + + +def fetch_graphs(inputs: Optional[List[TensorBox]]): + if isinstance(inputs, (TensorBox, ir.StorageBox, ir.View, sympy.Symbol, ir.Constant, ir.ReinterpretView)): + inputs = [inputs] + input_graphs = [] + for inp in inputs: + if isinstance(inp, List): + input_graphs.append(fetch_graphs(inp)) + continue + if not isinstance(inp, (TensorBox, ir.StorageBox, ir.View, ir.ReinterpretView, ir.PermuteView, ir.SliceView, ir.ExpandView)): + input_graphs.append(inp) + continue + if isinstance(inp, ExpandView): + inp, skip = process_ir_constant(inp) + if not skip: + input_graphs.append(inp) + continue + name = inp.get_name() + traced_graph = inp.get_traced_graph() + if traced_graph is not None: + input_graphs.append(traced_graph) + continue + traced_graph = TracedGraph() + device = inp.get_device() + dtype = inp.get_dtype() + size = inp.get_size() + stride = inp.get_stride() + new_node = traced_graph.graph.placeholder(name) + fake_input = create_fake_input(size, stride, device, dtype) + new_node.meta['val'] = fake_input + traced_graph.last_node = new_node + input_graphs.append(traced_graph) + return input_graphs + + +def merge_traced_graphs(input_graphs: List[TracedGraph], origin_fn, node_name, **kwargs): + new_graph = TracedGraph() + exist_nodes: Dict[str, torch.fx.Node] = {} + def merge_graph(input_graphs: List[TracedGraph]): + for input_graph in input_graphs: + if isinstance(input_graph, List): + merge_graph(input_graph) + continue + if not isinstance(input_graph, TracedGraph): + continue + for node in input_graph.graph.nodes: + if node.name in exist_nodes: + continue + new_node = new_graph.graph.node_copy(node, lambda n: exist_nodes[n.name]) + exist_nodes[node.name] = new_node + if node.name in input_graph.sym_nodes: + new_graph.sym_nodes.update({node.name: new_node}) + + def parse_args(input_graphs, exist_nodes): + args = [] + for input_graph in input_graphs: + if isinstance(input_graph, TracedGraph): + args.append(exist_nodes[input_graph.last_node.name]) + elif isinstance(input_graph, (List, Tuple)): + args.append(parse_args(input_graph, exist_nodes)) + else: + if isinstance(input_graph, Expr) and not isinstance(input_graph, Integer): + if not isinstance(input_graph, Symbol): + input_graph = map_operators_to_strings(str(input_graph)) + args.append(new_graph.sym_nodes[str(input_graph)]) + else: + args.append(input_graph) + return args + + num_args = len(input_graphs) + + for k, v in kwargs.items(): + if isinstance(v, Expr) and not isinstance(v, Integer): + traced_graph = TracedGraph() + create_sym_inputs(traced_graph, [v]) + s_name = str(v) + if not isinstance(v, Symbol): + s_name = map_operators_to_strings(str(v)) + traced_graph.last_node = traced_graph.sym_nodes[s_name] + kwargs[k] = traced_graph.sym_nodes[s_name] + input_graphs.append(traced_graph) + merge_graph(input_graphs) + input_graphs = input_graphs[:num_args] + # if inputs do not have any valid graphs, like full/iota + create_sym_inputs(new_graph, input_graphs) + args = parse_args(input_graphs, exist_nodes) + with new_graph.graph.inserting_after(new_graph.last_node): + new_node = new_graph.graph.call_function(origin_fn, args=tuple(args), kwargs=kwargs) + new_node.name = node_name + new_graph.last_node = new_node + return new_graph + +def merge_fx_graphs(traced_graphs: List[TracedGraph]): + new_graph = TracedGraph() + exist_nodes: Dict[str, torch.fx.Node] = {} + last_nodes = [] + def merge_graph(input_graphs: List[TracedGraph]): + for input_graph in input_graphs: + if isinstance(input_graph, List): + merge_graph(input_graph) + continue + if not isinstance(input_graph, TracedGraph): + continue + for node in input_graph.graph.nodes: + if node.name in exist_nodes: + continue + new_node = new_graph.graph.node_copy(node, lambda n: exist_nodes[n.name]) + exist_nodes[node.name] = new_node + last_nodes.append(exist_nodes[input_graph.last_node.name]) + merge_graph(traced_graphs) + new_graph.last_node = last_nodes + return new_graph + +def subtract_graph(graph1: TracedGraph, graph2: TracedGraph, node_name=None) -> Tuple[TracedGraph, torch.fx.Node]: + new_graph = TracedGraph() + last_node2 = graph2.last_node + graph1_node_names = {node.name for node in graph1.graph.nodes} + graph2_node_names = {node.name for node in graph2.graph.nodes} + placeholder = None + exist_nodes: Dict[str, torch.fx.Node] = {} + if node_name not in graph1_node_names: + placeholder = new_graph.graph.placeholder(last_node2.name if node_name is None else node_name) + exist_nodes[last_node2.name] = placeholder + for node in graph1.graph.nodes: + if node.name in graph2_node_names and node.name not in graph1.sym_nodes: + continue + new_node = new_graph.graph.node_copy(node, lambda n: exist_nodes[n.name]) + exist_nodes[node.name] = new_node + new_graph.last_node = exist_nodes[graph1.last_node.name] + new_graph.sym_nodes = graph1.sym_nodes + return new_graph, placeholder + +def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., Any]]: + """Get layout constraints. Returns None if there are no layout constraints.""" + if not isinstance(fn, torch._ops.OpOverload): + # Only OpOverloads have layout constraints. + return None + if fn in _maybe_layout_constraints: + return _maybe_layout_constraints[fn] + # OpOverload with custom lowerings override tag-based layout constraints + if fn in lowerings: + _maybe_layout_constraints[fn] = None + return None + # We lazily register tag-based layout constraints. + + def handle_layout_constraint_tag(tag): + if tag is torch._C.Tag.needs_fixed_stride_order: + _maybe_layout_constraints[fn] = constrain_to_fx_strides + return _maybe_layout_constraints[fn] + elif tag is torch._C.Tag.flexible_layout: + _maybe_layout_constraints[fn] = None + return None + else: + raise AssertionError(f"Unknown layout constraint tag: {tag}") + + tag = get_layout_constraint_tag(fn) + return handle_layout_constraint_tag(tag) + + +def get_layout_constraint_tag(fn): + tags_by_priority = [ + torch._C.Tag.needs_fixed_stride_order, + torch._C.Tag.flexible_layout, + ] + for tag in tags_by_priority: + if tag in fn.tags: + return tag + if torch._library.utils.is_builtin(fn): + return torch._C.Tag.flexible_layout + return getattr(torch._C.Tag, config.custom_op_default_layout_constraint) + + +def assert_nyi(cond, msg): + if not cond: + raise NotImplementedError(f"inductor does not support {msg}") + + +def add_needs_realized_inputs(fn): + if isinstance(fn, (list, tuple, set)): + return [add_needs_realized_inputs(x) for x in fn] + needs_realized_inputs.add(fn) + if isinstance(fn, torch._ops.OpOverloadPacket): + needs_realized_inputs.update( + getattr(fn, overload) for overload in fn.overloads() + ) + + +def add_layout_constraint(fn, constraint): + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + _maybe_layout_constraints[getattr(fn, overload)] = constraint + else: + _maybe_layout_constraints[fn] = constraint + + +add_needs_realized_inputs( + [ + aten.as_strided, + aten.as_strided_copy, + aten.avg_pool2d, + aten.avg_pool2d_backward, + aten.bmm, + aten.convolution, + aten.convolution_backward, + aten.max_pool2d_with_indices, + aten.max_pool2d_with_indices_backward, + aten.mm, + aten.upsample_nearest2d, + aten._upsample_nearest_exact2d, + aten._int_mm, + ] +) + +# TODO(jansel): ezyang says we won't need this in the future, try removing it +# based on https://github.com/pytorch/pytorch/blob/9e3eb329df8f701/c10/core/ScalarType.h#L28 +DTYPE_ID_LOOKUP = { + 0: torch.uint8, + 1: torch.int8, + 2: torch.int16, + 3: torch.int32, + 4: torch.int64, + 5: torch.float16, + 6: torch.float32, + 7: torch.float64, + 8: torch.complex32, + 9: torch.complex64, + 10: torch.complex32, + 11: torch.bool, + 15: torch.bfloat16, + # TODO(jansel): add quantized types? + # _(c10::qint8, QInt8) /* 12 */ + # _(c10::quint8, QUInt8) /* 13 */ + # _(c10::qint32, QInt32) /* 14 */ + # _(c10::quint4x2, QUInt4x2) /* 16 */ + # _(c10::quint2x4, QUInt2x4) /* 17 */ +} + + +def decode_dtype(dtype: int): + if not isinstance(dtype, int): + return dtype + assert dtype in DTYPE_ID_LOOKUP, f"id {dtype} missing from DTYPE_ID_LOOKUP" + dtype = DTYPE_ID_LOOKUP[dtype] + return dtype + + +def is_integer_type(x): + if isinstance(x, TensorBox): + return is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + elif isinstance(x, sympy.Expr): + return x.is_integer is True # type: ignore[attr-defined] + else: + return isinstance(x, int) + + +def is_boolean_type(x): + if isinstance(x, TensorBox): + return is_boolean_dtype(x.get_dtype()) + else: + return isinstance(x, bool) + + +def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND): + def construct_input(inp): + if isinstance(inp, (Number, sympy.Basic)): + return inp + else: + assert hasattr(inp, "get_dtype") + dim = len(inp.get_size()) + # construct a tmp tensor to feed into torch.result_type + return torch.zeros([1] * dim, dtype=inp.get_dtype()) + + inps = [construct_input(arg) for arg in args] + _, dtype = elementwise_dtypes(*inps, type_promotion_kind=type_promotion_kind) + return dtype + + +def get_overloads(aten_fn): + if not isinstance(aten_fn, (list, tuple)): + aten_fn = [aten_fn] + else: + aten_fn = list(aten_fn) + + for fn in list(aten_fn): + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + other_fn = getattr(fn, overload) + if other_fn not in lowerings: + aten_fn.append(other_fn) + + return aten_fn + + +def in_namespace(op, namespace): + if isinstance(op, torch._ops.OpOverloadPacket): + return namespace in op._qualified_op_name + elif isinstance(op, torch._ops.OpOverload): + return namespace in op.name() + return False + + +def transform_args( + args: List[Any], + kwargs: Dict[str, Any], + broadcast: bool, + type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND], + convert_input_to_bool: bool, +) -> Tuple[List[Any], Dict[str, Any]]: + args_indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + kwargs_indices = [k for k, v in kwargs.items() if isinstance(v, TensorBox)] + # check that there's something to transform + if not args_indices and not kwargs_indices: + return args, kwargs + + if type_promotion_kind or convert_input_to_bool: + if convert_input_to_bool: + dtype = torch.bool + else: + # FIXME this is a crude approximation for promoting args + promoting_args = [ + a + for a in args + if isinstance(a, (Number, sympy.Basic)) or hasattr(a, "dtype") + ] + # only consider tensor kwargs for promotion, for now + promoting_args.extend(a for a in kwargs.values() if hasattr(a, "dtype")) + dtype = get_promoted_dtype( + *promoting_args, type_promotion_kind=type_promotion_kind # type: ignore[arg-type] + ) + + device = ( + args[args_indices[0]] if args_indices else kwargs[kwargs_indices[0]] + ).get_device() + + # sometimes args are an immutable list so we can't mutate them + def promote(arg): + if isinstance(arg, TensorBox): + return to_dtype(arg, dtype) + elif isinstance(arg, ir.Constant): + return ir.Constant(value=arg.value, dtype=dtype, device=device) + else: + return arg + + args = [promote(a) for a in args] + kwargs = {k: promote(v) for k, v in kwargs.items()} + + if broadcast: + broadcasted = broadcast_tensors( + *list( + itertools.chain( + (args[i] for i in args_indices), + (kwargs[k] for k in kwargs_indices), + ) + ) + ) + size = list(broadcasted[0].get_size()) + + for i, x in zip(args_indices, broadcasted[: len(args_indices)]): + args[i] = x + for k, x in zip(kwargs_indices, broadcasted[len(args_indices) :]): + kwargs[k] = x + + for i in range(len(args)): + if isinstance(args[i], ir.Constant): + args[i] = ExpandView.create(args[i], size) + for k in kwargs: + if isinstance(kwargs[k], ir.Constant): + kwargs[k] = ExpandView.create(kwargs[k], size) + + return args, kwargs + + +def _register_foreach_lowering(aten_fn, decomp_fn): + """ + Add a foreach lowering to lowerings dict. + + Arguments: + aten_fn: torch.ops.aten.* fn we are lowering + decomp_fn: alternate implementation on our IR + broadcast: True to apply broadcasting to tensor inputs + type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion + convert_input_to_bool: some logical ops require inputs are converted to bool + """ + + @functools.wraps(decomp_fn) + def wrapped(*args, **kwargs): + assert len(args) <= 2 + out = decomp_fn(*args, **kwargs) + validate_ir(out) + return out + + aten_fns = get_overloads(aten_fn) + foreach_ops.update(aten_fns) + lowerings.update(dict.fromkeys(aten_fns, wrapped)) + return wrapped + + +def _register_lowering( + aten_fn, + decomp_fn, + broadcast, + type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND], + convert_input_to_bool, +): + """ + Add a lowering to lowerings dict + + Arguments: + aten_fn: torch.ops.aten.* fn we are lowering + decomp_fn: alternate implementation on our IR + broadcast: True to apply broadcasting to tensor inputs + type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion + convert_input_to_bool: some logical ops require inputs are converted to bool + """ + + @functools.wraps(decomp_fn) + def wrapped(*args, **kwargs): + args: List[Any] = list(args) + kwargs: Dict[str, Any] = dict(kwargs) + unpacked = False + # TODO maybe we need to use pytrees here + if len(args) == 1 and isinstance(args[0], (list, tuple)): + unpacked = True + args = list(args[0]) + + if not all( + (fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn + ): + # explicitly assert for "out=" ops for better error messages + assert not any( + x == "out" for x in kwargs.keys() + ), "out= ops aren't yet supported" + + args, kwargs = transform_args( + args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool + ) + + if unpacked: + args = [args] + + out = decomp_fn(*args, **kwargs) + validate_ir(out) + + return out + + aten_fn = get_overloads(aten_fn) + + lowerings.update(dict.fromkeys(aten_fn, wrapped)) + return wrapped + + +def register_lowering( + aten_fn, + broadcast=False, + type_promotion_kind: Optional[ + ELEMENTWISE_TYPE_PROMOTION_KIND + ] = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + convert_input_to_bool=False, +): + """ + Shim to support decorator syntax. + """ + return functools.partial( + _register_lowering, + aten_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + ) + + +def broadcast_symbolic_shapes(a, b): + """ + Broadcasting logic based on symbolic shapes. + + We give the shapes 0 and 1 concrete values, while all other shapes + are symbolic sympy formulas. + """ + output = [] + for x, y in itertools.zip_longest( + reversed(a), reversed(b), fillvalue=sympy.Integer(1) + ): + if V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(y, 1), size_oblivious=True + ): + output.append(x) + elif V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(x, 1), size_oblivious=True + ): + output.append(y) + else: + V.graph.sizevars.guard_equals(x, y) + if len(sympy.expand(y).free_symbols) < len(sympy.expand(x).free_symbols): + output.append(y) # prefer shorter formula + else: + output.append(x) + return tuple(reversed(output)) + + +def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None): + assert ( + override_return_dtype is None or type_promotion_kind is None + ), "only one of override_return_dtype or type_promotion_kind may be given" + + if override_return_dtype is None and type_promotion_kind is None: + type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + + if not any(isinstance(x, (sympy.Basic, int, float)) for x in inputs): + return inputs + if all(isinstance(x, (int, float, sympy.Basic)) for x in inputs): + dtype = override_return_dtype or get_promoted_dtype( + *inputs, type_promotion_kind=type_promotion_kind + ) + + def const_func(x): + if isinstance(x, sympy.Basic): + return ir.IndexingConstant( + index=x, dtype=dtype, device=decode_device(None) + ) + else: + return ir.Constant(value=x, dtype=dtype, device=decode_device(None)) + + return [const_func(x) for x in inputs] + ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView, ir.Constant))) + out = [] + for x in inputs: + if isinstance(x, (int, float)): + out.append( + ExpandView.create( + ir.Constant(value=x, dtype=ex.get_dtype(), device=ex.get_device()), + list(ex.get_size()), + ) + ) + elif isinstance(x, sympy.Basic): + out.append( + ExpandView.create( + IndexingConstant( + index=x, dtype=ex.get_dtype(), device=ex.get_device() + ), + list(ex.get_size()), + ) + ) + else: + out.append(x) + + return out + + +def make_pointwise( + fn, + override_return_dtype=None, + override_device=None, + override_fn_when_input_bool=None, + override_fn_when_gpu_float64=None, + allow_alpha=False, + triton_fallback=None, + **kwargs +): + def inner(*inputs: List[TensorBox], alpha=None): + if triton_fallback is not None and any(map(is_triton, inputs)): + assert not allow_alpha # not implemented + return triton_fallback(*inputs) + + inputs = promote_constants(inputs, override_return_dtype) + if allow_alpha: + if alpha is not None and alpha != 1: + inputs = list(inputs) + inputs[-1] = mul(inputs[-1], alpha) + else: + assert alpha is None + loaders = [x.make_loader() for x in inputs] + ranges = inputs[0].get_size() + dtype = override_return_dtype or inputs[0].get_dtype() + is_gpu_device = ir.is_gpu(decode_device(inputs[0].get_device()).type) + + for other in inputs[1:]: + assert isinstance(other, ir.BaseConstant) or len(ranges) == len( + other.get_size() + ), f"ndim mismatch {fn} {ranges} {other.get_size()}" + + # in tracing, we will annotate pointwise nodes that correspond to the output of + # a pointwise node that would have been run in eager. intermediary pointwise nodes + # during decompositions are not annotated. + emulate_precision_casts = ( + V.graph is not None + and getattr(V.graph, "current_node", None) is not None + and V.graph.current_node.meta is not None + and V.graph.current_node.meta.get("low_precision_pointwise_barrier", False) + and dtype in (torch.bfloat16, torch.float16) + ) + + def inner_fn(index): + assert len(index) == len(ranges), f"wrong ndim {index} {ranges}" + if dtype == torch.bool and override_fn_when_input_bool is not None: + return override_fn_when_input_bool(*[load(index) for load in loaders]) + elif ( + override_fn_when_gpu_float64 + and is_gpu_device + and dtype == torch.float64 + ): + return override_fn_when_gpu_float64(*[load(index) for load in loaders]) + else: + inputs_loaded = [] + for load in loaders: + out = load(index) + if emulate_precision_casts: + downcast = ops.to_dtype(out, dtype, use_compute_types=False) + out = ops.to_dtype(downcast, dtype) + inputs_loaded.append(out) + + out = fn(*inputs_loaded) + if emulate_precision_casts: + # fp16/bf16 kernels are computed in fp32. Casting down to fp16/bf16 here, + # then upcasting again, to emulate casts that eager would do. + downcast = ops.to_dtype(out, dtype, use_compute_types=False) + return ops.to_dtype(downcast, dtype) + return out + + if not override_device: + device = None + for i in inputs: + if ir.is_gpu(i.get_device().type): + device = i.get_device() + break + if not device: + device = inputs[0].get_device() + + device = override_device or device + + input_graphs = fetch_graphs(inputs) + node_name = f'pointwise_{next(node_id)}' + origin_fn = fn_to_aten_fn[fn] + new_graph = merge_traced_graphs(input_graphs, origin_fn, node_name, **kwargs) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=ranges, + node_name=node_name, + traced_graph=new_graph, + ) + + return inner + + +def make_foreach_pointwise(pw_fn, allow_alpha=False): + def inner(*inputs: List[List[TensorBox]], alpha=1): + # group by device, whether any of the inputs are dynamic, and whether their types match + # (proxy for type promotion) + def group_args(arg_pairs): + out = defaultdict(list) + for i, args in enumerate(arg_pairs): + use_foreach = ( + not is_dynamic(*args) or config.combo_kernel_foreach_dynamic_shapes + ) + device = None + for t in args: + if isinstance(t, TensorBox): + device = t.data.get_device() + break + assert ( + device is not None + ), "foreach op should have at least one tensor arg" + out[(device, use_foreach)].append((i, args)) + return out + + realize_outputs = ( + len(V.graph.current_node.users) == 0 + or V.graph.current_node.target in inplace_foreach_ops + ) + for node in V.graph.current_node.users: + for user in node.users: + if not (user.op == "call_function" and (user.target in foreach_ops)): + realize_outputs = True + + a_list_input = None + for input in inputs: + if isinstance(input, (list, tuple)): + a_list_input = input + break + assert ( + a_list_input is not None + ), "at least one input must be a list to a foreach op" + + # broadcast scalar inputs to match length of list inputs + broadcast_inputs = [] + for input in inputs: + if not isinstance(input, (list, tuple)): + broadcast_inputs.append([input] * len(a_list_input)) + else: + broadcast_inputs.append(input) + + groups = group_args(zip(*broadcast_inputs)) + + outputs = [None] * len(a_list_input) + for (device, use_foreach), group in groups.items(): + operation_list: List[str] = [] + for ( + output_ind, + args, + ) in group: + if allow_alpha: + output = pw_fn(*args, alpha=alpha) + else: + output = pw_fn(*args) + + outputs[output_ind] = output + + if ( + V.graph.has_feature(device, BackendFeature.FOREACH) + and use_foreach + and realize_outputs + ): + output.realize() + operation_list.append(output.get_operation_name()) + + if operation_list: + V.graph.register_operation_list(operation_list) + + assert all(x is not None for x in outputs) + return outputs + + return inner + + +def to_dtype(x: TensorBox, dtype: torch.dtype, copy=False): + src_dtype = x.get_dtype() + if src_dtype == dtype: + return clone(x) if copy else x + + def _to_dtype(x): + return ops.to_dtype(x, dtype, src_dtype=src_dtype) + register_fn_to_aten_fn(_to_dtype, aten.to.dtype) + return make_pointwise(_to_dtype, override_return_dtype=dtype, dtype=dtype)(x) + +@register_lowering(torch.ops.npu.npu_dtype_cast_backward, type_promotion_kind=None) +@register_lowering(torch.ops.npu.npu_dtype_cast, type_promotion_kind=None) +@register_lowering(torch.ops.npu._npu_dtype_cast_backward, type_promotion_kind=None) +@register_lowering(torch.ops.npu._npu_dtype_cast, type_promotion_kind=None) +@register_lowering(prims.convert_element_type, type_promotion_kind=None) +def _convert_element_type(x: TensorBox, dtype: torch.dtype): + if dtype.is_complex or x.get_dtype().is_complex: + if x.get_size(): + # Decompose since aa aten fallback is more friendly for c++ codegen. + # This decomposition doesn't work for empty tensor, which needs more investigation. + dst = empty_like(x, dtype=dtype) + ir.InplaceCopyFallback.create(dst, x) + return dst + else: + return fallback_handler( + prims.convert_element_type.default, add_to_fallback_set=False + )(x, dtype) + return to_dtype(x, dtype, copy=True) + + +def to_dtype_bitcast(x: TensorBox, dtype: torch.dtype, *, copy=False): + x_dtype = x.get_dtype() + if x_dtype == dtype: + return clone(x) if copy else x + + def _get_primitive_bitwidth(dtype): + if dtype.is_floating_point: + return torch.finfo(dtype).bits + else: + return torch.iinfo(dtype).bits + + src_bits = _get_primitive_bitwidth(x_dtype) + dst_bits = _get_primitive_bitwidth(dtype) + if src_bits != dst_bits: + # fallback to aten eager implementation for differing bitwidths + return fallback_handler(aten.view.dtype)(x, dtype) + else: + return TensorBox(DtypeView.create(x, dtype)) + + +@register_lowering(aten.view.dtype, type_promotion_kind=None) +def _view_dtype(x: TensorBox, dtype: torch.dtype): + if dtype.is_complex or x.get_dtype().is_complex: + return TensorBox.create( + ir.ComplexView.create(torch.ops.aten.view.dtype, x, dtype) + ) + return to_dtype_bitcast(x, dtype) + + +def to_device(x: TensorBox, device: torch.device, *, copy=False, non_blocking=False): + src_dtype = x.get_dtype() + device = decode_device(device) + if x.get_device() == device: + return clone(x) if copy else x + + input_graphs = fetch_graphs([x, device]) + node_name = f'to_device_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.to.device, node_name, dtype=src_dtype, copy=copy) + return TensorBox.create(ir.DeviceCopy.create(x, device, non_blocking, traced_graph=new_graph, node_name=node_name)) + + +@register_lowering(prims.device_put, type_promotion_kind=None) +def _device_put(x: TensorBox, device: torch.device, non_blocking=False): + return to_device(x, device, copy=True, non_blocking=non_blocking) + + +def register_pointwise( + aten_fn, + name=None, + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + convert_input_to_bool=False, + override_return_dtype=None, + override_fn_when_input_bool=None, + allow_alpha=False, + use_libdevice_for_f64=False, + triton_fallback=None, +): + """A pointwise function that maps ops.{name} to inputs""" + name = name or aten_fn.__name__ + fn = ops_wrapper(name) + if use_libdevice_for_f64: + fn_libdevice = ops_wrapper("libdevice_" + name) + if override_fn_when_input_bool is not None: + override_fn_when_input_bool = ops_wrapper(override_fn_when_input_bool) + + fn = register_fn_to_aten_fn(fn, aten_fn) + + fn = make_pointwise( + fn, + override_return_dtype=override_return_dtype, + override_fn_when_input_bool=override_fn_when_input_bool, + override_fn_when_gpu_float64=fn_libdevice if use_libdevice_for_f64 else None, # type: ignore[possibly-undefined] + allow_alpha=allow_alpha, + triton_fallback=triton_fallback, + ) + fn = register_lowering( + aten_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + )(fn) + + if hasattr(prims, name): + register_lowering( + getattr(prims, name), + type_promotion_kind=None, + convert_input_to_bool=convert_input_to_bool, + )(fn) + return fn + + +def register_frexp(): + """A pointwise function that maps ops.frexp to inputs""" + name = "frexp" + frexp = ops_wrapper("frexp") + + def frexp0(*args, **kwargs): + return frexp(*args, **kwargs)[0] # type: ignore[index] # next PR + + def frexp1(*args, **kwargs): + return frexp(*args, **kwargs)[1] # type: ignore[index] # next PR + + pw_fns = [ + make_pointwise(frexp0), + make_pointwise(frexp1, override_return_dtype=torch.int32), + ] + + def fn(*args, **kwargs): + return pw_fns[0](*args, **kwargs), pw_fns[1](*args, **kwargs) + + fn = register_lowering( + aten.frexp, + )(fn) + + if hasattr(prims, name): + register_lowering( + getattr(prims, name), + type_promotion_kind=None, + )(fn) + return fn + + +register_frexp() + + +def register_foreach_pointwise( + aten_fn, + pointwise_lowering_fn, + allow_alpha=False, +): + fn = make_foreach_pointwise(pointwise_lowering_fn, allow_alpha=allow_alpha) + fn = _register_foreach_lowering(aten_fn, fn) + return fn + + +@register_lowering(aten.where, broadcast=False, type_promotion_kind=None) +def where(cond, a, b): + def fn(*args): + return ops.where(*args) + + if isinstance(a, (float, int)): + a = constant_like(a)(b) + if isinstance(b, (float, int)): + b = constant_like(b)(a) + + args = [cond, a, b] + dtype = get_promoted_dtype( + args[1], args[2], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])): + args[i] = x + for i in range(len(args)): + if isinstance(args[i], ir.Constant): + args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size())) + register_fn_to_aten_fn(fn, aten.where) + return make_pointwise(fn, override_return_dtype=dtype)( + args[0], to_dtype(args[1], dtype), to_dtype(args[2], dtype) + ) + + +@register_lowering(aten.broadcast_tensors, broadcast=False, type_promotion_kind=None) +def broadcast_tensors(*inputs): + if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)): + return broadcast_tensors(*inputs[0]) + target: List[sympy.Expr] = functools.reduce( + broadcast_symbolic_shapes, [x.get_size() for x in inputs], [] + ) + outputs = [] + for x in inputs: + sizes = x.get_size() + if len(sizes) != len(target) or any( + ( + ( + V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(a, 1), size_oblivious=True + ) + and not V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(b, 1), size_oblivious=True + ) + ) + or ( + not V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(a, 1), size_oblivious=True + ) + and V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(b, 1), size_oblivious=True + ) + ) + ) + for a, b in zip(sizes, target) + ): + x = expand(x, target) + outputs.append(x) + return outputs + + +@register_lowering([aten.alias, aten.detach, aten.detach_, aten.lift, prims.view_of]) +def nop(x): + return x # AOT autograd handles this for us + + +if hasattr(aten, "lift_fresh"): + register_lowering(aten.lift_fresh)(nop) + + +@register_lowering(aten.squeeze, type_promotion_kind=None) +def squeeze(x, dim=None): + assert isinstance(x, TensorBox) + if dim is None: + return TensorBox(SqueezeView.create(x.data)) + + dim = ( + V.graph.sizevars.evaluate_static_shape(dim) + if isinstance(dim, (int, sympy.Expr)) + else tuple(V.graph.sizevars.evaluate_static_shape(d) for d in dim) + ) + dim = canonicalize_dims(len(x.get_size()), dim) # type: ignore[call-overload] + dims = set((dim,) if not isinstance(dim, tuple) else dim) + + new_shape = [] + for d, s in enumerate(x.get_size()): + if not (d in dims and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1))): + new_shape.append(s) + + # squeeze does nothing if the size isn't 1 + return view(x, new_shape) if new_shape != x.get_size() else x + + +@register_lowering(aten.squeeze_copy, type_promotion_kind=None) +def squeeze_copy(x, dim=None): + return clone(squeeze(x, dim)) + + +@register_lowering([aten.squeeze_]) +def squeeze_(x, dim=None): + val = squeeze(x, dim) + assert isinstance(x, TensorBox) + assert isinstance(val, TensorBox) + x.data = val.data + return x + + +@register_lowering(aten.isinf) +def isinf(x): + if is_integer_type(x): + return full_like(x, False, dtype=torch.bool) + fn = ops_wrapper("isinf") + register_fn_to_aten_fn(fn, aten.isinf) + return make_pointwise(fn, override_return_dtype=torch.bool)(x) + + +@register_lowering(aten.isnan) +def isnan(x): + if is_integer_type(x): + return full_like(x, False, dtype=torch.bool) + fn = ops_wrapper("isnan") + return make_pointwise(fn, override_return_dtype=torch.bool)(x) + + +@register_lowering(aten.ceil) +def ceil(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("ceil") + return make_pointwise(fn)(x) + + +@register_lowering(aten.floor) +def floor(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("floor") + return make_pointwise(fn)(x) + + +@register_lowering(aten.round.default) +def round(x): + if is_integer_type(x): + return clone(x) + else: + fn = ops_wrapper("round") + return make_pointwise(fn)(x) + + +@register_lowering(aten.trunc) +def trunc(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("trunc") + return make_pointwise(fn)(x) + + +@register_lowering(aten.expand, type_promotion_kind=None) +def expand(x, sizes): + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + (x,) = promote_constants([x]) + if isinstance(x, ir.BaseConstant): + return ExpandView.create(x, tuple(sizes)) + assert isinstance(x, TensorBox) + assert isinstance(sizes, (list, tuple)) + if tuple(x.get_size()) == tuple(sizes): + return x + + if not free_unbacked_symbols(x.get_size()): + x_size_product = V.graph.sizevars.size_hint(sympy_product(x.get_size())) + # TODO: It would be better to realize the input if any of its sizes + # are unbacked, because typically the size will be non-zero. However, + # this cannot be done directly as below as we'll choke on the size_hint + # here + if x_size_product > 0 and not free_unbacked_symbols(sizes): + # maybe realize input before broadcasting it + x.mark_reuse( + V.graph.sizevars.size_hint(sympy_product(sizes)) // x_size_product + ) + input_graphs = fetch_graphs([x.data, tuple(sizes)]) + node_name = f'expand_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.expand, node_name) + return TensorBox(ExpandView.create(x.data, tuple(sizes), traced_graph=new_graph, node_name=node_name)) + + +@register_lowering(prims.broadcast_in_dim, type_promotion_kind=None) +def broadcast_in_dim(a, shape, broadcast_dimensions): + s = list(shape) + for broadcast_dimension in broadcast_dimensions: + s[broadcast_dimension] = -1 + + v = a + for idx, x in enumerate(s): + if x != -1: + v = unsqueeze(v, idx) + + return expand(v, shape) + + +@register_lowering(aten.expand_as, type_promotion_kind=None) +def expand_as(x, y): + return expand(x, y.get_size()) + + +@register_lowering(aten.repeat) +def repeat(x, repeats): + + input_graphs = fetch_graphs([x, repeats]) + node_name = f'repeat_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.repeat, node_name) + + old_size = list(x.get_size()) + if len(repeats) > len(old_size): + old_size = [sympy.Integer(1)] * (len(repeats) - len(old_size)) + old_size + x = view(x, list(old_size)) + assert len(repeats) == len(x.get_size()) + + new_size = list(x.get_size()) + + zero_tensor = False + for i in range(len(repeats)): + if repeats[i] == 0: + zero_tensor = True + new_size[i] = new_size[i] * repeats[i] + + if zero_tensor: + return empty(new_size, dtype=x.get_dtype(), device=x.get_device()) + if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)): + return clone(expand(x, new_size)) + + x_loader: Callable[[Any], Any] + + def inner_fn(index): + assert len(index) == len(repeats) + index = list(index) + for i in range(len(repeats)): + if repeats[i] != 1: + if old_size[i] == 1: + index[i] = sympy.Integer(0) + else: + index[i] = ModularIndexing(index[i], 1, old_size[i]) + return x_loader(index) + + old_size_product = V.graph.sizevars.size_hint(sympy_product(old_size)) + if old_size_product > 0: + # maybe realize the input + x.mark_reuse( + V.graph.sizevars.size_hint(sympy_product(new_size)) // old_size_product + ) + x_loader = x.make_loader() + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(new_size), + traced_graph=new_graph, + node_name=node_name + ) + + +@register_lowering(aten._unsafe_view, type_promotion_kind=None) +@register_lowering(aten.view, type_promotion_kind=None) +@register_lowering(aten.reshape, type_promotion_kind=None) +def view(x, sizes): + assert isinstance(x, TensorBox) + assert isinstance(sizes, (list, tuple)) + input_graphs = fetch_graphs([x.data, sizes]) + node_name = f'view_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.reshape, node_name) + return TensorBox(View.create(x.data, sizes, traced_graph=new_graph, node_name=node_name)) + + +@register_lowering(aten.permute, type_promotion_kind=None) +def permute(x, dims): + assert isinstance(x, TensorBox) + assert isinstance(dims, (list, tuple)) + input_graphs = fetch_graphs([x.data, dims]) + node_name = f'permute_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.permute, node_name) + return TensorBox(PermuteView.create(x.data, tuple(dims), traced_graph=new_graph, node_name=node_name)) + + +@register_lowering(aten.slice, type_promotion_kind=None) +def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True): + assert isinstance(x, TensorBox) + dim = _validate_dim(x, dim, 0) + x.realize_hint() + input_graphs = fetch_graphs([x.data]) + node_name = f'slice_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.slice, node_name, dim=dim, start=start, end=end, step=step) + + return TensorBox(ir.SliceView.create(x.data, dim, start, end, step, traced_graph=new_graph, node_name=node_name)) + + +@register_lowering(aten.as_strided, type_promotion_kind=None) +def as_strided(x, size, stride, storage_offset=None): + if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView): + # as_strided ignores views + x = x.data.unwrap_view() + x.realize() + if not ir.is_storage_and_layout(x): + raise NotImplementedError(f"unrealized as_strided({x}, ...)") + storage, old_layout = ir.as_storage_and_layout(x) + new_layout = ir.FixedLayout( + old_layout.device, + old_layout.dtype, + [sympy.expand(s) for s in size], + [sympy.expand(s) for s in stride], + sympy.expand(storage_offset or 0), + ) + return TensorBox(ir.ReinterpretView(data=storage, layout=new_layout)) + + +@register_lowering(aten.as_strided_, type_promotion_kind=None) +def as_strided_(x, size, stride, storage_offset=None): + assert isinstance(x, TensorBox) + x.data = as_strided(x, size, stride, storage_offset).data + return x + + +@register_lowering(aten.as_strided_copy, type_promotion_kind=None) +def as_strided_copy(x, size, stride, storage_offset=None): + result = as_strided(x, size, stride, storage_offset) + return clone(result) + + +def pointwise_cat(inputs, dim=0): + # (inclusive, exclusive) + inputs_ranges: List[Tuple[sympy.Expr, sympy.Expr]] = [] + prev_end = 0 + for inp in inputs: + inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type] + prev_end = inputs_ranges[-1][-1] # type: ignore[assignment] + + inputs_loaders = [inp.make_loader() for inp in inputs] + + def inner_fn(idx): + idx_dim = ops.index_expr(idx[dim], torch.int64) + + masks = [] + masked_loads = [] + for i in range(len(inputs)): + start = ( + ops.constant(0, torch.int64) + if i == 0 + else ops.index_expr(inputs_ranges[i][0], torch.int64) + ) + end = ops.index_expr(inputs_ranges[i][1], torch.int64) + + start_cond = ops.ge(idx_dim, start) + end_cond = ops.lt(idx_dim, end) + if i == 0: + mask = end_cond + elif i == len(inputs) - 1: + mask = start_cond + else: + mask = ops.and_(start_cond, end_cond) + + masks.append(mask) + idx_load = list(idx) + + # if we're concatting [4], [2] + # when we index the second tensor for 5 we want to index 5 - 4 + # Use Identity to prevent expansion of index * stride to keep expression + # in same int bitwidth as shape + idx_load[dim] = Identity(idx_load[dim] - inputs_ranges[i][0]) + + masked_loads.append( + ops.masked( + mask, + lambda: inputs_loaders[i](idx_load), + 0.0, # this value should be unused + ), + ) + + next_val = masked_loads[-1] + for i in range((len(inputs)) - 2, -1, -1): + next_val = ops.where( + masks[i], + masked_loads[i], + next_val, + ) + return next_val + + new_size = list(inputs[0].get_size()) + new_size[dim] = inputs_ranges[-1][-1] + + input_graphs = fetch_graphs([inputs]) + node_name = f'cat_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.cat, node_name, dim=dim) + + return Pointwise.create( + device=inputs[0].get_device(), + dtype=inputs[0].get_dtype(), + inner_fn=inner_fn, + ranges=new_size, + traced_graph=new_graph, + node_name=node_name + ) + +@register_lowering(quantized_decomposed.quantize_per_channel, type_promotion_kind=None) +def quantized_decomposed_quantize_per_channel( + input: TensorBox, + scales: TensorBox, + zero_points: TensorBox, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + assert len(scales.get_size()) == 1, "expect scales 1 dim" + assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" + + if input.get_dtype() == torch.bfloat16: + input = to_dtype(input, torch.float32) + assert ( + input.get_dtype() == torch.float32 + ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + assert axis < len( + input.get_size() + ), f"Expecting axis to be < {len(input.get_size())}" + + input_loader = input.make_loader() + scales_loader = scales.make_loader() + zero_points_loader = zero_points.make_loader() + + def inner_fn(idx): + channel_idx = (idx[axis],) + + input = input_loader(idx) + scale = scales_loader(channel_idx) + zero_point = zero_points_loader(channel_idx) + qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) + + if scales.dtype != torch.float32: + scale = ops.to_dtype(scale, torch.float32) + if zero_points.dtype != torch.int32: + zero_point = ops.to_dtype(zero_point, torch.int32) + inv_scale = ops.reciprocal(scale) + val = ops.round(input * inv_scale) + zero_point + clamped = ops.maximum(qmin, ops.minimum(qmax, val)) + return ops.to_dtype(clamped, dtype) + + return Pointwise.create( + device=input.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.dequantize_per_channel, type_promotion_kind=None +) +def quantized_decomposed_dequantize_per_channel( + input: TensorBox, + scales: TensorBox, + zero_points: TensorBox, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + assert len(scales.get_size()) == 1, "expect scales 1 dim" + assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" + assert ( + input.get_dtype() == dtype + ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + assert axis < len( + input.get_size() + ), f"Expecting axis to be < {len(input.get_size())}" + + input_loader = input.make_loader() + scales_loader = scales.make_loader() + zero_points_loader = zero_points.make_loader() + + def inner_fn(idx): + channel_idx = (idx[axis],) + + input = input_loader(idx) + scale = scales_loader(channel_idx) + zero_point = zero_points_loader(channel_idx) + + if scales.dtype != torch.float32: + scale = ops.to_dtype(scale, torch.float32) + if zero_points.dtype != torch.float32: + zero_point = ops.to_dtype(zero_point, torch.float32) + val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale + return val + + return Pointwise.create( + device=input.get_device(), + dtype=torch.float32, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.quantize_per_tensor.default, type_promotion_kind=None +) +def quantized_decomposed_quantize_per_tensor_default( + input: TensorBox, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + if input.get_dtype() == torch.bfloat16: + input = to_dtype(input, torch.float32) + assert ( + input.get_dtype() == torch.float32 + ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + + input_loader = input.make_loader() + + def inner_fn(idx, scale, zero_point): + input = input_loader(idx) + inv_scale, zero_point = _create_constants( + 1.0 / scale, zero_point, dtype=torch.float32 + ) + val = ops.round(input * inv_scale) + zero_point + qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) + clamped = ops.minimum(ops.maximum(val, qmin), qmax) + return ops.to_dtype(clamped, dtype) + + return Pointwise.create( + device=input.get_device(), + dtype=dtype, + inner_fn=functools.partial( + inner_fn, scale=float(scale), zero_point=int(zero_point) + ), + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.dequantize_per_tensor.default, type_promotion_kind=None +) +def quantized_decomposed_dequantize_per_tensor_default( + input: TensorBox, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + assert ( + input.get_dtype() == dtype + ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + + input_loader = input.make_loader() + + def inner_fn(idx, scale, zero_point): + input = input_loader(idx) + scale, zero_point = _create_constants(scale, zero_point, dtype=torch.float32) + val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale + return val + + return Pointwise.create( + device=input.get_device(), + dtype=torch.float32, + inner_fn=functools.partial( + inner_fn, scale=float(scale), zero_point=int(zero_point) + ), + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.quantize_per_tensor.tensor, type_promotion_kind=None +) +def quantized_decomposed_quantize_per_tensor_tensor( + input: TensorBox, + scale: TensorBox, + zero_point: TensorBox, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + if input.get_dtype() == torch.bfloat16: + input = to_dtype(input, torch.float32) + assert ( + input.get_dtype() == torch.float32 + ), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + assert len(scale.get_size()) == 0 or ( + len(scale.get_size()) == 1 and scale.get_size()[0] == 1 + ), "expect scale as scalar tensor" + assert len(zero_point.get_size()) == 0 or ( + len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1 + ), "expect zero_point as scalar tensor" + + input_loader = input.make_loader() + scale_loader = scale.make_loader() + zero_point_loader = zero_point.make_loader() + + def inner_fn(idx): + input = input_loader(idx) + _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ()) + _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ()) + if scale.dtype != torch.float32: + _scale = ops.to_dtype(_scale, torch.float32) + if zero_point.dtype != torch.float32: + _zero_point = ops.to_dtype(_zero_point, torch.float32) + val = ops.round(input * ops.reciprocal(_scale)) + _zero_point + qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) + clamped = ops.minimum(ops.maximum(val, qmin), qmax) + return ops.to_dtype(clamped, dtype) + + return Pointwise.create( + device=input.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.dequantize_per_tensor.tensor, type_promotion_kind=None +) +def quantized_decomposed_dequantize_per_tensor_tensor( + input: TensorBox, + scale: TensorBox, + zero_point: TensorBox, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + assert len(scale.get_size()) == 0 or ( + len(scale.get_size()) == 1 and scale.get_size()[0] == 1 + ), "expect scale as scalar tensor" + assert len(zero_point.get_size()) == 0 or ( + len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1 + ), "expect zero_point as scalar tensor" + assert ( + input.get_dtype() == dtype + ), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + + input_loader = input.make_loader() + scale_loader = scale.make_loader() + zero_point_loader = zero_point.make_loader() + + def inner_fn(idx): + input = input_loader(idx) + _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ()) + _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ()) + if scale.dtype != torch.float32: + _scale = ops.to_dtype(_scale, torch.float32) + if zero_point.dtype != torch.float32: + _zero_point = ops.to_dtype(_zero_point, torch.float32) + val = ops.sub(ops.to_dtype(input, torch.float32), _zero_point) * _scale + return val + + return Pointwise.create( + device=input.get_device(), + dtype=torch.float32, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering(aten.cat) +def cat(inputs, dim=0): + cpu_device = inputs[0].get_device().type == "cpu" + if cpu_device and all( + input.get_dtype() in [torch.int8, torch.uint8] for input in inputs + ): + # TODO Remove this fallback when we support vectorization + # code gen with uint8 data type directly. + for input in inputs: + input.realize() + if all(len(input.get_size()) == 4 for input in inputs): + inputs, _ = require_channels_last(aten.cat, *inputs) + return fallback_handler(aten.cat.default)(inputs, dim) + + if len(inputs) == 1: + return clone(inputs[0]) + + dim = _validate_dim(inputs[0], dim, 0) + dtype = get_promoted_dtype( + *inputs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + inputs = [to_dtype(inp, dtype) for inp in inputs] + + def unwrap_tensor(x: Union[TensorBox, ir.StorageBox]) -> ir.IRNode: + if isinstance(x, TensorBox): + if isinstance(x.data, ir.BaseView): + return x.data.unwrap_view() + else: + return x.data + + if isinstance(x, ir.StorageBox): + return x.data + + return x + + def is_reduction(t): + return isinstance(t, ir.ComputedBuffer) and isinstance(t.data, ir.Reduction) + + def can_fuse_reduction(t): + if isinstance(t, (TensorBox, ir.StorageBox)): + return can_fuse_reduction(unwrap_tensor(t)) + return ( + is_reduction(t) + or isinstance(t, ir.Pointwise) + and any( + can_fuse_reduction(V.graph.get_buffer(read)) + for read in t.get_read_names() + ) + ) + + # fusing reducutions into computed concat buffer can cause regressions. + fusable_reduction = any(can_fuse_reduction(t) for t in inputs) + + def should_lower_cat_input(x) -> bool: + # Unrealized inputs will not be storage and layouts, and we dont want to realize + # them in case we want to fuse + if ir.is_storage_and_layout(x): + storage, _ = ir.as_storage_and_layout(x, freeze=False) + return not ir.ConcatKernel.can_realize_into_without_copy(storage) + + if isinstance(x, (TensorBox, ir.StorageBox)): + return should_lower_cat_input(unwrap_tensor(x)) + + if isinstance(x, ir.Pointwise): + return True + + return False + + # TODO: We observed negative performance impact of pointwise_cat optimization on CPU so disabled it. + # We will revisit this later after enabling vectorization on index_expr. + if cpu_device: + return TensorBox(ir.ConcatKernel.create(inputs, dim)) + + def op_count(x): + if isinstance(x, (TensorBox, ir.StorageBox)): + return op_count(unwrap_tensor(x)) + + # this will correspond to a direct memory read + if not isinstance(x, ir.Pointwise): + return 0 + + count = x.inner_fn_opcount().num_ops + for read in x.get_read_names(): + count += op_count(V.graph.get_buffer(read)) + + return count + + # as of inputs increase, possibility for register spilling also increases + # past a certain threshold of inputs we only fuse if the if the input kernels + # are simple + # not sure if we want to expose to users via config since logic may change in future + MAX_COMPLEX_POINTWISE_CAT = 8 + MAX_SIMPLE_OP_COUNT = 2 + + def additional_pointwise_ops(op: torch._ops.OpOverload): + return op in (aten.cat.default, aten.constant_pad_nd.default) + + if len(inputs) <= MAX_COMPLEX_POINTWISE_CAT or ( + (len(inputs) <= config.max_pointwise_cat_inputs) + and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs) + ): + pointwise_uses = all( + is_pointwise_use(use, additional_pointwise_ops) + for use in V.current_node.users + ) + # fuse in case we will be used in a pointwise node, and there are any inputs we + # we can prevent materialization of. + fuse_pointwise_use = ( + any(should_lower_cat_input(inp) for inp in inputs) and pointwise_uses + ) + + # horizontal fuse in case all inputs will require a copy kernel anyway. + # only horizontally fuse pointwise kernels + horizontal_fuse_cat = all( + should_lower_cat_input(inp) for inp in inputs + ) and not any(can_fuse_reduction(t) for t in inputs) + if fuse_pointwise_use or (horizontal_fuse_cat and not fusable_reduction): + return pointwise_cat(inputs, dim) + + return TensorBox(ir.ConcatKernel.create(inputs, dim)) + + +@register_lowering(aten.diagonal, type_promotion_kind=None) +def diagonal(input, offset: int = 0, dim1: int = 0, dim2: int = 1): + original_shape = input.get_size() + num_dims = len(original_shape) + dim1 = canonicalize_dim(idx=dim1, rank=num_dims) + dim2 = canonicalize_dim(idx=dim2, rank=num_dims) + + check( + dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" + ) + + offset_negative = V.graph.sizevars.evaluate_expr(sympy.Lt(offset, 0)) + if offset_negative: + diag_size = V.graph.sizevars.evaluate_max( + V.graph.sizevars.evaluate_min( + original_shape[dim1] + offset, original_shape[dim2] + ), + 0, # type: ignore[arg-type] + ) + else: + diag_size = V.graph.sizevars.evaluate_max( + V.graph.sizevars.evaluate_min( + original_shape[dim1], original_shape[dim2] - offset + ), + 0, # type: ignore[arg-type] + ) + + base_idx = (0, 0) + if offset_negative: + base_idx = (-offset, 0) + else: + base_idx = (0, offset) + + sizes = [s for i, s in enumerate(original_shape) if i not in (dim1, dim2)] + sizes.append(diag_size) + + def reindexer(idx): + diag_idx = idx[-1] + original_idx = [0] * len(original_shape) + cur_dim = 0 + for d in range(num_dims): + if d == dim1: + original_idx[d] = diag_idx + base_idx[0] + elif d == dim2: + original_idx[d] = diag_idx + base_idx[1] + else: + original_idx[d] = idx[cur_dim] + cur_dim += 1 + + assert cur_dim == len(original_shape) - 2 + return original_idx + + return TensorBox(ir.GenericView.create(input, sizes, reindexer)) + + +@register_lowering(aten.diagonal_copy, type_promotion_kind=None) +def diagonal_copy(input, offset: int = 0, dim1: int = 0, dim2: int = 1): + return clone(diagonal(input, offset, dim1, dim2)) + + +@register_lowering(aten.diagonal_scatter, type_promotion_kind=None) +def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1): + output = clone(input) + target = diagonal(output, offset, dim1, dim2) + mutate_to(target, src) + return output + + +@register_lowering(aten.select, type_promotion_kind=None) +def select(x, dim, idx): + idx = View.handle_negative_index(idx, x.get_size()[dim]) + return squeeze(slice_(x, dim, idx, idx + 1), dim) + + +@register_lowering(aten.split, type_promotion_kind=None) +def split(x, sizes, dim=0, clamp=True): + dim = _validate_dim(x, dim, 0) + if isinstance(sizes, sympy.Expr): + # TODO: We don't have to guard on sizes per se, but the number + # of splits must stay constant + sizes = V.graph.sizevars.evaluate_static_shape(sizes) + if isinstance(sizes, (int, sympy.Integer)): + x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) + sizes = [sizes] * ((x_size + sizes - 1) // sizes) + result = [] + start = 0 + for size in sizes: + end = start + size + result.append(slice_(x, dim, start, end, clamp=clamp)) + start = end + return result + + +@register_lowering(aten.split_with_sizes, type_promotion_kind=None) +def split_with_sizes(x, sizes, dim=0): + return split(x, sizes, dim, clamp=False) + + +@register_lowering(aten.unbind, type_promotion_kind=None) +def unbind(x, dim=0): + dim = _validate_dim(x, dim, 0) + x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) + result = [] + for i in range(x_size): + result.append(select(x, dim, i)) + return result + + +@register_lowering(aten.unfold, type_promotion_kind=None) +def unfold(x, dimension, size, step): + sizes = x.get_size() + ndim = len(sizes) + dim = canonicalize_dim(ndim, dimension) + + if ndim == 0: + return slice_(unsqueeze(x, 0), end=size) + + dim_size = sizes[dim] + sizevars = V.graph.sizevars + sizevars.guard_leq(size, dim_size) + sizevars.guard_lt(0, step) # type: ignore[arg-type] + + new_dim_size = FloorDiv(dim_size - size, step) + 1 + if sizevars.size_hint(dim_size) > 0: + x.mark_reuse(sizevars.size_hint(CeilDiv(new_dim_size * size, dim_size))) + + out_size = [*sizes[:dim], new_dim_size, *sizes[dim + 1 :], size] + + def reindexer(idx): + dim_idx = idx[-1] + idx[dim] * step + return (*idx[:dim], dim_idx, *idx[dim + 1 : -1]) + + return TensorBox(ir.GenericView.create(x, out_size, reindexer)) + + +@register_lowering(aten.unsqueeze, type_promotion_kind=None) +def unsqueeze(x, dim): + dim = _validate_dim(x, dim, 1) + new_shape = list(x.get_size()) + new_shape.insert(dim, sympy.Integer(1)) + return view(x, new_shape) + + +@register_lowering(aten.unsqueeze_, type_promotion_kind=None) +def unsqueeze_(x, dim): + val = unsqueeze(x, dim) + assert isinstance(x, TensorBox) + assert isinstance(val, TensorBox) + x.data = val.data + return x + + +def _validate_dim(x, dim, offset=0): + dim = V.graph.sizevars.shape_env.evaluate_expr(sympy.sympify(dim)) + ndim = len(x.get_size()) + if dim < 0: + dim += ndim + offset + assert 0 <= dim < ndim + offset + return dim + + +@register_lowering(aten.glu) +def glu(x, dim=-1): + dim = _validate_dim(x, dim, 0) + # TODO: don't guard on static shape here + new_len = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) // 2 + a = slice_(x, dim, 0, new_len) + b = slice_(x, dim, new_len, new_len * 2) + return mul(a, sigmoid(b)) + + +def fallback_handler(kernel, add_to_fallback_set=True): + if add_to_fallback_set: + fallbacks.add(kernel) + + def handler(*args, **kwargs): + def wrap_tensors(x): + return TensorBox.create(x) if isinstance(x, ir.IRNode) else x + + return pytree.tree_map( + wrap_tensors, ir.FallbackKernel.create(kernel, *args, **kwargs) + ) + + return handler + + +@functools.lru_cache(None) +def _warn_complex_not_supported(): + warnings.warn( + "Torchinductor does not support code generation for complex operators. Performance may be worse than eager." + ) + + +# There are some types (CPU) which we accept as input but not as +# output. +def unsupported_input_tensor(t: torch.Tensor, parent=None): + "Do not support reading or writing to this tensor" + if t.is_complex(): + # Complex views are supported with IR ComplexView + if parent and parent.target in ( + torch.ops.aten.view.dtype, + torch.ops.prims.convert_element_type.default, + ): + return False + _warn_complex_not_supported() + return True + return False + + +def unsupported_output_tensor(t: torch.Tensor, parent=None): + "Do not support writing tensor but can read from it" + if unsupported_input_tensor(t, parent): + return True + return t.is_cpu and config.disable_cpp_codegen + + +def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=True): + # Custom fallback lowering + if node.target is aten.view_as_complex.default: + return False + + # We should be able to remove this special case once `disable_cpp_codegen` is killed. + if node.target is aten.lift_fresh_copy.default: + return False + + def check_skip_condition(node, parent, is_output): + if not isinstance(node, torch.fx.Node): + return False + + if "val" not in node.meta: + return False + + for meta in pytree.tree_leaves(node.meta["val"]): + if not isinstance(meta, torch._subclasses.FakeTensor): + continue + + if is_output: + if unsupported_output_tensor(meta, parent): + return True + else: + if unsupported_input_tensor(meta, parent): + return True + + return False + + # only skip codegen if there is a cpu output, not input + for arg in pytree.arg_tree_leaves(*node.args, **node.kwargs): + if check_skip_condition(arg, node, is_output=False): + return True + + return check_skip_condition(node, node, is_output=True) + + +def make_fallback(op, layout_constraint=None, warn=True): + assert op not in decompositions, f"both a fallback and a decomp for same op: {op}" + if ( + warn + and bool(os.getenv("CI")) + and get_decompositions([op]) + # if fallback_random, we allow not decomposing random + and not ( + config.fallback_random + and op in torch._decomp.decompositions_for_rng.extra_random_decomps + ) + ): + # Note: 'warn' is holdover from when this was a warning, but for ops that previously + # set warn=False we do not want a CI error. + # Ignore the 'suppress errors' configs in CI, as this particular warning happens on startup anyway and is not + # likely to be triggered preferentially on one CI config over another. + if torch._dynamo.config.suppress_errors: + torch._dynamo.config.suppress_errors = False + log.warning( + "A make_fallback error occurred in suppress_errors config," + " and suppress_errors is being disabled to surface it." + ) + raise AssertionError( + f"make_fallback({op}): a decomposition exists, we should switch to it." + " To fix this error, either add a decomposition to core_aten_decompositions (preferred)" + " or inductor_decompositions, and delete the corresponding `make_fallback` line." + " Get help from the inductor team if unsure, don't pick arbitrarily to unblock yourself.", + ) + + def register_fallback(op_overload): + add_needs_realized_inputs(op_overload) + if layout_constraint is not None: + add_layout_constraint(op_overload, layout_constraint) + return register_lowering(op_overload, type_promotion_kind=None)( + fallback_handler(op_overload) + ) + + if isinstance(op, torch._ops.OpOverloadPacket): + for ol in op.overloads(): + op_overload = getattr(op, ol) + register_fallback(op_overload) + elif isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + register_fallback(op) + else: + raise RuntimeError(f"Unsupported fallback {op} with type {type(op)}") + + +def philox_rand_offset(shape): + """ + TorchInductor offset calculation differs from PyTorch eager offset + calculation for random ops (tl.rand vs torch.rand). In future, we should + strive for same impl for tl.rand and torch.rand. + """ + numel = 1 + for s in shape: + numel = numel * s + return tensor(numel, dtype=torch.int64) + + +@register_lowering(torch.ops.rngprims.philox_rand, type_promotion_kind=None) +def philox_rand(size, seed, offset, stride, device, dtype): + # stride arg is optional and will be used in future for distributed random + # ops. Currently, its unused. + random_pos = ir.FixedLayout( + device, + dtype, + size, + ir.FlexibleLayout.contiguous_strides(size), + ).make_indexer() + seed_loader = seed.make_loader() + offset_loader = offset.make_loader() + + def inner_fn(index): + # Both seed and offset in the philox_rand op are tensors. + # torch seed and offsets are of type int64, but tl.rand accepts int32 + seed_index_expr = ops.to_dtype(seed_loader([]), torch.int32) + offset_index_expr = ops.to_dtype(offset_loader([]), torch.int32) + # Get the offset'd position + rand_index_expr = ops.add( + ops.index_expr(random_pos(index), torch.int32), offset_index_expr + ) + result = ops.rand( + seed_index_expr, + rand_index_expr, + ) + return ops.to_dtype(result, dtype) + + random_values_node = Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(size), + ) + + offset_node = philox_rand_offset(size) + return random_values_node, offset_node + + +@register_lowering(aten.native_dropout, type_promotion_kind=None) +def native_dropout(x, p, train): + if config.fallback_random: + return pytree.tree_map( + TensorBox.create, + ir.FallbackKernel.create(aten.native_dropout.default, x, p, train), + ) + else: + raise AssertionError("should be handled in replace_random.py") + + +@register_lowering(aten.bernoulli_, type_promotion_kind=None) +def bernoulli_(x, *args): + assert config.fallback_random or x.get_device() == torch.device( + "cpu" + ), "this should be handled in decomps unless config.fallback_random or the device is CPU" + x.realize() + op_overload = ( + aten.bernoulli_.float + if len(args) == 0 or isinstance(args[0], float) + else aten.bernoulli_.Tensor + ) + ir.InplaceBernoulliFallback(op_overload, x, *args) + return x + + +@register_lowering(aten.bernoulli.p, type_promotion_kind=None) +def bernoulli_p(x, *args): + assert config.fallback_random or x.get_device() == torch.device( + "cpu" + ), "this should be handled in decomps unless config.fallback_random or the device is CPU" + return bernoulli_(clone(x), *args) + + +# This shouldn't be called in general +@register_lowering(aten._foobar) +def _foobar(_): + raise AssertionError + + +@functools.lru_cache(1) +def _warn_triton_random(salt): + log.info("using triton random, expect difference from eager") + + +def warn_triton_random(): + # only warn once per graph + _warn_triton_random(V.graph.creation_time) + + +fallback_rand_default = fallback_handler(aten.rand.default) +fallback_rand_generator = fallback_handler(aten.rand.generator) +fallback_randn_default = fallback_handler(aten.randn.default) +fallback_randn_generator = fallback_handler(aten.randn.generator) +make_fallback(aten.randint) + + +@register_lowering(aten.rand) +def rand(*args, **kwargs): + if kwargs.get("generator", None) is not None: + return fallback_rand_generator(*args, **kwargs) + elif config.fallback_random: + kwargs.pop("generator", None) + return fallback_rand_default(*args, **kwargs) + raise AssertionError("should have been handled in replace_random.py") + + +@register_lowering(aten.randn) +def randn(*args, **kwargs): + if kwargs.get("generator", None) is not None: + return fallback_randn_generator(*args, **kwargs) + elif config.fallback_random: + kwargs.pop("generator", None) + return fallback_randn_default(*args, **kwargs) + raise AssertionError("should have been handled in replace_random.py") + + +@register_lowering(inductor_prims.force_stride_order, type_promotion_kind=None) +def inductor_force_stride_order(input_tensor, stride): + stride_order = ir.get_stride_order(stride) + return ir.ExternKernel.require_stride_order(input_tensor, stride_order) + + +@register_lowering(inductor_prims.seed, type_promotion_kind=None) +def inductor_seed(device: torch.device): + raise AssertionError("should be handled in fuse_seed_creation_pass()") + + +@register_lowering(inductor_prims.seeds, type_promotion_kind=None) +def inductor_seeds(count, device): + warn_triton_random() + return TensorBox.create(ir.RandomSeeds(count, decode_device(device))) + + +@register_lowering(inductor_prims.lookup_seed, type_promotion_kind=None) +def inductor_lookup_seed(seeds, index): + def inner_fn(_): + return ops.load_seed(seeds.get_name(), index) + + return Pointwise.create( + device=seeds.get_device(), + dtype=seeds.get_dtype(), + inner_fn=inner_fn, + ranges=[], + ) + + +@register_lowering(inductor_prims.random, type_promotion_kind=None) +def inductor_random(size: List[int], seed: TensorBox, mode: str, *, offset: int = 0): + assert not config.fallback_random + assert mode in ("rand", "randn") + size = [*size] + dtype = torch.float32 + device = seed.get_device() + random_pos = ir.FixedLayout( + device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset + ).make_indexer() + seed_loader = seed.make_loader() + + def inner_fn(index): + return getattr(ops, mode)( + seed_loader([]), + ops.index_expr(random_pos(index), torch.int32), + ) + + result = Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=[*size], + ) + result.realize() + return result + + +@register_lowering(inductor_prims.randint, type_promotion_kind=None) +def inductor_randint( + low: int, high: int, size: List[int], seed: TensorBox, *, offset: int = 0 +): + assert not config.fallback_random + size = [*size] + dtype = torch.int64 + device = seed.get_device() + random_pos = ir.FixedLayout( + device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset + ).make_indexer() + seed_loader = seed.make_loader() + + def inner_fn(index): + return ops.randint64( + seed_loader([]), + ops.index_expr(random_pos(index), torch.int32), + ops.index_expr(low, torch.int64), + ops.index_expr(high, torch.int64), + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=[*size], + ) + + +def _boundaries_helper(tb: TensorBox) -> Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr]: + return ( + tb.get_name(), + tb.get_size()[-1], + tb.get_size()[0] * tb.get_stride()[0], + tb.get_stride()[-1], + ) + + +def _sorter_helper(tb: TensorBox) -> Tuple[str, sympy.Expr]: + return tb.get_name(), tb.get_stride()[-1] + + +@register_lowering(aten.searchsorted.Tensor, type_promotion_kind=None) +def searchsorted( + sorted_sequence: TensorBox, + self: TensorBox, + *, + out_int32: bool = False, + right: bool = False, + side: Optional[str] = None, + sorter: Optional[TensorBox] = None, +) -> TensorBox: + validate_bucketize = lambda tb: V.graph.has_feature( # noqa: E731 + tb, BackendFeature.BUCKETIZE + ) + if ( + not validate_bucketize(sorted_sequence) + or not validate_bucketize(self) + or (sorter is not None and not validate_bucketize(sorter)) + ): + return fallback_handler(aten.searchsorted.Tensor, add_to_fallback_set=False)( + sorted_sequence, + self, + out_int32=out_int32, + right=right, + side=side, + sorter=sorter, + ) + + # If side is present, override the value of right if needed. This assumes that + # validation of the two options being non-contradictory is already done by the + # searchsorted meta-function. + if side is not None and side == "right": + right = True + + index_dtype = torch.int32 if out_int32 else torch.int64 + values_loader = self.make_loader() + + # The entire sorted_sequence tensor needs to be used by ops.bucketize, so we need to + # realize it into global memory; or in other words, we can't guarantee that + # sorted_sequence.get_name() (used below) will exist unless we call + # sorted_sequence.realize(). + sorted_sequence.realize() + + if sorter is not None: + sorter.realize() + + if len(sorted_sequence.get_size()) == 1: + + def inner_fn(idx): + val = values_loader(idx) + return ops.bucketize( + val, + _boundaries_helper(sorted_sequence), + 0, + index_dtype, + right, + sorter=None if sorter is None else _sorter_helper(sorter), + sorter_indices=None if sorter is None else 0, + ) + + else: + + def inner_fn(idx): + val = values_loader(idx) + + # Get index to the beginning of the sorted sequence within a flattened + # version of the array. + def get_flattened_index(tb: TensorBox): + strides = tb.get_stride() + return ops.index_expr( + functools.reduce( + operator.add, (s * i for s, i in zip(strides[:-1], idx[:-1])) + ), + index_dtype, + ) + + return ops.bucketize( + val, + _boundaries_helper(sorted_sequence), + get_flattened_index(sorted_sequence), + index_dtype, + right, + sorter=None if sorter is None else _sorter_helper(sorter), + sorter_indices=None if sorter is None else get_flattened_index(sorter), + ) + + device = self.get_device() + return Pointwise.create( + device=device, + dtype=index_dtype, + inner_fn=inner_fn, + ranges=self.shape, + ) + + +@register_lowering(aten.bucketize, type_promotion_kind=None) +def bucketize( + input: TensorBox, + boundaries: TensorBox, + *, + out_int32: bool = False, + right: bool = False, +): + assert len(boundaries.get_size()) == 1 + + if not ( + V.graph.has_feature(input, BackendFeature.BUCKETIZE) + and V.graph.has_feature(boundaries, BackendFeature.BUCKETIZE) + ): + return fallback_handler(aten.bucketize.Tensor, add_to_fallback_set=False)( + input, boundaries, out_int32=out_int32, right=right + ) + + # The entire boundaries tensor needs to be used by ops.bucketize, so we + # need to realize it into global memory; or in other words, we can't + # guarantee that boundaries.get_name() (used below) will exist unless + # we call boundaries.realize(). + boundaries.realize() + device = input.get_device() + input_loader = input.make_loader() + + index_dtype = torch.int32 if out_int32 else torch.int64 + + def inner_fn(index): + val = input_loader(index) + indices = ops.bucketize( + val, + _boundaries_helper(boundaries), + 0, + index_dtype, + right, + ) + + return indices + + return Pointwise.create( + device=device, + dtype=index_dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +def require_dense(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_stride1, (args, kwargs) + ) + return args, kwargs + + +def require_contiguous(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_contiguous, (args, kwargs) + ) + return args, kwargs + + +def require_channels_last(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs) + ) + return args, kwargs + + +def constrain_to_fx_strides(fx_node, *args, **kwargs): + def apply_constraint(arg, fx_arg): + if isinstance(arg, ir.IRNode): + stride_order = ir.get_stride_order(fx_arg.meta["val"].stride()) + return ir.ExternKernel.require_stride_order(arg, stride_order) + if isinstance(arg, dict): + return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg.keys()} + return arg + + args = tuple( + apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) + ) + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + +# TODO(jansel): we should implement decomps or lowerings for these +# https://github.com/pytorch/torchdynamo/issues/327 +FALLBACK_ALLOW_LIST = { + "torchvision::roi_align", +} + + +def sdpa_constraint(fx_node, *args, **kwargs): + # sdpa requires dense last dimension] + + def apply_constraint(idx, arg, fx_arg): + if not isinstance(arg, ir.IRNode): + return arg + + meta_val = fx_arg.meta["val"] + meta_stride = meta_val.stride() + + stride_order = ir.get_stride_order(meta_stride) + + if stride_order and stride_order[-1] != 0: + # contiguous stride order + stride_order = list(reversed(range(len(arg.get_size())))) + + if ( + fx_node.target + == aten._scaled_dot_product_efficient_attention_backward.default + and idx in (0, 5) + ): + assert len(stride_order) == 4 + # The 0 and 5th arguments for aten._scaled_dot_product_efficient_attention_backward.default + # are for out and gradient_out. They have to be in + # (3, 1, 2, 0) stride order. Otherwise the kernel will crash. + # Check https://github.com/pytorch/pytorch/issues/138772 + stride_order = (3, 1, 2, 0) + + if not meta_val.is_cuda: + return ir.ExternKernel.require_stride_order(arg, stride_order) + + # This is the minimum alignment required by SDPA kernels for attention_bias. + # This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask + ALIGNMENT = 8 + + assert isinstance(arg, TensorBox) + if len(arg.get_size()) not in (3, 4): + return arg + + def is_aligned_realized_tensor(x): + aligned_strides = all( + (V.graph.sizevars.size_hint(x.get_stride()[i]) % ALIGNMENT) == 0 + for i in range(len(x.get_stride()) - 1) + ) + return ( + V.graph.sizevars.size_hint(x.get_stride()[-1]) + ) == 1 and aligned_strides + + try: + arg.get_stride() + if is_aligned_realized_tensor(arg): + return V.graph.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride + ) + except AttributeError: + pass + + def is_aligned(x): + return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0 + + if isinstance(arg.data, ir.BaseView): + if not is_aligned(arg): + if is_aligned(arg.unwrap_view()): + return V.graph.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride + ) + + return ir.ExternKernel.require_stride_order(arg, stride_order) + + args = tuple( + apply_constraint(idx, arg, fx_arg) + for idx, (arg, fx_arg) in enumerate(zip(args, fx_node.args)) + ) + kwargs = {k: apply_constraint(-1, v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + +# WIP +make_fallback(aten._adaptive_avg_pool3d) # @isuruf +make_fallback(aten.adaptive_max_pool3d) # @isuruf +make_fallback(aten.fractional_max_pool3d) # @isuruf +make_fallback(aten.max_pool3d_with_indices) # @isuruf (can this one be implemented?) + + +# 1) Easy +make_fallback(aten.uniform, warn=False) +make_fallback(aten.exponential.default, warn=False) # (fails accuracy on test_torch.py) +make_fallback(aten._pdist_forward) # Has decomp. Needs benchmarks +make_fallback(aten.soft_margin_loss_backward, warn=False) # py_impl? + + +# 1.5) Easy or Impossible +make_fallback(aten._cdist_forward) # p=2 should be feasible +make_fallback(aten._cdist_backward) + +# 2) Medium +make_fallback(aten._trilinear) + + +# 3) Difficult +# Scans +# See the discussion at +# https://dev-discuss.pytorch.org/t/pytorch-sparse-gnn-compiler-rfc/1644/19 +make_fallback(aten.segment_reduce.default) +make_fallback(aten._segment_reduce_backward.default) + +# Histogram (need to implement Histogram IR) +make_fallback(aten.histc) +make_fallback(aten.histogram.bin_ct) +make_fallback(aten._histogramdd_bin_edges.default) +make_fallback(aten._histogramdd_from_bin_cts.default) + +# Need templated kernel +make_fallback(aten.addbmm) +make_fallback(aten._addmm_activation, warn=False) + +# Need templated kernel. Probably impossible to write efficiently +make_fallback(aten.convolution_backward, constrain_to_fx_strides) +make_fallback(aten._cudnn_rnn, require_dense) +make_fallback(aten._cudnn_rnn_backward, require_contiguous) + +# Haven't checked but sound difficult / impossible +make_fallback(aten._embedding_bag, require_contiguous) +make_fallback(aten._embedding_bag_forward_only, require_contiguous) +make_fallback(aten._embedding_bag_backward) +make_fallback(aten._embedding_bag_per_sample_weights_backward) +make_fallback(aten._embedding_bag_per_sample_weights_backward) +make_fallback(aten._fused_moving_avg_obs_fq_helper) +make_fallback(aten._fused_moving_avg_obs_fq_helper_functional) + + +# 4) Backwards (try py_impl'ing them) when fwd is written as a decomp +make_fallback(aten.max_pool3d_with_indices_backward) +make_fallback(aten._adaptive_avg_pool2d_backward, require_dense) +make_fallback(aten._adaptive_avg_pool3d_backward) +make_fallback(aten.adaptive_max_pool2d_backward) +make_fallback(aten.adaptive_max_pool3d_backward) +make_fallback(aten.fractional_max_pool2d_backward) +make_fallback(aten.fractional_max_pool3d_backward) +make_fallback(aten.replication_pad1d_backward) +make_fallback(aten.replication_pad2d_backward) +make_fallback(aten.upsample_linear1d_backward) +make_fallback(aten.upsample_bicubic2d_backward, require_contiguous) +make_fallback(aten.upsample_trilinear3d_backward) +make_fallback(aten.grid_sampler_2d_backward, require_dense) +make_fallback(aten._pdist_backward) + + +# 5) Impossible (missing triton/CPU features) + +# Sorting / Sorting-like +make_fallback(aten.sort) +make_fallback(aten.sort.stable) +make_fallback(aten.kthvalue) +make_fallback(aten.topk) +make_fallback(aten.mode) +make_fallback(aten.median) +make_fallback(aten.nanmedian) +make_fallback(aten.randperm) +# see: https://github.com/pytorch/pytorch/pull/121354 +make_fallback(aten.resize_) +make_fallback(aten.resize_as_) + +# Linalg +make_fallback(aten._linalg_det) +make_fallback(aten.linalg_householder_product) +make_fallback(aten.linalg_inv_ex) +make_fallback(aten.linalg_ldl_factor_ex) +make_fallback(aten.linalg_ldl_solve) +make_fallback(aten.linalg_lu) +make_fallback(aten.linalg_lu_factor_ex) +make_fallback(aten.linalg_lu_solve) +make_fallback(aten.linalg_matrix_exp) +make_fallback(aten.linalg_qr) +make_fallback(aten._linalg_slogdet) +make_fallback(aten._linalg_solve_ex) +make_fallback(aten.linalg_solve_triangular) +make_fallback(aten._linalg_svd) +make_fallback(aten.lu_unpack) +make_fallback(aten.ormqr) +make_fallback(aten._linalg_check_errors) +make_fallback(aten.linalg_pinv.atol_rtol_tensor) +make_fallback(aten._linalg_eigh) +make_fallback(aten.triangular_solve) +make_fallback(aten.linalg_cholesky_ex) +make_fallback(aten.cholesky_inverse) +make_fallback(aten.cholesky_solve) +make_fallback(aten.geqrf) +make_fallback(aten._fft_r2c) # needs complex as well + +# Data dependent (are these necessary?) +make_fallback(aten.nonzero.default) + +# Misc +make_fallback(aten.gcd.default, warn=False) +make_fallback(aten._thnn_fused_lstm_cell, require_dense) +make_fallback(torch._prims.rng_prims.run_and_save_rng_state) +make_fallback(torch._prims.rng_prims.run_with_rng_state) + +# Implmented / Half implemented +# Scans. Implemented for CUDA, missing CPU +make_fallback(aten.masked_scatter) +make_fallback(aten.masked_scatter_backward) + +# Complex number support +make_fallback(aten.view_as_complex, require_contiguous) +make_fallback(aten.angle) # needs complex + +# Needs efficentzerotensor +make_fallback(aten._efficientzerotensor) + +# Needs Sparse +make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors) +make_fallback(aten.to_sparse) +make_fallback(aten._to_sparse) + +# Needs dimname support +make_fallback(aten.zeros.names) + +# 6) Pattern-matched +make_fallback( + aten._scaled_dot_product_efficient_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_efficient_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_cudnn_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_cudnn_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_for_cpu.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_for_cpu_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback(aten._flash_attention_forward.default, sdpa_constraint) +make_fallback(aten._flash_attention_backward.default, sdpa_constraint) +make_fallback(aten._efficient_attention_forward.default, sdpa_constraint) +make_fallback(aten._efficient_attention_backward.default, sdpa_constraint) + +# index_reduce requires fallback when use_scatter_fallback(...) returns True +make_fallback(aten.index_reduce) + + +# Register with type_promotion_kind None. +# For example, fp16.copy_(fp32) should **not** promote the first input's dtype. +@register_lowering(aten.copy, type_promotion_kind=None) +def copy(self, src, non_blocking=False): + x = src + if self.get_device() != src.get_device(): + x = to_device(x, self.get_device()) + if self.get_dtype() != src.get_dtype(): + x = to_dtype(x, self.get_dtype()) + + if self.get_size() != src.get_size(): + out = expand(x, self.get_size()) + return clone(out) + return clone(x) + + +@register_lowering(aten.clone) +def clone(x, *, memory_format=None): + # TODO(jansel): memory format + input_graphs = fetch_graphs(x) + node_name = f'clone_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.clone, node_name) + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + ranges=list(x.get_size()), + traced_graph=new_graph, + node_name=node_name + ) + + +def clone_preserve_reinterpret_view(x): + reinterpret_view_layouts = [] + if isinstance(x, TensorBox) and isinstance(x.data, ir.ReinterpretView): + x = x.data # unwrap TensorBox + while isinstance(x, ir.ReinterpretView): + reinterpret_view_layouts.append(x.get_layout()) + x = x.data + x = TensorBox(x) + + x = clone(x) + + if reinterpret_view_layouts: + x = x.data # unwrap TensorBox + for layout in reinterpret_view_layouts[::-1]: + x = ir.ReinterpretView(data=x, layout=layout) + x = TensorBox(x) + + return x + + +if hasattr(aten, "lift_fresh_copy"): + register_lowering(aten.lift_fresh_copy)(clone) + + +@register_lowering(prims.iota) +def iota( + length, + *, + start, + step, + dtype, + device, + requires_grad, +): + def fn(index): + return ops.index_expr(step * index[0] + start, dtype=dtype) + + node_name = f'iota_{next(node_id)}' + new_graph = merge_traced_graphs([length], prims.iota, node_name, \ + start=start, step=step, \ + dtype=dtype, device=device, \ + requires_grad=requires_grad) + return Pointwise.create( + device=decode_device(device), + dtype=dtype, + inner_fn=fn, + ranges=[length], + traced_graph=new_graph, + node_name=node_name + ) + + +@register_lowering(aten.select_scatter, type_promotion_kind=None) +def select_scatter(x, src, dim: int, index: int): + assert x.get_dtype() == src.get_dtype() + input_graphs = fetch_graphs([x, src, dim, index]) + node_name = f'select_scatter_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.select_scatter, node_name) + x_loader = x.make_loader() + dim = _validate_dim(x, dim, 0) + if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)): + index = index + x.get_size()[dim] + V.graph.sizevars.guard_leq(0, index) # type: ignore[arg-type] + V.graph.sizevars.guard_lt(index, x.get_size()[dim]) # type: ignore[arg-type] + src = expand(unsqueeze(src, dim), x.get_size()) + src_loader = src.make_loader() + + def inner_fn(idx): + return ops.where( + ops.eq( + ops.index_expr(idx[dim], torch.int32), + ops.index_expr(index, torch.int32), + ), + src_loader(idx), + x_loader(idx), + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(x.get_size()), + traced_graph=new_graph, + node_name=node_name + ) + + +@register_lowering(aten.slice_scatter, type_promotion_kind=None) +def slice_scatter(x, src, dim=0, start=None, end=None, step=1): + assert x.get_dtype() == src.get_dtype() + input_graphs = fetch_graphs([x, src]) + node_name = f'slice_scatter_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.slice_scatter, node_name, \ + dim=dim, + start=start, + end=end, + step=step) + x_loader = x.make_loader() + dim = _validate_dim(x, dim, 0) + dim_size = x.get_size()[dim] + + start, end = ir.SliceView.normalize_start_end(x, dim, start, end) + + src_size = list(x.get_size()) + src_size[dim] = FloorDiv(end - start + (step - 1), step) + src = expand(src, src_size) + src_loader = src.make_loader() + + def inner_fn(idx): + if start == 0 and end == dim_size and step == 1: + # selecting every element is the same as just src.clone() + return src_loader(idx) + + idx_dim = ops.index_expr(idx[dim], torch.int64) + src_idx = list(idx) + src_idx[dim] = FloorDiv(idx[dim] - start, step) + + mask = [] + if start != 0: + mask.append( + ops.ge( + idx_dim, + ops.index_expr(sympy.expand(start), torch.int64), + ) + ) + if end != dim_size: + mask.append( + ops.lt( + idx_dim, + ops.index_expr(sympy.expand(end), torch.int64), + ) + ) + if step != 1: + mask.append( + ops.eq( + ops.index_expr( + ModularIndexing(idx[dim] - start, 1, step), torch.int64 + ), + ops.constant(0, torch.int64), + ) + ) + assert mask + mask = functools.reduce(ops.and_, mask) + src_val = ops.masked( + mask, + lambda: src_loader(src_idx), + 0 if is_integer_type(x) else 0.0, + ) + return ops.where( + mask, + src_val, + x_loader(idx), + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(x.get_size()), + traced_graph=new_graph, + node_name=node_name + ) + + +def _unwrap(x): + if isinstance(x, (list, tuple)) and len(x) > 0: + return _unwrap(x[0]) + return x + + +@register_lowering([torch.tensor, aten.scalar_tensor]) +def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False): + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + assert_nyi(not pin_memory, "pin_memory") + input_graphs = fetch_graphs([data]) + node_name = f'tensor_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, torch.tensor, node_name, \ + dtype=dtype, + device='npu', + pin_memory=False) + if isinstance(_unwrap(data), int): + dtype = dtype or torch.int64 + else: + dtype = dtype or torch.get_default_dtype() + + ranges: List[sympy.Expr] = [] + + if isinstance(data, sympy.Basic): + + def inner_fn(index): + return ops.index_expr(data, dtype) + + elif isinstance(data, (float, int)): + + def inner_fn(index): + return ops.constant(data, dtype) + + elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8: + # inline small tensors + ranges.append(sympy.Integer(len(data))) + + def inner_fn(index): + def binary_search(start, end): + assert start < end + if end - start == 1: + return ops.constant(data[start], dtype) + mid = (end - start) // 2 + start + return ops.where( + ops.lt( + ops.index_expr(index[0], torch.int64), + ops.constant(mid, torch.int64), + ), + binary_search(start, mid), + binary_search(mid, end), + ) + + if len(data) == 0: + return ops.constant(0, dtype) + return binary_search(0, len(data)) + + else: + return V.graph.add_tensor_constant( + torch.tensor(data, dtype=dtype, device=device) + ) + return Pointwise.create( + device=decode_device(device), + dtype=dtype, + inner_fn=inner_fn, + ranges=ranges, + traced_graph=new_graph, + node_name=node_name + ) + + +@register_lowering(torch.as_tensor) +def as_tensor(data, dtype=None, device=None): + if isinstance(data, TensorBox): + if dtype is not None: + data = to_dtype(data, dtype) + if device is not None: + data = to_device(data, device) + return data + return tensor(data, dtype=dtype, device=device) + + +@register_lowering(torch.LongTensor) +def long_tensor(data): + return tensor(data, dtype=torch.int64) + + +@register_lowering(aten._local_scalar_dense) +def _local_scalar_dense(data): + from torch.fx.experimental.symbolic_shapes import resolve_unbacked_bindings + + # This is interesting! Most lowerings return tensors, so you can just + # return the buffer you allocated and it will get used (or not used, if + # it's dead.) But _local_scalar_dense (aka item) returns an int, + # not a Tensor, so you would have a type mismatch if you return a buffer; + # we are obligated to return a sympy expression instead. However, + # we need to actually codegen the .item() call somehow. We do this + # by registering a faux buffer for the DynamicScalar IR node, which is + # solely responsible for generating this .item(). The buffer is + # not used for anything (notice we discard it); at codegen time, + # the "buffer" just gets assigned None. + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"] + ) + assert unbacked_bindings is not None + assert len(unbacked_bindings) == 1, unbacked_bindings + # NB: Have to be very careful here. V.graph.current_node.meta["val"] + # seemingly also contains a symbol which you want to do binding for, + # but it actually isn't. In particular, if we have later performed + # a deferred runtime assert saying that u0 == s0, you will actually + # see s0 from expr! This is bad because we need to actually generate + # the assert that says u0 == s0, so we need to know where to get u0 + # from (this call). In particular, we must use unbacked_bindings, which + # is guaranteed to have the original, unreplaced symbol in question. + # + # NB2: Another thing we have to be very careful about are symbol bindings + # that require nontrivial refinement, e.g., when you have a binding site + # x: Sym(u0 * 4) = y.item(). Here, the code generation must do a division + # in order to appropriately bind u0. This is communicated via the keypath + # in unbacked_bindings, and we need to hold onto it in order to generate + # code appropriately for this case. + binding_sym, keypath = next(iter(unbacked_bindings.items())) + buffer = ir.DynamicScalar(binding_sym, keypath, data) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + # NB: the replaced expr is OK to use directly downstream, we want + # simplifications in this case! + val = V.graph.current_node.meta["val"] + if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)): + return val.node.expr + else: + return sympy.sympify(val) + + +@register_lowering(aten._assert_scalar) +def _assert_scalar(data, msg): + # NB: These will be handled at codegen time + # Not sure if we are guaranteed to be able to serve out truth from the + # deferred_runtime_asserts, TODO: try this assert out + # assert bool(data.scalar), data + return None + + +def _full(fill_value, device, dtype, size): + value = fill_value + dtype = torch.int32 if dtype == torch.int64 else dtype + if not isinstance(fill_value, (int, float)) and hasattr(value, "value"): + value = value.value + + if isinstance(value, (int, float)): + + def inner_fn(index): + return ops.constant(value, dtype) + + elif isinstance(value, sympy.Basic): + + def inner_fn(index): + return ops.index_expr(value, dtype) + + else: + assert len(value.get_size()) == 0 + value_loader = value.make_loader() + + def inner_fn(index): + return value_loader([]) + + node_name = f'full_{next(node_id)}' + new_graph = merge_traced_graphs([size, fill_value], aten.full.default, node_name, \ + device='npu', dtype=dtype, layout = torch.strided, pin_memory = False) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(size), + traced_graph=new_graph, + node_name=node_name + ) + + +@register_lowering(aten.full_like, type_promotion_kind=None) +def full_like(x, fill_value, **kwargs): + return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs) + + +def tensor_constructor(fill_value): + # torch.zeros, torch.ones, etc + def inner( + *size, + names=None, + dtype=None, + device=None, + layout=None, + pin_memory=False, + memory_format=None, + ): + assert_nyi(names is None, "named tensors") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + assert_nyi(not pin_memory, "pin_memory") + device = decode_device(device) + dtype = dtype or torch.get_default_dtype() + if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): + size = tuple(size[0]) + # See https://github.com/pytorch/pytorch/issues/118102 + # All sizes at lowering time should be sympy.Symbol, not SymInt! + for s in size: + assert not isinstance(s, torch.SymInt) + size = [sympy.expand(s) for s in size] + return _full(fill_value, device, dtype, size) + + return inner + + +@register_lowering([torch.empty, aten.empty]) +def empty( + *size, + names=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, + memory_format=None, +): + assert_nyi(names is None, "named tensors") + device = decode_device(device) + if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): + size = tuple(size[0]) + return empty_strided( + size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +def create_tensor_like(creation_fn): + """ + Shim to convert X_like(...) into X(...). For example zeros_like() into zeros(). + """ + + def _constant_like( + x, *, dtype=None, device=None, layout=None, pin_memory=False, memory_format=None + ): + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + if dtype is None: + dtype = x.get_dtype() + else: + dtype = decode_dtype(dtype) + device = device or x.get_device() + size = list(x.get_size()) + return creation_fn( + size, dtype=dtype, device=device, layout=layout, pin_memory=pin_memory + ) + + return _constant_like + + +def constant_like(fill_value): + return create_tensor_like(tensor_constructor(fill_value)) + + +empty_like = register_lowering(aten.empty_like)(create_tensor_like(empty)) +ones_like = create_tensor_like(tensor_constructor(1)) +zeros_like = create_tensor_like(tensor_constructor(0)) + + +def new_constant(fill_value): + def _new_constant( + x, size, *, dtype=None, layout=None, device=None, pin_memory=None + ): + assert isinstance(size, (list, tuple)) + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + dtype = decode_dtype(dtype) or x.get_dtype() + device = device or x.get_device() + size = [sympy.Integer(s) for s in size] + return _full(fill_value, device, dtype, size) + + return _new_constant + + +@register_lowering(aten.new_empty) +def new_empty(x, size, *, dtype=None, layout=None, device=None, pin_memory=None): + if dtype is None: + dtype = x.get_dtype() + if device is None: + device = x.get_device() + return empty_strided( + size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_lowering(aten.empty_strided) +def empty_strided( + size, stride, *, dtype=None, layout=None, device=None, pin_memory=None +): + assert isinstance(size, (list, tuple)) + assert isinstance(stride, (list, tuple, type(None))) + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + dtype = decode_dtype(dtype) or torch.get_default_dtype() + device = device or torch.tensor(0.0).device + pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size) + pointwise.realize() + buffer = pointwise.data.data + # explicitly set ranges to zeros in order to make a NopKernelSchedulerNode + buffer.data = dataclasses.replace(buffer.data, ranges=[0] * len(size)) + assert isinstance(buffer, ir.ComputedBuffer) + size = [sympy.expand(s) for s in size] + stride = ( + [sympy.expand(s) for s in stride] + if stride + else ir.FlexibleLayout.contiguous_strides(size) + ) + buffer.layout = ir.FixedLayout( + device=device, + dtype=dtype, + size=size, + stride=stride, + ) + return pointwise + + +@register_lowering(aten.new_empty_strided) +def new_empty_strided( + x, size, stride, *, dtype=None, layout=None, device=None, pin_memory=None +): + if dtype is None: + dtype = x.get_dtype() + if device is None: + device = x.get_device() + return empty_strided( + size, stride, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_lowering(prims.copy_strided.default) +def copy_strided(x, stride): + stride = [V.graph.sizevars.size_hint(s) for s in stride] + stride_order = sorted(range(len(stride)), key=stride.__getitem__) + return ir.ExternKernel.require_stride_order(x, stride_order) + + +@register_lowering([torch.full, aten.full]) +def full(size, fill_value, **kwargs): + assert kwargs.get("dtype") is not None, "dtype should be handled by decomposition" + return tensor_constructor(fill_value)(size, **kwargs) + + +@register_lowering(aten.gather, type_promotion_kind=None) +def gather(x, dim, index, sparse_grad=False): + # sparse_grad doesn't affect forward computation, + # and backward tracing is taken care of by AOT Autograd + assert isinstance(x, TensorBox) + if index.get_numel() == 0: + # Empty index case. Return an empty array with the same shape + return new_empty(x, index.get_size()) + + assert index.get_dtype() == torch.int64 + size = x.get_size() + offset = len(size) == 0 + dim = _validate_dim(x, dim, offset) + + if offset: + x = expand(x, [1]) + size = [1] + + x_loader = x.make_loader() + index_loader = index.make_loader() + + def fn(idx): + idx = list(idx) + gather_idx = ops.indirect_indexing(index_loader(idx), size[dim]) + if len(idx) == 0: + idx = [gather_idx] + else: + idx[dim] = gather_idx + return x_loader(idx) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=index.get_size(), + ) + + +@register_lowering(aten.embedding, type_promotion_kind=None) +def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): + assert not sparse + assert isinstance(weight, TensorBox) + assert isinstance(indices, TensorBox) + assert "int" in str(indices.get_dtype()) + + weight_loader = weight.make_loader() + indices_loader = indices.make_loader() + indices_ndim = len(indices.get_size()) + weight_size = weight.get_size() + new_size = [*indices.get_size(), *weight_size[1:]] + + def fn(idx): + assert len(idx) == len(new_size), f"{idx} != {new_size}" + var_index = indices_loader(idx[:indices_ndim]) + weight_idx = [ops.indirect_indexing(var_index, weight_size[0])] + [ + *idx[indices_ndim:] + ] + return weight_loader(weight_idx) + + input_graphs = fetch_graphs([weight, indices]) + node_name = f'embedding_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.embedding, node_name, \ + padding_idx=padding_idx, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + + return Pointwise.create( + device=weight.get_device(), + dtype=weight.get_dtype(), + inner_fn=fn, + ranges=new_size, + traced_graph=new_graph, + node_name=node_name + ) + + +def check_and_broadcast_indices(indices, device): + assert all( + i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8) + for i in indices + if i is not None + ), f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}" + if any( + i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None + ): + raise NotImplementedError("Fallback for bool indices") + + valid_idxs = [i for i, x in enumerate(indices) if isinstance(x, TensorBox)] + assert len(valid_idxs) > 0, "requires at least 1 non-None index" + new_indices = [None] * len(indices) + for i, x in zip(valid_idxs, broadcast_tensors(*[indices[i] for i in valid_idxs])): + # Eager allows indices to be CPU tensor when running on CUDA + # FIXME: Calling to_device(x, device) should work but + # test_advancedindex_mixed_cpu_devices still fails + if x.get_device() != device: + raise NotImplementedError("Fallback when indices is on a different device") + new_indices[i] = x + return new_indices, valid_idxs + + +def index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + x_loader, + check, + wrap_neg=True, +): + # Note that behavior of indexing differs when there are non consecutive + # tensors. In this case, the tensor index is pulled to the beginning. + # + # Suppose a = torch.arange(3 * 4 * 5 * 6 * 7).view(3, 4, 5, 6, 7) + # x = torch.tensor[1,2] + # Then, a[:,x,:,x,:] will have shape 2,3,5,7 as due to x,:,x then 2 will + # be pulled to the front. + non_consecutive_tensors = False + for previous, current in zip(tensor_indices, tensor_indices[1:]): + if current - previous != 1: + non_consecutive_tensors = True + + output_size = [x_size[i] for i, val in enumerate(indices) if val is None] + output_size = [*output_size, *x_size[len(output_size) + len(tensor_indices) :]] + + first_tensor_index = tensor_indices[0] + if non_consecutive_tensors: + output_size = tensor_size + output_size + else: + output_size = ( + output_size[:first_tensor_index] + + tensor_size + + output_size[first_tensor_index:] + ) + + def fn(idx): + assert len(idx) == len(output_size) + assert len(indices_loaders) == len(indexed_size) + + rank = len(tensor_size) + new_index = [] + first_tensor_index = tensor_indices[0] + start_offset = 0 if non_consecutive_tensors else first_tensor_index + next_idx = 0 + for i in range(tensor_indices[-1] + 1): + if i == start_offset: + next_idx += rank + if indices[i] is None: + assert next_idx < len(idx) + new_index.append(idx[next_idx]) + next_idx += 1 + else: + loader = indices_loaders[i] + assert loader is not None + size = indexed_size[i] + new_index.append( + ops.indirect_indexing( + loader(idx[start_offset : start_offset + rank]), + size, + check=check, + wrap_neg=wrap_neg, + ) + ) + new_index = [ + *new_index, + *idx[next_idx:], + ] + return new_index if x_loader is None else x_loader(new_index) + + return output_size, fn + + +def index_impl(x, indices, check): + output_size, inner_fn, _ = index_impl_helper(x, indices, check) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=output_size, + ) + + +def index_impl_helper(x, indices, check, wrap_neg=True): + assert isinstance(indices, (list, tuple)) + x_loader = x.make_loader() + indices, tensor_indices = check_and_broadcast_indices(indices, x.get_device()) + assert len(tensor_indices) > 0, "Must have at least one valid idx" + + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + # no guards on output size, all the guards are set in broadcast_tensors + + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + + x_size = x.get_size() + + indexed_size = [x_size[i] for i in range(len(indices)) if indices[i] is not None] + if check and 0 in indexed_size and 0 not in tensor_size: + raise IndexError("index is out of bounds for dimension with size 0") + + indexed_size = [x_size[i] for i in range(len(indices))] + output_size, index_inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + None, + check=check, + wrap_neg=wrap_neg, + ) + + def inner_fn(idx): + return x_loader(index_inner_fn(idx)) + + return output_size, inner_fn, index_inner_fn + + +@register_lowering(aten.index, type_promotion_kind=None) +def index(x, indices): + try: + return index_impl(x, indices, check=True) + except NotImplementedError: + # Fallback to ATen for boolean indexing + x.realize() + return fallback_handler(aten.index.Tensor, add_to_fallback_set=False)( + x, indices + ) + + +@register_lowering(aten._unsafe_index, type_promotion_kind=None) +def _unsafe_index(x, indices): + return index_impl(x, indices, check=False) + + +# All the indexing decompositions are written in terms of index, index_put, and index_put_ +# We cannot have this lowering as a decomposition as it introduces +# mutation in the graph, which is bad for Aot Autograd. Aot Autograd runs dead +# code elimination and common subexpression elimination optimizations, which +# assume graphs to be side-effect free. More details at +# https://github.com/pytorch/torchdynamo/issues/1235 +# and +# https://github.com/pytorch/torchdynamo/issues/1863 +@register_lowering(aten.index_put) +def index_put(x, indices, values, accumulate=False): + return index_put_(clone(x), indices, values, accumulate) + + +@register_lowering(aten._unsafe_index_put) +def _unsafe_index_put(x, indices, values, accumulate=False): + return index_put_impl_(clone(x), indices, values, accumulate, check=False) + + +def index_put_as_masked_fill(self, indices, value, accumulate): + if value.get_device() != self.get_device(): + value = to_device(value, self.get_device()) + if accumulate: + value = add(self, value) + return mutate_to(self, where(indices[0], value, self)) + + +def index_put_fallback(self, indices, values, accumulate): + deterministic = torch.are_deterministic_algorithms_enabled() + if is_triton(values) and (accumulate or deterministic): + msg = ( + "index put with accumulate." + if not deterministic + else "deterministic index put." + ) + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + ir.IndexPutFallback(V.graph.current_node.target, self, indices, values, accumulate) + return self + + +@register_lowering(aten.index_put_, type_promotion_kind=None) +def index_put_(self, indices, values, accumulate=False): + return index_put_impl_(self, indices, values, accumulate, check=True) + + +@register_lowering(inductor_prims._unsafe_index_put_, type_promotion_kind=None) +def _unsafe_index_put_(self, indices, values, accumulate=False): + return index_put_impl_(self, indices, values, accumulate, check=False) + + +def index_put_impl_(self, indices, values, accumulate, check): + # Dispatch to masked fill for single boolean index with single value + if ( + values.get_numel() == 1 + and len(indices) == 1 + and indices[0].get_dtype() in {torch.bool, torch.uint8} + ): + mask = indices[0] + for _ in range(len(mask.get_size()), len(self.get_size())): + mask = unsqueeze(mask, -1) + return index_put_as_masked_fill(self, [mask], values, accumulate) + + # Fallback in torch deterministic mode + if torch.are_deterministic_algorithms_enabled(): + return index_put_fallback(self, indices, values, accumulate) + + # Fallback if there is a boolean index + for index in indices: + if index is not None and index.get_dtype() in {torch.bool, torch.uint8}: + return index_put_fallback(self, indices, values, accumulate) + + x_size = self.get_size() + x_ndim = len(x_size) + + if accumulate and needs_fallback_due_to_atomic_add_limitations(self.get_dtype()): + # self is an scalar Tensor + if x_ndim == 0: + self = view(self, [1]) + self = index_put_fallback(self, indices, values, accumulate) + if x_ndim == 0: + self = view(self, []) + return self + + values = to_dtype(values, self.get_dtype()) + + try: + # Note that code will only get here when dtype is uint32 + indices, tensor_indices = check_and_broadcast_indices( + indices, self.get_device() + ) + except NotImplementedError: + return index_put_fallback(self, indices, values, accumulate) + + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + + assert isinstance(self, TensorBox) + self.realize() + + # self is an scalar Tensor + if x_ndim == 0: + self = view(self, [1]) + + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + indexed_size = [x_size[i] for i in range(len(indices))] + + expected_vals_size, inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + None, + check=check, + ) + + values = expand(values, expected_vals_size) + # all guards are set above during broadcast_tensors and expand + + scatter = ir.Scatter( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=values.make_loader(), + ranges=expected_vals_size, # iter_ranges, + output_indexer=inner_fn, + scatter_mode="atomic_add" if accumulate else None, + ) + buffer = ir.ComputedBuffer( + name=None, + layout=ir.MutationLayoutSHOULDREMOVE(self), + data=scatter, + ) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + + if x_ndim == 0: + self = view(self, []) + return self + + +fallback__unsafe_masked_index = fallback_handler( + aten._unsafe_masked_index.default, add_to_fallback_set=False +) + +fallback__unsafe_masked_index_put_accumulate = fallback_handler( + aten._unsafe_masked_index_put_accumulate.default, add_to_fallback_set=False +) + + +@register_lowering(aten._unsafe_masked_index, type_promotion_kind=None) +def _unsafe_masked_index(self, mask, indices, fill): + ranges, _, _unsafe_index_fn = index_impl_helper( + self, indices, check=False, wrap_neg=False + ) + mask_loader = mask.make_loader() + self_loader = self.make_loader() + + def inner_fn(idx): + if mask.dtype != torch.bool: + mask_val = ops.to_dtype(mask_loader(idx), torch.bool) + else: + mask_val = mask_loader(idx) + return ops.masked(mask_val, lambda: self_loader(_unsafe_index_fn(idx)), fill) + + return Pointwise.create( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=inner_fn, + ranges=ranges, + ) + + +@register_lowering(aten._unsafe_masked_index_put_accumulate, type_promotion_kind=None) +def _unsafe_masked_index_put_accumulate(x, mask, indices, values): + masked_value = where(mask, values, 0) + shape = x.get_size() + clamped_indices = [ + clamp(indices[i], -shape[i], shape[i] - 1) if indices[i] else None + for i in range(len(indices)) + ] + # TODO: use a masked store for this. currently only triton + # supports masked stores and cpp backend does not. + return _unsafe_index_put(x, clamped_indices, masked_value, accumulate=True) + + +@make_pointwise +def clamp(a, min, max): + return ops.maximum(min, ops.minimum(max, a)) + + +@register_lowering(aten.as_strided_scatter, type_promotion_kind=None) +def as_strided_scatter(self, src, size, stride, storage_offset=None): + output = clone(self) + output_view = as_strided(output, size, stride, storage_offset) + copy_(output_view, src) + return output + + +@register_lowering(aten.scatter, type_promotion_kind=None) +def scatter(x, dim: int, index, src, **kwargs): + return scatter_(clone(x), dim, index, src, **kwargs) + + +def scatter_fallback( + op_overload: torch._ops.OpOverload, + self, + dim: int, + index, + src, + *, + reduce: Optional[str] = None, + include_self: bool = True, +): + src_is_tensor = isinstance(src, TensorBox) + if use_scatter_fallback( + op_overload, + reduce, + self.get_dtype(), + src.get_dtype() if src_is_tensor else type(src), + src.get_device().type if src_is_tensor else "not impl", + src_is_tensor, + ): + ir.ScatterFallback( + op_overload, + self, + dim, + index, + src, + reduce=reduce, + include_self=include_self, + ) + return self + + return None + + +@register_lowering(aten.scatter_, type_promotion_kind=None) +def scatter_(self, dim: int, index, src, *, reduce: Optional[str] = None): + assert reduce in {None, "add", "multiply"} + if reduce is None: + op_overload = getattr(aten.scatter_, V.graph.current_node.target._overloadname) # type: ignore[union-attr] + fallback_result = scatter_fallback( + op_overload, self, dim, index, src, reduce=reduce + ) + if fallback_result is not None: + return fallback_result + + if reduce == "add": + reduce = "sum" + elif reduce == "multiply": + reduce = "prod" + return scatter_reduce_(self, dim, index, src, reduce) + + +@register_lowering(aten.scatter_add, type_promotion_kind=None) +def scatter_add(x, dim: int, index, src): + return scatter_add_(clone(x), dim, index, src) + + +@register_lowering(aten.scatter_add_, type_promotion_kind=None) +def scatter_add_(x, dim: int, index, src): + return scatter_reduce_(x, dim, index, src, "sum") + + +@register_lowering(aten.scatter_reduce, type_promotion_kind=None) +def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs): + return scatter_reduce_(clone(x), dim, index, src, reduction_type, **kwargs) + + +@register_lowering(aten.scatter_reduce_, type_promotion_kind=None) +def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True): + assert reduce in {None, "sum", "prod", "mean", "amax", "amin"} + assert ( + len(aten.scatter_reduce_.overloads()) == 1 + and "two" in aten.scatter_reduce_.overloads() + ), "aten.scatter_reduce_.two is not the unique overload of aten.scatter_reduce_" + + if isinstance(src, Number): + src = full_like(self, src) + + fallback_result = scatter_fallback( + aten.scatter_reduce_.two, + self, + dim, + index, + src, + reduce=reduce, + include_self=include_self, + ) + + if fallback_result: + return fallback_result + + assert isinstance(self, TensorBox) + assert "int" in str(index.get_dtype()) + + ndim = len(self.get_size()) + if ndim == 0: + self = view(self, [1]) + + if isinstance(src, TensorBox) and len(src.get_size()) == 0: + src = view(src, [1]) + + if isinstance(index, TensorBox) and len(index.get_size()) == 0: + index = view(index, [1]) + + if index.get_numel() == 0: + return self + + dim = _validate_dim(self, dim) + + self.realize() + index_loader = index.make_loader() + src_loader = src.make_loader() if isinstance(src, TensorBox) else None + + def output_indexer(idx): + # self is captured from the end of the function, so it may have 0 dim + shape = self.get_size() + ndim = len(shape) + indirect_idx = list(idx) + indirect_idx[dim] = ops.indirect_indexing( + index_loader(idx), 1 if ndim == 0 else shape[dim], wrap_neg=False + ) + return indirect_idx + + def fn(idx): + if src_loader: + return src_loader(idx) + else: + # src is a scalar + return ops.constant(src, self.get_dtype()) + + def backend_reduce_str(reduce): + if reduce == "sum": + return "atomic_add" + else: + # TODO: Need to support more reduction type + assert reduce is None + return None + + if not include_self: + # zero out the corresponding elements first + zero_out = ir.Scatter( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=lambda index: ops.constant(0, self.get_dtype()), + ranges=index.get_size(), + output_indexer=output_indexer, + scatter_mode=None, + ) + buffer = ir.ComputedBuffer( + name=None, + layout=ir.MutationLayoutSHOULDREMOVE(self), + data=zero_out, + ) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + + # self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 + # self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 + # self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 + scatter = ir.Scatter( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=fn, + ranges=index.get_size(), + output_indexer=output_indexer, + scatter_mode=backend_reduce_str(reduce), + ) + buffer = ir.ComputedBuffer( + name=None, + layout=ir.MutationLayoutSHOULDREMOVE(self), + data=scatter, + ) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + + if ndim == 0: + self = view(self, []) + return self + + +def upsample_nearestnd( + x, + output_size, + scales_x: Tuple[Optional[float], ...], + n: int = 2, + exact: bool = False, +): + x.realize_hint() # elements are reused + x_loader = x.make_loader() + i_sizes = x.get_size()[-n:] + batch = x.get_size()[:-n] + i_sizes = [V.graph.sizevars.evaluate_static_shape(i) for i in i_sizes] + + assert len(scales_x) == n + o_sizes = output_size + + inv_scales = [i / o for i, o in zip(i_sizes, o_sizes)] + for i, scale in enumerate(scales_x): + if scale is not None: + inv_scales[i] = 1.0 / scale + + def scale_fn(x, scale, size): + # Nearest Exact: input_index = round(scale * (output_index + 0.5) - 0.5) + # = floor(scale * (output_index + 0.5)) + # Nearest: input_index = floor(scale * output_index) + x = ops.index_expr(x, torch.float32) + if exact: + x = ops.add(x, ops.constant(0.5, torch.float32)) + x = ops.mul(x, ops.constant(scale, torch.float32)) + x = ops.to_dtype(x, torch.int32) + return ops.indirect_indexing(x, size, check=False) + + def fn(idx): + x = idx[-n:] + b = idx[:-n] + return x_loader( + [*b, *[scale_fn(i, s, size) for i, s, size in zip(x, inv_scales, i_sizes)]] + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=[*batch, *o_sizes], + ) + + +@register_lowering(aten.upsample_nearest1d.default) +def upsample_nearest1d(x, output_size, scales: Optional[float] = None): + return upsample_nearestnd(x, output_size, (scales,), n=1) + + +@register_lowering(aten._upsample_nearest_exact1d.default) +def _upsample_nearest_exact1d(x, output_size, scales: Optional[float] = None): + return upsample_nearestnd(x, output_size, (scales,), n=1, exact=True) + + +@register_lowering(aten.upsample_nearest2d.default) +def upsample_nearest2d( + x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None +): + return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2) + + +@register_lowering(aten._upsample_nearest_exact2d.default) +def _upsample_nearest_exact2d( + x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None +): + return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2, exact=True) + + +@register_lowering(aten.upsample_nearest3d.default) +def upsample_nearest3d( + x, + output_size, + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + return upsample_nearestnd(x, output_size, (scales_d, scales_h, scales_w), n=3) + + +@register_lowering(aten._upsample_nearest_exact3d.default) +def _upsample_nearest_exact3d( + x, + output_size, + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + return upsample_nearestnd( + x, output_size, (scales_d, scales_h, scales_w), n=3, exact=True + ) + + +def _create_constants(*args, dtype): + return tuple(ops.constant(a, dtype) for a in args) + + +@register_lowering(prims.rev.default) +def rev(x, dims): + # note - dims pre-canonicalized + x_loader = x.make_loader() + sizes = x.get_size() + + def loader(idx): + idx = list(idx) + assert len(idx) == len(sizes) + for dim in dims: + idx[dim] = (sizes[dim] - 1) - idx[dim] + + return x_loader(idx) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=loader, + ranges=sizes, + ) + + +@register_lowering(aten.constant_pad_nd, type_promotion_kind=None) +def constant_pad_nd(x, padding, fill_value=0): + assert (len(padding) % 2) == 0 + if all(p == 0 for p in padding): + return clone(x) + + sizes = x.get_size() + + bounds = list(reversed(list(zip(padding[::2], padding[1::2])))) + n = len(sizes) - len(bounds) + + # if padding is a complicated expression, hoist it + bounds_precomp: List[Tuple[sympy.Symbol, Any]] = [] + for l, h in bounds: + bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type] + + output_size = list(sizes[:n]) + mask_sizes = [] + for (low, high), size in zip(bounds, sizes[n:]): + mask_sizes.append(size) + output_size.append(sympy.expand(size + low + high)) + assert len(output_size) == len(sizes) + fill_value = dtype_to_type(x.get_dtype())(fill_value) + + def mask(index): + mask = [] + for idx, (low, high), length in zip(index[n:], bounds, mask_sizes): + if low != 0: + mask.append(range_mask_low(idx, 0)) + if high != 0: + mask.append(range_mask_high(idx, length)) + mask = functools.reduce(ops.and_, mask) + return ops.masked(mask, lambda: x_loader(index), fill_value) + + def offset_fn(index): + new_index = list(index[:n]) + for idx, (low, high) in zip(index[n:], bounds_precomp): + new_index.append(idx - low) + assert len(new_index) == len(index) + return mask(new_index) + + x_loader = x.make_loader() + + input_graphs = fetch_graphs([x, padding]) + node_name = f'constand_pad_nd_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.constant_pad_nd, node_name, value=fill_value) + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=offset_fn, + ranges=output_size, + traced_graph=new_graph, + node_name=node_name + ) + + +def range_mask_low(i: sympy.Expr, low: Union[sympy.Expr, int]): + return ops.ge( + ops.index_expr(i, torch.int64), + ops.index_expr(sympy.Integer(low), torch.int64), + ) + + +def range_mask_high(i: sympy.Expr, high: sympy.Expr): + return ops.lt( + ops.index_expr(i, torch.int64), + ops.index_expr(high, torch.int64), + ) + + +def range_mask(i: sympy.Expr, high: sympy.Expr, low: sympy.Expr): + return ops.and_( + range_mask_low(i, low), + range_mask_high(i, high), + ) + + +def constant_boundary_condition( + x, fill_value, padding=None, pad_fill_value=1.0, dim=None +): + h = x.get_size()[-dim:] + x_loader = x.make_loader() + padding_h = padding or [0] * dim + + def load(index): + prefix = index[:-dim] + ih = index[-dim:] + + mask = functools.reduce( + ops.and_, + [range_mask(ih[i], h[i] + padding_h[i], -padding_h[i]) for i in range(dim)], + ) + return ( + ops.masked( + mask, + lambda: constant_boundary_condition(x, pad_fill_value, dim=dim)( + [*prefix, *ih] + ), + fill_value, + ) + if padding + else ops.masked(mask, lambda: x_loader([*prefix, *ih]), fill_value) + ) + + return load + + +def pooling_size(x, i, kernel_size, stride, padding, ceil_mode): + x_out = FloorDiv( + x + 2 * padding[i] - (kernel_size[i] - 1) + (stride[i] - 1), stride[i] + ) + + if ceil_mode: + x_alt = FloorDiv( + x + 2 * padding[i] - (kernel_size[i] - 1) + 2 * (stride[i] - 1), stride[i] + ) + if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0: + # Sliding windows must start within the input or left padding + x_alt -= 1 # type: ignore[assignment] + V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type] + if V.graph.sizevars.size_hint(x_out - x_alt) == 0: + # ceil mode is actually a no-op, lets guard on that + V.graph.sizevars.guard_equals(x_out, x_alt) + ceil_mode = False + else: + x_out = x_alt + return x_out, ceil_mode + + +def should_fallback_max_pool2d_with_indices(kernel_size, dilation): + kernel_size = pad_listlike(kernel_size, 2) + window_size = kernel_size[0] * kernel_size[1] + return (window_size > 25) or any(d > 1 for d in dilation) + + +def max_pool2d_checks( + x, kernel_size, stride, padding, dilation, *, assert_fallback=None +): + if padding == 0: + padding = [0, 0] + if dilation == 1: + dilation = [1, 1] + if not stride: + stride = kernel_size + + kernel_size = pad_listlike(kernel_size, 2) + stride = pad_listlike(stride, 2) + padding = pad_listlike(padding, 2) + dilation = pad_listlike(dilation, 2) + + assert isinstance(x, TensorBox) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(dilation) == 2 + assert len(x.get_size()) in (3, 4) + + use_fallback = should_fallback_max_pool2d_with_indices(kernel_size, dilation) + if assert_fallback is not None: + assert use_fallback == assert_fallback + + return kernel_size, stride, padding, dilation, use_fallback + + +@register_lowering(prims._low_memory_max_pool2d_with_offsets, type_promotion_kind=None) +def _low_memory_max_pool2d_with_offsets( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode=False, +): + # assert we are not on a fallback path, the inductor decomp should have guaranteed this + kernel_size, stride, padding, dilation, _ = max_pool2d_checks( + x, kernel_size, stride, padding, dilation, assert_fallback=False + ) + + x.realize_hint() + *batch, h, w = x.get_size() + + h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode) + w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode) + + dtype = x.dtype + min_value = ( + False + if dtype is torch.bool + else (float("-inf") if dtype.is_floating_point else torch.iinfo(dtype).min) + ) + + new_size = list(batch) + [h_out, w_out] + if padding[0] or padding[1] or ceil_mode1 or ceil_mode2: + x_loader = constant_boundary_condition(x, min_value, dim=2) + else: + x_loader = x.make_loader() + + def fn(idx, return_index): + *prefix, bh, bw = idx + maxval = None + maxindex = None + for h_inc, w_inc in itertools.product( + range(kernel_size[0]), range(kernel_size[1]) + ): + ih = bh * stride[0] + h_inc - padding[0] + iw = bw * stride[1] + w_inc - padding[1] + val = x_loader([*prefix, ih, iw]) + if return_index: + index = ops.index_expr(h_inc * kernel_size[1] + w_inc, torch.int8) + if maxindex is None: + maxindex = index + else: + maxindex = ops.where(ops.gt(val, maxval), index, maxindex) + if maxval is None: + maxval = val + else: + maxval = ops.maximum(val, maxval) + if return_index: + return maxindex + else: + return maxval + + out = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=functools.partial(fn, return_index=False), + ranges=new_size, + ) + offsets = Pointwise.create( + device=x.get_device(), + dtype=torch.int8, + inner_fn=functools.partial(fn, return_index=True), + ranges=new_size, + ) + return out, offsets + + +@register_lowering( + prims._low_memory_max_pool2d_offsets_to_indices, type_promotion_kind=None +) +def _low_memory_max_pool2d_offsets_to_indices( + offsets, kernel_width, input_width, stride, padding +): + # TODO: Generalize to other max pooling flavors, and arbitrary dim + + offsets_loader = offsets.make_loader() + + def increments_to_index(h_inc, w_inc, bh, bw): + w_in = ops.index_expr(input_width, torch.int64) + hbase = ops.index_expr(bh * stride[0] - padding[0], torch.int64) + wbase = ops.index_expr(bw * stride[1] - padding[1], torch.int64) + ih = hbase + h_inc + iw = wbase + w_inc + return ih * w_in + iw + + def offsets_to_indices(idx): + *prefix, bh, bw = idx + offset = offsets_loader([*prefix, bh, bw]) + kw_const = ops.constant(kernel_width, torch.int32) + h_inc = offset // kw_const + w_inc = offset - (h_inc * kw_const) + return increments_to_index(h_inc, w_inc, bh, bw) + + indices = Pointwise.create( + device=offsets.get_device(), + dtype=torch.int64, + inner_fn=offsets_to_indices, + ranges=offsets.get_size(), + ) + return indices + + +# Fallback selected when we do not decompose to the low-memory path. +make_fallback(aten.max_pool2d_with_indices) + + +fallback_max_pool2d_with_indices_backward = fallback_handler( + aten.max_pool2d_with_indices_backward.default, + add_to_fallback_set=False, +) + + +@register_lowering(aten.max_pool2d_with_indices_backward, type_promotion_kind=None) +def max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices +): + if padding == 0: + padding = [0, 0] + if dilation == 1: + dilation = [1, 1] + if not stride: + stride = kernel_size + + assert isinstance(x, TensorBox) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(dilation) == 2 + assert len(x.get_size()) in (3, 4) + + # we will read this many times, so make sure it is computed + grad_output.realize_hint() + try: + gO_stride = grad_output.get_stride() + except AttributeError: + # some classes don't have `get_stride` + # TODO will need a better way of determining if inputs are channels-last + gO_stride = None + if isinstance(x, TensorBox) and isinstance(x.data.data, Pointwise): # type: ignore[attr-defined] + data = x.data.data # type: ignore[attr-defined] + x_buffer = ir.ComputedBuffer( + name=None, + layout=ir.FlexibleLayout( + device=data.get_device(), + dtype=data.get_dtype(), + size=data.get_size(), + ), + data=data, + ) + x_buffer.decide_layout() + x_stride = x_buffer.get_stride() + else: + try: + x_stride = x.get_stride() + except AttributeError: + x_stride = None + + is_channels_last = (x_stride is not None and x_stride[1] == 1) or ( + gO_stride is not None and gO_stride[1] == 1 + ) + if any(d != 1 for d in dilation): + # dilation NYI + return fallback_max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices + ) + + *batch, height, width = x.get_size() + *_, pooled_height, pooled_width = grad_output.get_size() + + indices_loader = indices.make_loader() + grad_loader = grad_output.make_loader() + new_size = list(x.get_size()) + + h_window_size = max( + max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1) + for h in range(kernel_size[0] * 2) + ) + w_window_size = max( + max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1) + for w in range(kernel_size[1] * 2) + ) + + window_size = h_window_size * w_window_size + + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices + ) + + indices_size = indices.get_size() + + def fn(idx): + *prefix, h, w = idx + index_test = ops.index_expr(h * width + w, torch.int32) + h = h + padding[0] + w = w + padding[1] + phstart = ops.index_expr( + FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32 + ) + pwstart = ops.index_expr( + FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32 + ) + phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32) + pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32) + + phstart = ops.maximum(phstart, ops.constant(0, torch.int32)) + pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32)) + phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32)) + pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32)) + + gradient = None + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + ph = ops.add(phstart, ops.constant(ph_, torch.int32)) + pw = ops.add(pwstart, ops.constant(pw_, torch.int32)) + grad_index = [ + *prefix, + ops.indirect_indexing( + ops.minimum(ph, ops.sub(phend, ops.constant(1, torch.int32))), + indices_size[-2], + check=False, + ), + ops.indirect_indexing( + ops.minimum(pw, ops.sub(pwend, ops.constant(1, torch.int32))), + indices_size[-1], + check=False, + ), + ] + + index_actual = indices_loader(grad_index) + grad_part = grad_loader(grad_index) + check = ops.eq(index_actual, index_test) + + if gradient is None: + # don't need mask for 0, 0 + gradient = ops.where( + check, grad_part, ops.constant(0.0, torch.float32) + ) + else: + mask = ops.and_( + ops.and_( + ops.lt(ph, phend), + ops.lt(pw, pwend), + ), + check, + ) + gradient = ops.where(mask, ops.add(gradient, grad_part), gradient) + assert gradient is not None + return gradient + + out = Pointwise.create( + device=grad_output.get_device(), + dtype=grad_output.get_dtype(), + inner_fn=fn, + ranges=new_size, + ) + if is_channels_last: + return ir.ExternKernel.require_channels_last(out) + else: + return out + + +def pad_adaptive_loader(x, pad_val=0.0): + *_, h, w = x.get_size() + x_loader = x.make_loader() + + def load(prefix, increments, start_indices, end_indices): + ih, iw = increments + h_start_index, w_start_index = start_indices + h_end_index, w_end_index = end_indices + + mask = ops.and_( + ops.lt( + ops.index_expr(h_start_index + ih, torch.int64), + ops.index_expr(h_end_index, torch.int64), + ), + ops.lt( + ops.index_expr(w_start_index + iw, torch.int64), + ops.index_expr(w_end_index, torch.int64), + ), + ) + + return ops.masked( + mask, + lambda: x_loader([*prefix, h_start_index + ih, w_start_index + iw]), + pad_val, + ) + + return load + + +def compute_indices_adaptive_pooling(start_index, end_index, h_in, w_in, h_out, w_out): + h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in) + h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in) + + w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) + w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) + + return h_start_index, h_end_index, w_start_index, w_end_index + + +def _adaptive_pooling_fn( + start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn +): + h_in, w_in = in_sizes + h_out, w_out = out_sizes + + ( + h_start_index_fn, + h_end_index_fn, + w_start_index_fn, + w_end_index_fn, + ) = compute_indices_adaptive_pooling( + start_index, end_index, h_in, w_in, h_out, w_out + ) + + def fn(idx, loader): + *prefix, bh, bw = idx + + h_start_index = h_start_index_fn(bh) + h_end_index = h_end_index_fn(bh) + + w_start_index = w_start_index_fn(bw) + w_end_index = w_end_index_fn(bw) + + result = None + for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): + val = loader( + prefix, + [ih, iw], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + if result is None: + result = val + else: + result = pooling_fn(val, result) + return result + + return fn + + +def _adaptive_pooling_fn_with_idx( + start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn +): + h_in, w_in = in_sizes + h_out, w_out = out_sizes + + ( + h_start_index_fn, + h_end_index_fn, + w_start_index_fn, + w_end_index_fn, + ) = compute_indices_adaptive_pooling( + start_index, end_index, h_in, w_in, h_out, w_out + ) + + def fn(idx, loader): + *prefix, bh, bw = idx + + h_start_index = h_start_index_fn(bh) + h_end_index = h_end_index_fn(bh) + + w_start_index = w_start_index_fn(bw) + w_end_index = w_end_index_fn(bw) + + maxval = None + maxindex = None + for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): + val = loader( + prefix, + [ih, iw], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + + index = ops.index_expr( + (h_start_index + ih) * w_in + w_start_index + iw, torch.int64 + ) + + if maxindex is None: + maxindex = index + else: + maxindex = ops.where(ops.gt(val, maxval), index, maxindex) + + if maxval is None: + maxval = val + else: + maxval = pooling_fn(val, maxval) + + return maxindex + + return fn + + +fallback_adaptive_avg_pool2d = fallback_handler( + aten._adaptive_avg_pool2d.default, add_to_fallback_set=False +) + + +@register_lowering(aten._adaptive_avg_pool2d) +def _adaptive_avg_pool2d(x, output_size): + assert isinstance(x, TensorBox) + assert len(output_size) == 2 + x.realize_hint() + + *batch, h_in, w_in = x.get_size() + + h_in = V.graph.sizevars.evaluate_static_shape(h_in) + w_in = V.graph.sizevars.evaluate_static_shape(w_in) + + h_out, w_out = output_size + + # no-op if the same input and output + if h_in == h_out and w_in == w_out: + return clone(x) + + if h_out == 0 or w_out == 0: + o_size = [*batch, h_out, w_out] + return empty(o_size, dtype=x.get_dtype(), device=x.get_device()) + if h_in % h_out == 0 and w_in % w_out == 0: + kernel_size = [h_in // h_out, w_in // w_out] + return avg_pool2d(x, kernel_size) + + h_kernel_max = ceildiv((h_in + h_out - 1), h_out) + w_kernel_max = ceildiv((w_in + w_out - 1), w_out) + + new_size = list(batch) + [h_out, w_out] + dtype = x.get_dtype() + + window_size = h_kernel_max * w_kernel_max + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_adaptive_avg_pool2d(x, output_size) + + def start_index(index, out_dim, inp_dim): + return FloorDiv((index * inp_dim), out_dim) + + def end_index(index, out_dim, inp_dim): + return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) + + fn_sum = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + pooling_fn=ops.add, + ) + + ones_loader = pad_adaptive_loader(ones_like(x)) + + def fn(idx): + return ops.truediv( + fn_sum(idx, pad_adaptive_loader(x)), fn_sum(idx, ones_loader) + ) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + # TODO: should we force these to be realized? + return rv + + +fallback_adaptive_max_pool2d = fallback_handler( + aten.adaptive_max_pool2d.default, add_to_fallback_set=False +) + + +@register_lowering(aten.adaptive_max_pool2d) +def adaptive_max_pool2d(x, output_size): + assert isinstance(x, TensorBox) + assert len(output_size) == 2 + x.realize_hint() + + *batch, h_in, w_in = x.get_size() + + h_in = V.graph.sizevars.evaluate_static_shape(h_in) + w_in = V.graph.sizevars.evaluate_static_shape(w_in) + + h_out, w_out = output_size + + if h_out == 0 or w_out == 0: + o_size = [*batch, h_out, w_out] + return empty(o_size, dtype=x.get_dtype(), device=x.get_device()), empty( + o_size, dtype=torch.int64, device=x.get_device() + ) + + if h_in % h_out == 0 and w_in % w_out == 0: + # This is handled by a decomposition + raise ValueError + + h_kernel_max = ceildiv((h_in + h_out - 1), h_out) + w_kernel_max = ceildiv((w_in + w_out - 1), w_out) + + new_size = list(batch) + [h_out, w_out] + dtype = x.get_dtype() + + window_size = h_kernel_max * w_kernel_max + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_adaptive_max_pool2d(x, output_size) + + def start_index(index, out_dim, inp_dim): + return FloorDiv((index * inp_dim), out_dim) + + def end_index(index, out_dim, inp_dim): + return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) + + inner_func_max_val = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + pooling_fn=ops.maximum, + ) + + inner_func_max_idx = _adaptive_pooling_fn_with_idx( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + pooling_fn=ops.maximum, + ) + + def inner_fn_max_val(idx): + return inner_func_max_val(idx, pad_adaptive_loader(x, float("-inf"))) + + def inner_fn_max_idx(idx): + return inner_func_max_idx(idx, pad_adaptive_loader(x, float("-inf"))) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=inner_fn_max_val, + ranges=new_size, + ) + ri = Pointwise.create( + device=x.get_device(), + dtype=torch.int64, + inner_fn=inner_fn_max_idx, + ranges=new_size, + ) + return rv, ri + + +fallback_fractional_max_pool2d = fallback_handler( + aten.fractional_max_pool2d.default, add_to_fallback_set=False +) + + +def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim): + out_sz = out_sz[dim] + in_sz = in_sz[dim] + kernel_sz = kernel_sz[dim] + alpha = IntTrueDiv(in_sz - kernel_sz, out_sz - 1) + samples_loader = samples.make_loader() + + def load(prefix, i): + sample = samples_loader([*prefix, dim]) + i_expr = ops.index_expr(i, samples.get_dtype()) + alpha_expr = ops.index_expr(alpha, samples.get_dtype()) + seq_i = ops.floor((i_expr + sample) * alpha_expr) - ops.floor( + sample * alpha_expr + ) + seq_i = ops.to_dtype(seq_i, torch.int64) + + mask = ops.lt( + i_expr, + ops.index_expr(out_sz - 1, torch.int64), + ) + return ops.where(mask, seq_i, ops.index_expr(in_sz - kernel_sz, torch.int64)) + + return load + + +@register_lowering(aten.fractional_max_pool2d) +def fractional_max_pool2d(x, kernel_size, output_size, random_samples): + x.realize_hint() + *batch, inp_h, inp_w = x.get_size() + kernel_h, kernel_w = kernel_size + h_out, w_out = output_size + + if kernel_h * kernel_w >= 25: + return fallback_fractional_max_pool2d( + x, kernel_size, output_size, random_samples + ) + + gen_offsets_for_dim = functools.partial( + _fractional_pooling_offsets, + samples=random_samples, + in_sz=[inp_h, inp_w], + out_sz=output_size, + kernel_sz=kernel_size, + ) + + h_index_fn = gen_offsets_for_dim(dim=0) + w_index_fn = gen_offsets_for_dim(dim=1) + x_loader = x.make_loader() + + def fn(idx, return_index): + *prefix, bh, bw = idx + + h_start_index = ops.indirect_indexing(h_index_fn(prefix, bh), inp_h) + w_start_index = ops.indirect_indexing(w_index_fn(prefix, bw), inp_w) + + maxval = None + maxindex = None + for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])): + val = x_loader([*prefix, h_start_index + ih, w_start_index + iw]) + if return_index: + index = ops.index_expr( + (h_start_index + ih) * inp_w + w_start_index + iw, torch.int64 + ) + if maxindex is None: + maxindex = index + else: + maxindex = ops.where( + ops.or_(ops.gt(val, maxval), ops.isnan(val)), index, maxindex + ) + if maxval is None: + maxval = val + else: + maxval = ops.maximum(val, maxval) + if return_index: + return maxindex + else: + return maxval + + new_size = list(batch) + [h_out, w_out] + rv = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=functools.partial(fn, return_index=False), + ranges=new_size, + ) + + ri = Pointwise.create( + device=x.get_device(), + dtype=torch.int64, + inner_fn=functools.partial(fn, return_index=True), + ranges=new_size, + ) + return rv, ri + + +@register_lowering(aten.upsample_nearest2d_backward.default) +def upsample_nearest2d_backward( + x, output_size=None, input_size=None, scales_h=None, scales_w=None +): + x.realize_hint() + + *batch, inp_h, inp_w = x.get_size() + inp_h = V.graph.sizevars.evaluate_static_shape(inp_h) + inp_w = V.graph.sizevars.evaluate_static_shape(inp_w) + + *batch, out_h, out_w = input_size + + if inp_h % out_h == 0 and inp_w % out_w == 0: + return avg_pool2d(x, [inp_h // out_h, inp_w // out_w], divisor_override=1) + + h_kernel_max = ceildiv(inp_h, out_h) + w_kernel_max = ceildiv(inp_w, out_w) + + def start_index(index, out_dim, inp_dim): + return CeilDiv(index * inp_dim, sympy.sympify(out_dim)) + + def end_index(index, out_dim, inp_dim): + return start_index((index + 1), out_dim, inp_dim) + + fn_sum = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[inp_h, inp_w], + out_sizes=[out_h, out_w], + pooling_fn=ops.add, + ) + + def fn(idx): + return fn_sum(idx, pad_adaptive_loader(x)) + + rv = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=list(input_size), + ) + + return rv + + +fallback_avg_pool2d = fallback_handler( + aten.avg_pool2d.default, add_to_fallback_set=False +) +fallback_avg_pool3d = fallback_handler( + aten.avg_pool3d.default, add_to_fallback_set=False +) + + +@register_lowering(aten.avg_pool2d, type_promotion_kind=None) +def avg_pool2d( + x, + kernel_size, + stride=(), + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + return _avg_poolnd( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + dim=2, + ) + + +@register_lowering(aten.avg_pool3d, type_promotion_kind=None) +def avg_pool3d( + x, + kernel_size, + stride=(), + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + return _avg_poolnd( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + dim=3, + ) + + +def _avg_poolnd( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + dim, +): + if not stride: + stride = kernel_size + if not padding: + padding = [0] * dim + kernel_size = pad_listlike(kernel_size, dim) + stride = pad_listlike(stride, dim) + padding = pad_listlike(padding, dim) + + assert isinstance(x, TensorBox) + assert len(kernel_size) == dim + assert len(stride) == dim + assert len(padding) == dim + assert len(x.get_size()) in (dim + 1, dim + 2) + + x.realize_hint() + batch = x.get_size()[:-dim] + h = x.get_size()[-dim:] + + h_out, ceil_modes = zip( + *[ + pooling_size(h[i], i, kernel_size, stride, padding, ceil_mode) + for i in range(dim) + ] + ) + + if any(padding) or any(ceil_modes): + x_loader = constant_boundary_condition(x, 0.0, dim=dim) + had_padding = True + else: + x_loader = x.make_loader() + had_padding = False + + new_size = list(batch) + list(h_out) + dtype = x.get_dtype() + + window_size = functools.reduce(operator.mul, kernel_size) + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + if dim == 2: + fallback = fallback_avg_pool2d + elif dim == 3: + fallback = fallback_avg_pool3d + else: + raise ValueError(f"Unknown dim: {dim}") + + return fallback( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def fn_sum(idx, loader): + prefix = idx[:-dim] + b = idx[-dim:] + total = None + for ih in itertools.product(*[range(kernel_size[i]) for i in range(dim)]): + inp = [b[i] * stride[i] + ih[i] - padding[i] for i in range(dim)] + val = loader([*prefix, *inp]) + if total is None: + total = val + else: + total = ops.add(val, total) + return total + + if not had_padding or divisor_override: + if divisor_override: + scale = 1 / divisor_override + else: + scale = 1.0 / window_size + + def fn(idx): + return ops.mul(fn_sum(idx, x_loader), ops.constant(scale, dtype)) + + else: + + def fn(idx): + prefix = idx[:-dim] + bh = idx[-dim:] + + divide_factors = [] + for i in range(dim): + hstart = bh[i] * stride[i] - padding[i] + hend = sympy.Min(hstart + kernel_size[i], h[i] + padding[i]) + if not count_include_pad: + hstart = sympy.Max(hstart, 0) + hend = sympy.Min(hend, h[i]) + factor = ops.index_expr(hend - hstart, torch.int32) + divide_factors.append(factor) + divide_factor = functools.reduce(ops.mul, divide_factors) + return ops.truediv(fn_sum(idx, x_loader), divide_factor) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + # TODO(jansel): should we force these to be realized? + return rv + + +fallback_avg_pool2d_backward = fallback_handler( + aten.avg_pool2d_backward.default, add_to_fallback_set=False +) + + +@register_lowering(aten.avg_pool2d_backward, type_promotion_kind=None) +def avg_pool2d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override=None, +): + assert divisor_override is None or divisor_override != 0, "divisor must be not zero" + if not stride: + stride = kernel_size + if not padding: + padding = [0, 0] + + assert isinstance(grad_output, TensorBox) + assert isinstance(x, TensorBox) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(x.get_size()) in (3, 4) + + grad_output.realize_hint() # we will read this many times, so make sure it is computed + + *batch, height, width = x.get_size() + + h_out, ceil_mode1 = pooling_size(height, 0, kernel_size, stride, padding, ceil_mode) + w_out, ceil_mode2 = pooling_size(width, 1, kernel_size, stride, padding, ceil_mode) + + grad_loader = grad_output.make_loader() + + had_padding = padding[0] or padding[1] or ceil_mode1 or ceil_mode2 + + *_, pooled_height, pooled_width = grad_output.get_size() + new_size = list(x.get_size()) + dtype = x.get_dtype() + + h_window_size = max( + max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1) + for h in range(kernel_size[0] * 2) + ) + w_window_size = max( + max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1) + for w in range(kernel_size[1] * 2) + ) + + window_size = h_window_size * w_window_size + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_avg_pool2d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def compute_pool_size_without_padding(ph, pw): + """ + This computes the scaling factor that we will divide an element + by when `count_include_pad=False` + """ + stride_h = ops.constant(stride[0], torch.int32) + stride_w = ops.constant(stride[1], torch.int32) + pad_h = ops.constant(padding[0], torch.int32) + pad_w = ops.constant(padding[1], torch.int32) + kernel_h = ops.constant(kernel_size[0], torch.int32) + kernel_w = ops.constant(kernel_size[1], torch.int32) + hstart = ops.sub(ops.mul(ph, stride_h), pad_h) + wstart = ops.sub(ops.mul(pw, stride_w), pad_w) + hend = ops.minimum( + ops.add(hstart, kernel_h), + ops.add(ops.index_expr(height, torch.int32), pad_h), + ) + wend = ops.minimum( + ops.add(wstart, kernel_w), + ops.add(ops.index_expr(width, torch.int32), pad_w), + ) + hstart = ops.maximum(hstart, ops.constant(0, torch.int32)) + wstart = ops.maximum(wstart, ops.constant(0, torch.int32)) + hend = ops.minimum(hend, ops.index_expr(height, torch.int32)) + wend = ops.minimum(wend, ops.index_expr(width, torch.int32)) + divide_factor = ops.mul(ops.sub(hend, hstart), ops.sub(wend, wstart)) + return divide_factor + + def fn(idx): + *prefix, h, w = idx + h = h + padding[0] + w = w + padding[1] + phstart = ops.index_expr( + FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32 + ) + pwstart = ops.index_expr( + FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32 + ) + phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32) + pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32) + + phstart = ops.maximum(phstart, ops.constant(0, torch.int32)) + pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32)) + phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32)) + pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32)) + + gradient = None + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + ph = ops.add(phstart, ops.constant(ph_, torch.int32)) + pw = ops.add(pwstart, ops.constant(pw_, torch.int32)) + + if divisor_override is not None: + scale = divisor_override + elif count_include_pad or not had_padding: + scale = kernel_size[0] * kernel_size[1] + else: + scale = compute_pool_size_without_padding(ph, pw) + + part = ops.truediv( + grad_loader( + [ + *prefix, + ops.indirect_indexing( + ops.minimum( + ph, ops.sub(phend, ops.constant(1, torch.int32)) + ), + pooled_height, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + pw, ops.sub(pwend, ops.constant(1, torch.int32)) + ), + pooled_width, + check=False, + ), + ] + ), + scale, + ) + + mask = ops.and_( + ops.lt(ph, phend), + ops.lt(pw, pwend), + ) + if gradient is None: + gradient = ops.where(mask, part, ops.constant(0.0, torch.float32)) + else: + gradient = ops.where(mask, ops.add(gradient, part), gradient) + assert gradient is not None + return gradient + + rv = Pointwise.create( + device=grad_output.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + return rv + + +fallback_avg_pool3d_backward = fallback_handler( + aten.avg_pool3d_backward.default, add_to_fallback_set=False +) + + +@register_lowering(aten.avg_pool3d_backward, type_promotion_kind=None) +def avg_pool3d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override=None, +): + assert divisor_override is None or divisor_override != 0, "divisor must be not zero" + if not stride: + stride = kernel_size + if not padding: + padding = [0, 0, 0] + + assert isinstance(grad_output, TensorBox) + assert isinstance(x, TensorBox) + assert len(kernel_size) == 3 + assert len(stride) == 3 + assert len(padding) == 3 + assert len(x.get_size()) in (4, 5) + + grad_output.realize_hint() + + *batch, depth, height, width = x.get_size() + + d_out, ceil_mode_d = pooling_size(depth, 0, kernel_size, stride, padding, ceil_mode) + h_out, ceil_mode_h = pooling_size( + height, 1, kernel_size, stride, padding, ceil_mode + ) + w_out, ceil_mode_w = pooling_size(width, 2, kernel_size, stride, padding, ceil_mode) + + grad_loader = grad_output.make_loader() + had_padding = any(padding) or ceil_mode_d or ceil_mode_h or ceil_mode_w + + *_, pooled_depth, pooled_height, pooled_width = grad_output.get_size() + new_size = list(x.get_size()) + dtype = x.get_dtype() + + d_window_size, h_window_size, w_window_size = ( + max( + max(d // stride[i] - max(0, (d - kernel_size[i]) // stride[i]), 1) + for d in range(kernel_size[i] * 2) + ) + for i in range(3) + ) + + window_size = d_window_size * h_window_size * w_window_size + if window_size > 125: + # Kernel size too big. Results in hard-to-optimize Triton code. + return fallback_avg_pool3d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def compute_pool_size_without_padding(pd, ph, pw): + stride_d, stride_h, stride_w = (ops.constant(s, torch.int32) for s in stride) + pad_d, pad_h, pad_w = (ops.constant(p, torch.int32) for p in padding) + kernel_d, kernel_h, kernel_w = ( + ops.constant(k, torch.int32) for k in kernel_size + ) + + dstart, hstart, wstart = ( + ops.sub(ops.mul(p, s), pad) + for p, s, pad in zip( + [pd, ph, pw], [stride_d, stride_h, stride_w], [pad_d, pad_h, pad_w] + ) + ) + dend, hend, wend = ( + ops.minimum( + ops.add(start, k), ops.add(ops.index_expr(dim, torch.int32), pad) + ) + for start, k, dim, pad in zip( + [dstart, hstart, wstart], + [kernel_d, kernel_h, kernel_w], + [depth, height, width], + [pad_d, pad_h, pad_w], + ) + ) + dstart, hstart, wstart = ( + ops.maximum(start, ops.constant(0, torch.int32)) + for start in [dstart, hstart, wstart] + ) + dend, hend, wend = ( + ops.minimum(end, ops.index_expr(dim, torch.int32)) + for end, dim in zip([dend, hend, wend], [depth, height, width]) + ) + divide_factor = ops.mul( + ops.mul(ops.sub(dend, dstart), ops.sub(hend, hstart)), ops.sub(wend, wstart) + ) + return divide_factor + + def fn(idx): + *prefix, d, h, w = idx + d, h, w = (v + pad for v, pad in zip([d, h, w], padding)) + + pdstart, phstart, pwstart = ( + ops.index_expr(FloorDiv(v - k + s, s), torch.int32) + for v, k, s in zip([d, h, w], kernel_size, stride) + ) + + pdend, phend, pwend = ( + ops.index_expr(FloorDiv(v, s) + 1, torch.int32) + for v, s in zip([d, h, w], stride) + ) + + pdstart, phstart, pwstart = ( + ops.maximum(pstart, ops.constant(0, torch.int32)) + for pstart in [pdstart, phstart, pwstart] + ) + pdend, phend, pwend = ( + ops.minimum(pend, ops.index_expr(pooled_dim, torch.int32)) + for pend, pooled_dim in zip( + [pdend, phend, pwend], [pooled_depth, pooled_height, pooled_width] + ) + ) + + gradient = None + # Iterate over the 3D region to accumulate gradients + for pd_ in range(d_window_size): + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + pd, ph, pw = ( + ops.add(pstart, ops.constant(p_, torch.int32)) + for pstart, p_ in zip( + [pdstart, phstart, pwstart], [pd_, ph_, pw_] + ) + ) + + if divisor_override is not None: + scale = divisor_override + elif count_include_pad or not had_padding: + scale = kernel_size[0] * kernel_size[1] * kernel_size[2] + else: + scale = compute_pool_size_without_padding(pd, ph, pw) + + part = ops.truediv( + grad_loader( + [ + *prefix, + ops.indirect_indexing( + ops.minimum( + pd, ops.sub(pdend, ops.constant(1, torch.int32)) + ), + pooled_depth, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + ph, ops.sub(phend, ops.constant(1, torch.int32)) + ), + pooled_height, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + pw, ops.sub(pwend, ops.constant(1, torch.int32)) + ), + pooled_width, + check=False, + ), + ] + ), + scale, + ) + + mask = ops.and_( + ops.and_(ops.lt(pd, pdend), ops.lt(ph, phend)), + ops.lt(pw, pwend), + ) + if gradient is None: + gradient = ops.where( + mask, part, ops.constant(0.0, torch.float32) + ) + else: + gradient = ops.where(mask, ops.add(gradient, part), gradient) + assert gradient is not None + return gradient + + rv = Pointwise.create( + device=grad_output.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + return rv + + +def _validate_reduction_axis(x, axis): + size = x.get_size() + if isinstance(axis, int): + axis = [axis] + elif not axis: + axis = range(len(size)) + if len(size) == 0: + assert tuple(axis) in [(), (0,), (-1,)], f"invalid axis: {axis}" + return [] + axis = list(axis) + for i in range(len(axis)): + if axis[i] < 0: + axis[i] += len(size) if len(size) else 1 + assert 0 <= axis[i] < len(size) or (len(size) == 0 and axis[i] == 0) + assert len(set(axis)) == len(axis), "reduction axis not unique" + return axis + + +def _make_reduction_inner(x, *, axis, keepdims, dtype, override_return_dtype): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = set(_validate_reduction_axis(x, axis)) + + kept_sizes = [] + kept_idx = [] + reduced_sizes = [] + reduced_idx = [] + for i in range(len(size)): + if i in axis: + reduced_idx.append(i) + reduced_sizes.append(size[i]) + else: + kept_idx.append(i) + kept_sizes.append(size[i]) + + def loader(index, reduction_index): + assert len(reduction_index) == len(reduced_idx) + if keepdims: + assert len(index) == len(size) + index = [index[i] for i in kept_idx] + assert len(index) == len(kept_idx) + new_index = [None] * (len(index) + len(reduction_index)) + for idx, var in itertools.chain( + zip(kept_idx, index), zip(reduced_idx, reduction_index) + ): + new_index[idx] = var + return inner_loader(new_index) + + if keepdims: + new_size = list(size) + for i in reduced_idx: + new_size[i] = sympy.Integer(1) + else: + new_size = kept_sizes + + inner_loader = x.make_loader() + return dict( + device=x.get_device(), + dst_dtype=override_return_dtype or x.get_dtype(), + src_dtype=x.get_dtype(), + inner_fn=loader, + ranges=new_size, + reduction_ranges=reduced_sizes, + ) + +def should_not_sum(a, b, keepdims): + if not b: + return False + for i in b: + if i < 0 or i >= len(a): + return False + if a[i] != 1: + return False + + unique_indices = reversed(b) + if not keepdims: + for i in unique_indices: + del a[i] + return a + return a + +def make_reduction(reduction_type: str, override_return_dtype=None): + def inner(x, axis=None, keepdims=False, *, dtype=None): + if dtype is not None: + x = to_dtype(x, dtype) + if axis and axis[-1] < 0: + offset = len(x.get_size()) + axis = [ax + offset for ax in axis] + axis = sorted(axis) + kwargs = _make_reduction_inner( + x, + axis=axis, + keepdims=keepdims, + dtype=dtype, + override_return_dtype=override_return_dtype, + ) + new_size = should_not_sum(x.get_size(), axis, keepdims) + if new_size: + node_name = f'reshape_{next(node_id)}' + input_graphs = fetch_graphs([x, new_size]) + new_graph = merge_traced_graphs(input_graphs, aten.reshape, node_name) + else: + node_name = f'reduction_{next(node_id)}' + input_graphs = fetch_graphs([x, axis if axis is not None else list(range(len(x.get_size())))]) + new_graph = merge_traced_graphs(input_graphs, reduction_type_to_aten_fn[reduction_type], + node_name, keepdim=keepdims) + result = Reduction.create(reduction_type=reduction_type, + input_node=x, + node_name=node_name, + traced_graph=new_graph, + **kwargs) + if isinstance( + result.data.data, Reduction + ): # Only realize if reduction isn't unrolled + result.realize() + return result + + return inner + + +def _make_scan_inner(x, *, axis, dtype): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = _validate_dim(x, axis) + + return dict( + device=x.get_device(), + dtypes=(x.get_dtype(),), + inner_fns=(x.make_loader(),), + size=x.get_size(), + axis=axis, + ) + + +@register_lowering(aten.mean) +def mean(x, axis=None, keepdim=False, *, dtype=None): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + # compute in higher-precision until end of mean lowering + output_dtype = x.get_dtype() + if output_dtype in (torch.float16, torch.bfloat16): + x = to_dtype(x, torch.float) + sum_result = sum_(x, axis, keepdim) + denom = sympy_product(size[i] for i in axis) + denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + return to_dtype(div(sum_result, denom), output_dtype) + + +def var_mean_sum_(x, axis, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + x_mean = mean(x, axis, keepdim=True) + if return_mean: + x_mean.realize() + + diffs = square(sub(x, x_mean)) + sum_result = sum_(diffs, axis, keepdim) + + denom = sympy_product(size[i] for i in axis) + if correction: + denom = sympy.Max(denom - correction, 0) + denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + x_var = div(sum_result, denom) + if not return_mean: + return (x_var,) + + x_mean = x_mean if keepdim else squeeze(x_mean, axis) + return x_var, x_mean + + +def use_two_step_variance(x, axis, keepdim): + # Instead of unrolling welford, just unroll the simpler two-step var + axis = _validate_reduction_axis(x, axis) + kwargs = _make_reduction_inner( + x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None + ) + + ranges = kwargs["ranges"] + reduction_numel = sympy_product(kwargs["reduction_ranges"]) + return ( + isinstance(reduction_numel, sympy.Integer) + and int(reduction_numel) < config.unroll_reductions_threshold + and sympy_product(ranges) != 1 + ) + + +def var_mean_welford_(x, axis, *, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + kwargs = _make_reduction_inner( + x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None + ) + loader = kwargs.pop("inner_fn") + kwargs.pop("dst_dtype") + kwargs.pop("src_dtype") + + mean, m2, _ = ir.WelfordReduction.create( + inner_fns=(loader,), + reduction_type="welford_reduce", + dtype=x.get_dtype(), + **kwargs, + ) + m2.realize() + + dtype = x.get_dtype() + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + rnumel = sympy_product(size[i] for i in axis) + + def get_constant_or_index_expr(x, dtype): + if isinstance(x, sympy.Expr) and not x.is_number: + return ops.to_dtype(ops.index_expr(x, torch.int64), dtype) + return ops.constant(x, dtype) + + def scale_fn(data): + c = get_constant_or_index_expr(correction, dtype) + N = get_constant_or_index_expr(rnumel, dtype) + zero = ops.constant(0, dtype) + return data / ops.maximum(zero, N - c) + + var = make_pointwise(scale_fn)(m2) + + if return_mean: + mean.realize() + return var, mean + return (var,) + + +def var_mean_helper_(x, *, axis, correction, keepdim, return_mean): + out_dtype = x.get_dtype() + compute_dtype = get_computation_dtype(out_dtype) + x = to_dtype(x, compute_dtype, copy=False) + kwargs = dict( + x=x, + axis=axis, + correction=correction, + keepdim=keepdim, + return_mean=return_mean, + ) + output = ( + var_mean_sum_(**kwargs) + # if use_two_step_variance(x, axis=axis, keepdim=keepdim) + # else var_mean_welford_(**kwargs) + ) + output = tuple(to_dtype(x, out_dtype, copy=False) for x in output) + return output[0] if not return_mean else output + + +@register_lowering([aten.var, prims.var]) +def var_(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False + ) + + +@register_lowering(aten.var_mean) +def var_mean(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True + ) + + +def pow_recursive(x, y, dtype): + if y < 0: + return pow_recursive(ops.reciprocal(x), -y, dtype) + if y == 0: + return ops.constant(1, dtype) + if y == 1: + return x + + result = pow_recursive(x, y // 2, dtype) + result = ops.mul(result, result) + if (y % 2) == 1: + result = ops.mul(result, x) + return result + + +@make_pointwise +@register_to_aten(aten_fn=aten.pow) +def pow_native(a, b): + return ops.pow(a, b) + + +fallback_pow_tensor_tensor = fallback_handler( + aten.pow.Tensor_Tensor, add_to_fallback_set=False +) +fallback_pow_scalar = fallback_handler(aten.pow.Scalar, add_to_fallback_set=False) +fallback_pow_tensor_scalar = fallback_handler( + aten.pow.Tensor_Scalar, add_to_fallback_set=False +) + + +@register_lowering(aten.pow, broadcast=True) +def pow(a, b): + if isinstance(b, float) and b == int(b): + return pow(a, int(b)) + elif isinstance(b, float) and b == 0.5: + return sqrt(a) + elif isinstance(b, int) and b == 1: + return clone(a) + + # Type promotion ensures all tensor arguments have the same type + dtype = next(x.get_dtype() for x in (a, b) if isinstance(x, ir.TensorBox)) + is_integer_pow = is_integer_dtype(dtype) + + # Optimize away small fixed powers, or for integers avoid falling back to ATen + embed_exponent = isinstance(b, int) and ( + -32 < b < 32 or (is_integer_pow and b >= 0) + ) + if embed_exponent: + loader = a.make_loader() + + def fn(idx): + return pow_recursive(loader(idx), b, a.get_dtype()) + + input_graphs = fetch_graphs([a, b]) + node_name = f'pointwise_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.pow, node_name) + + return Pointwise.create( + device=a.get_device(), + dtype=a.get_dtype(), + inner_fn=fn, + ranges=a.get_size(), + node_name=node_name, + traced_graph=new_graph, + ) + + if isinstance(a, Number): + if a == 1: + return full_like(b, 1) + if a == 2 and is_float_dtype(b.get_dtype()): + return exp2(b) + + if is_integer_pow: + # ops.pow doesn't work for integers + if isinstance(a, Number): + return fallback_pow_scalar(a, b) + elif isinstance(b, Number): + return fallback_pow_tensor_scalar(a, b) + else: + return fallback_pow_tensor_tensor(a, b) + + return pow_native(a, b) + + +def mutate_to(changed, val, unsafe_alias=False): + if isinstance(changed, TensorBox): + changed_data = changed.data + else: + changed_data = changed + if isinstance(val, TensorBox): + val = val.data + + if not isinstance(val, ir.StorageBox): + # introduce a copy to handle views + input_graphs = fetch_graphs([changed, val]) + node_name = f'copy__{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.copy_, node_name) + val = Pointwise.create( + device=changed.get_device(), + dtype=changed.get_dtype(), + inner_fn=val.make_loader(), + ranges=changed.get_size(), + traced_graph=new_graph, + node_name=node_name + ).data + assert isinstance(val, ir.StorageBox) + + if isinstance(changed_data, ir.StorageBox) and not ( + changed_data.is_input_buffer() + # In AOTI, module parameters and buffers are not lifted as graph inputs + or changed_data.is_module_buffer() + or isinstance(changed_data.data, ir.NopKernel) + ): + # Fast path, just swing the data pointer + val.realize() + changed_data.data = val.data + return changed + + ir.MutationLayoutSHOULDREMOVE.realize_into( + val, changed_data, unsafe_alias=unsafe_alias + ) + return changed + + +@register_lowering(aten.fill_) +def fill_(x, fill_value): + return mutate_to(x, full_like(x, fill_value)) + + +@register_lowering(aten.copy_, type_promotion_kind=None) +def copy_(dst, src, non_blocking=False): + if dst is src: + # dst.copy_(dst) can happen from the reinplacing pass + return dst + src = to_device(src, dst.get_device()) + src = to_dtype(src, dst.get_dtype()) + src = expand(src, dst.get_size()) + return mutate_to(dst, src) + + +@make_pointwise +def floordiv(a, b): + fn = ops.floordiv + return fn(a, b) + + +@make_pointwise +def truncdiv(a, b): + fn = ops.truncdiv + + return fn(a, b) + + +@register_lowering(aten.div, broadcast=True) +def div_mode(a, b, rounding_mode=None): + both_integer = is_integer_type(a) and is_integer_type(b) + both_boolean = is_boolean_type(a) and is_boolean_type(b) + + # floordiv and truncdiv need special handling for integer tensors on Triton, + # see the discussion at https://github.com/openai/triton/issues/605 + if rounding_mode == "floor": + assert not both_boolean, "floordiv operands can not be boolean at the same time" + return fallback_handler(aten.div.Tensor_mode)(a, b, rounding_mode=rounding_mode) + return floordiv(a, b) if both_integer else floor(div(a, b)) + if rounding_mode == "trunc": + assert not both_boolean, "truncdiv operands can not be boolean at the same time" + return fallback_handler(aten.div.Tensor_mode)(a, b, rounding_mode=rounding_mode) + return truncdiv(a, b) if both_integer else trunc(div(a, b)) + return div(a, b) + + +@register_lowering([aten.mul], broadcast=True) +def mul(a, b): + both_bool = is_boolean_type(a) and is_boolean_type(b) + if both_bool: + return logical_and(a, b) + else: + fn = ops_wrapper(aten.mul.__name__) + fn = register_fn_to_aten_fn(fn, aten.mul) + return make_pointwise(fn)(a, b) + + +def get_constant_value(x: ir.IRNode) -> Optional[ir.Constant]: + """Try convert an arbitrary IR node into an ir.Constant value""" + + # First try unwrapping the IRNode to see if it is already an ir.Constant + # Optional step, but avoids unnecessary inner_fn evaluation. + if isinstance(x, ir.MutableBox): + return get_constant_value(x.data) + if isinstance(x, ir.BaseView): + return get_constant_value(x.unwrap_view()) + if isinstance(x, ir.Constant): + return x + + # If the unwrapped node is not an ir.Constant, try evaluating inner_fn + # to see if the returned value is from an `ops.constant` call + if not isinstance(x, ir.Loops): + return None + + handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device()) + with V.set_ops_handler(handler), patch.object( + ir.FlexibleLayout, "allow_indexing", True + ): + out = x.inner_fn(*x.inner_fn_args()) + + assert isinstance(out, torch._inductor.virtualized.OpsValue) + if isinstance(out.value, ir.Constant): + return out.value + return None + + +# NOTE: prims.div maps to a / b in C, so performs truncation division on +# integer inputs and true division for floating and complex inputs. +@register_lowering([prims.div], broadcast=True) +def div_prim(a, b): + is_integral = all(is_boolean_type(x) or is_integer_type(x) for x in [a, b]) + + if is_integral: + return truncdiv(a, b) + + # if (divisor := get_constant_value(b)) is not None: + # # Replace divide by constant with multiply by reciprocal + # if divisor.value == 0: + # reciprocal = math.copysign(float("inf"), divisor.value) + # else: + # reciprocal = 1.0 / divisor.value + # return mul(a, reciprocal) + + def fn(*args): + return ops.truediv(*args) + + fn = register_fn_to_aten_fn(fn, aten.div) + return make_pointwise(fn)(a, b) + + +@register_lowering( + [aten.true_divide, aten.div.Tensor], + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def div(a, b): + a, b = promote_constants( + (a, b), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + return div_prim(a, b) + + +@register_lowering([aten.reciprocal], broadcast=True,) +def reciprocal(a): + return div(1.0, a) + + +@register_lowering([aten.fmod, prims.fmod], broadcast=True) +def fmod(a, b): + is_integral = is_boolean_type(a) or is_integer_type(a) + + if is_integral: + + def fn(a, b): + return ops.mod(a, b) + + else: + + def fn(a, b): + return ops.fmod(a, b) + + return make_pointwise(fn)(a, b) + + +@register_lowering(aten.rsqrt) +def rsqrt(x): + dtype = x.get_dtype() + if is_integer_dtype(dtype) or is_boolean_dtype(dtype): + x = to_dtype(x, torch.get_default_dtype()) + + def _rsqrt(x): + return ops.rsqrt(x) + + register_fn_to_aten_fn(_rsqrt, aten.rsqrt) + return make_pointwise(_rsqrt)(x) + +def split_last_continuous(lst): + n = len(lst) + if n == 1: + return lst, [] + + i = n - 2 + while i >= 0: + if lst[i] + 1 != lst[i + 1]: + break + i -= 1 + + start = i + 1 + last_part = lst[start:] + remaining = lst[:start] + return last_part, remaining + +@register_lowering([aten.sum, prims.sum]) +def sum_(x, axis=None, keepdims=False, *, dtype=None): + if axis and axis[-1] < 0: + offset = len(x.get_size()) + axis = [ax + offset for ax in axis] + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + if anir_config.disable_any_pbr and axis and len(x.get_size()) > 1 and axis[0] < len(x.get_size()) - len(axis): + fn = fallback_handler( + aten.sum.dim_IntList, add_to_fallback_set=False + ) + else: + fn = make_reduction("sum", override_return_dtype=dtype) + r = fn(x, axis, keepdims, dtype=dtype) + return r + + +fallback_cumsum = fallback_handler(aten.cumsum.default) +fallback_cumprod = fallback_handler(aten.cumprod.default) +fallback_logcumsumexp = fallback_handler(aten.logcumsumexp.default) +fallback_cummax = fallback_handler(aten.cummax.default) +fallback_cummin = fallback_handler(aten.cummin.default) + + +@register_lowering(aten.cumsum) +def cumsum(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + if len(x.get_size()) == 0: + assert axis in [0, -1] + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + + def combine_fn(a_tuple, b_tuple): + (a,) = a_tuple + (b,) = b_tuple + return (ops.add(a, b),) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn) + if result is None: + return fallback_cumsum(x, dim=axis, dtype=dtype) + return result + + +@register_lowering(aten.cumprod) +def cumprod(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + if len(x.get_size()) == 0: + assert axis in [0, -1] + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + + def combine_fn(a_tuple, b_tuple): + (a,) = a_tuple + (b,) = b_tuple + return (ops.mul(a, b),) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn) + if result is None: + return fallback_cumprod(x, dim=axis, dtype=dtype) + return result + + +@register_lowering(aten.logcumsumexp) +def logcumsumexp(x, dim): + def log_add_exp_helper(a_tuple, b_tuple): + (a,) = a_tuple + (b,) = b_tuple + min_v = ops.minimum(a, b) + max_v = ops.maximum(a, b) + mask = (min_v != max_v) | (~ops.isinf(min_v)) + return (ops.where(mask, ops.log1p(ops.exp(min_v - max_v)) + max_v, a),) + + dtype = x.get_dtype() + if len(x.get_size()) == 0: + assert dim in [0, -1] + return clone(x) + + kwargs = _make_scan_inner(x, axis=dim, dtype=dtype) + (result,) = ir.Scan.create(**kwargs, combine_fn=log_add_exp_helper) + if result is None: + return fallback_logcumsumexp(x, dim=dim) + return result + + +@register_lowering(aten.cummax, type_promotion_kind=None) +def cummax(x, axis=None): + if len(x.get_size()) == 0: + assert axis in [0, -1] + return clone(x), empty_like(x, dtype=torch.int64) + + dtype = x.get_dtype() + combine_fn = ir.get_reduction_combine_fn( + "argmax", dtype=dtype, arg_break_ties_left=False + ) + + min_value = ( + False + if dtype is torch.bool + else ( + torch.finfo(dtype).min + if dtype.is_floating_point + else torch.iinfo(dtype).min + ) + ) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + kwargs["dtypes"] = (dtype, torch.int64) + kwargs["inner_fns"] = (x.make_loader(), lambda _: "rindex") + values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type] # next PR + if values is None: + return fallback_cummax(x, dim=axis) + return values, indices + + +@register_lowering(aten.cummin, type_promotion_kind=None) +def cummin(x, axis=None): + if len(x.get_size()) == 0: + assert axis in [0, -1] + return clone(x), empty_like(x, dtype=torch.int64) + + dtype = x.get_dtype() + combine_fn = ir.get_reduction_combine_fn( + "argmin", dtype=dtype, arg_break_ties_left=False + ) + + max_value = ( + True + if dtype is torch.bool + else ( + torch.finfo(dtype).max + if dtype.is_floating_point + else torch.iinfo(dtype).max + ) + ) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + kwargs["dtypes"] = (dtype, torch.int64) + kwargs["inner_fns"] = (x.make_loader(), lambda _: "rindex") + values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type] # next PR + if values is None: + return fallback_cummin(x, dim=axis) + return values, indices + + +@register_lowering(aten.prod) +def prod(x, axis=None, keepdims=False, *, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + fn = make_reduction("prod", override_return_dtype=dtype) + return fn(x, axis, keepdims, dtype=dtype) + + +@register_lowering(aten.any) +def reduce_any(x, dim=None, keepdim=False): + x = to_dtype(x, torch.bool) + return make_reduction("any")(x, axis=dim, keepdims=keepdim) + + +@register_lowering(aten.max, type_promotion_kind=None) +def reduce_max(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amax(x, axis=dim, keepdims=keepdim), + reduce_argmax(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amax(x, axis=None, keepdims=keepdim) + + +@register_lowering(aten.min, type_promotion_kind=None) +def reduce_min(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amin(x, axis=dim, keepdims=keepdim), + reduce_argmin(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amin(x, axis=None, keepdims=keepdim) + + +register_lowering(prims.xor_sum)(make_reduction("xor_sum")) +reduce_amax = register_lowering(aten.amax)(make_reduction("max")) +reduce_amin = register_lowering(aten.amin)(make_reduction("min")) +reduce_argmax = register_lowering(aten.argmax)( + make_reduction("argmax", override_return_dtype=torch.int64) +) +reduce_argmin = register_lowering(aten.argmin)( + make_reduction("argmin", override_return_dtype=torch.int64) +) + +add = register_pointwise( + aten.add, allow_alpha=True, override_fn_when_input_bool="logical_or" +) + +sort_fallback = fallback_handler(aten.sort.stable, add_to_fallback_set=False) + + +@register_lowering(aten.sort.stable, type_promotion_kind=None) +def sort_stable(x, *, stable=None, dim=-1, descending=False): + if stable is None: + stable = False + + shape = x.get_size() + device = x.get_device() + dim = canonicalize_dim(len(shape), dim) + if len(shape) == 0: + return clone(x), _full(0, device, torch.int64, shape) + + dim_size = shape[dim] if len(shape) else 1 + if not V.graph.sizevars.statically_known_lt(dim_size, torch.iinfo(torch.int16).max): + return sort_fallback(x, stable=stable, dim=dim, descending=descending) + + indices = iota( + dim_size, start=0, step=1, dtype=torch.int16, device=device, requires_grad=False + ) + view_shape = [1] * len(shape) + if len(shape): + view_shape[dim] = dim_size + indices = view(indices, view_shape) + indices = expand(indices, shape) + + values, indices = ir.Sort.create( + device=device, + dtypes=(x.dtype, indices.dtype), + inner_fns=(x.make_loader(), indices.make_loader()), + size=shape, + axis=dim, + stable=stable, + descending=descending, + ) + if values is None: + return sort_fallback(x, stable=stable, dim=dim, descending=descending) + + assert indices is not None + return values, to_dtype(indices, torch.int64) + + +@register_lowering(aten.sort.default, type_promotion_kind=None) +def sort(x, dim=-1, descending=False): + return sort_stable(x, stable=False, dim=dim, descending=descending) + + +def register_pointwise_numeric(op, name=None, triton_fallback=None): + return register_pointwise( + op, + name=name, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + triton_fallback=triton_fallback, + ) + + +def register_pointwise_numeric_ldf64(op): + return register_pointwise( + op, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + use_libdevice_for_f64=True, + ) + +@register_lowering([aten.neg]) +def neg(a): + if a.get_dtype() in (torch.int32, torch.int64): + return mul(a, -1) + fn = ops_wrapper(aten.neg.__name__) + fn = register_fn_to_aten_fn(fn, aten.neg) + return make_pointwise(fn)(a) + + +exp = register_pointwise_numeric_ldf64(aten.exp) +exp2 = register_pointwise_numeric(aten.exp2) +expm1 = register_pointwise_numeric(aten.expm1) +relu = register_pointwise(aten.relu) +sigmoid = register_pointwise_numeric_ldf64(aten.sigmoid) +sqrt = register_pointwise_numeric_ldf64(aten.sqrt) +square = register_pointwise(aten.square) +sub = register_pointwise(aten.sub, allow_alpha=True) +register_pointwise_numeric_ldf64(aten.cos) +register_pointwise_numeric_ldf64(aten.sin) +abs = register_pointwise(aten.abs) +bitwise_and = register_pointwise(aten.bitwise_and) +bitwise_left_shift = register_pointwise(aten.bitwise_left_shift) +bitwise_not = register_pointwise( + aten.bitwise_not, override_fn_when_input_bool="logical_not" +) +bitwise_or = register_pointwise(aten.bitwise_or) +bitwise_right_shift = register_pointwise(aten.bitwise_right_shift) +bitwise_xor = register_pointwise(aten.bitwise_xor) +register_pointwise_numeric(aten.lgamma) +erf = register_pointwise_numeric(aten.erf) +register_lowering( + aten.special_erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT +)(erf) + +register_pointwise_numeric(aten.log1p) +register_pointwise_numeric(aten.tan) +register_pointwise_numeric(aten.tanh) +register_pointwise_numeric_ldf64(aten.log) +logical_and = register_pointwise( + aten.logical_and, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_not = register_pointwise( + aten.logical_not, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_or = register_pointwise( + aten.logical_or, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_xor = register_pointwise( + aten.logical_xor, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +maximum = register_pointwise(aten.maximum) +minimum = register_pointwise(aten.minimum) +clamp_min = register_pointwise(aten.clamp_min, name='maximum') +clamp_max = register_pointwise(aten.clamp_max, name='minimum') +register_lowering(aten.clamp_min)(clamp_min) +register_lowering(aten.clamp_max)(clamp_max) +abs = register_pointwise(aten.abs) +# reciprocal = register_pointwise_numeric(aten.reciprocal) +register_pointwise(aten.remainder) +sign = register_pointwise(aten.sign, override_fn_when_input_bool="identity") +register_pointwise(aten.ceil) +register_pointwise(aten.signbit, override_return_dtype=torch.bool) + +register_lowering(aten._neg_view)(neg) + +register_pointwise(aten.le, override_return_dtype=torch.bool) +register_pointwise(aten.lt, override_return_dtype=torch.bool) +register_pointwise(aten.ge, override_return_dtype=torch.bool) +gt = register_pointwise(aten.gt, override_return_dtype=torch.bool) +register_pointwise(aten.eq, override_return_dtype=torch.bool) +register_pointwise(aten.ne, override_return_dtype=torch.bool) + +register_pointwise_numeric(aten.cosh) +register_pointwise_numeric(aten.sinh) +register_pointwise_numeric(aten.acos) +register_pointwise_numeric(aten.acosh) +register_pointwise_numeric(aten.asin) +register_pointwise_numeric(aten.asinh) +register_pointwise_numeric(aten.atan2) +register_pointwise_numeric(aten.atan) +register_pointwise_numeric(aten.atanh) +register_pointwise_numeric(aten.copysign) +register_pointwise_numeric(aten.erfc) +register_pointwise_numeric(aten.erfinv) +register_pointwise_numeric(aten.hypot) +register_pointwise_numeric(aten.log10) +register_pointwise_numeric(aten.log2) +register_pointwise_numeric(aten.nextafter) + +from torch._inductor.codegen.common import BackendFeature, pointwise_overrides_data + + +def _get_pointwise_overrides(ns, name): + data = pointwise_overrides_data[name] + op = getattr(ns, data.name, None) + if op is None: + return + + def make_triton_fallback(op): + if data.triton is None: + return fallback_handler(op) + + if isinstance(op, torch._ops.OpOverloadPacket): + for olname in op.overloads(): + ol = getattr(op, olname) + yield ol, data.type_promotion_kind, make_triton_fallback(ol) + else: + yield op, data.type_promotion_kind, make_triton_fallback(op) + + +for name in pointwise_overrides_data: + for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides( + aten, name + ): + register_pointwise( + op, + name=name, + type_promotion_kind=type_promotion_kind, + triton_fallback=triton_fallback, + ) + + for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides( + prims, name + ): + register_pointwise( + op, + name=name, + type_promotion_kind=type_promotion_kind, + triton_fallback=triton_fallback, + ) + + +foreach_add_list = register_foreach_pointwise( + aten._foreach_add.List, add, allow_alpha=True +) +foreach_add_scalar = register_foreach_pointwise( + aten._foreach_add.Scalar, add, allow_alpha=True +) +register_foreach_pointwise(aten._foreach_add.Tensor, add, allow_alpha=True) +foreach_mul_list = register_foreach_pointwise(aten._foreach_mul.List, mul) +register_foreach_pointwise(aten._foreach_mul.Tensor, mul) +foreach_mul_scalar = register_foreach_pointwise(aten._foreach_mul.Scalar, mul) +register_foreach_pointwise(aten._foreach_sub.List, sub) +register_foreach_pointwise(aten._foreach_sub.Scalar, sub) +register_foreach_pointwise(aten._foreach_neg.default, neg) +register_foreach_pointwise(aten._foreach_abs.default, abs) +register_foreach_pointwise(aten._foreach_pow.Scalar, pow) +register_foreach_pointwise(aten._foreach_pow.List, pow) +register_foreach_pointwise(aten._foreach_pow.ScalarAndTensor, pow) +foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div) +register_foreach_pointwise(aten._foreach_div.Tensor, div) +foreach_div_scalar = register_foreach_pointwise(aten._foreach_div.Scalar, div) +register_foreach_pointwise(aten._foreach_sqrt, sqrt) +register_foreach_pointwise(aten._foreach_maximum.List, maximum) +register_foreach_pointwise(aten._foreach_maximum.Scalar, maximum) +register_foreach_pointwise(aten._foreach_minimum.List, minimum) +register_foreach_pointwise(aten._foreach_minimum.Scalar, minimum) +register_foreach_pointwise(aten._foreach_clamp_min.List, maximum) +register_foreach_pointwise(aten._foreach_clamp_min.Scalar, maximum) +register_foreach_pointwise(aten._foreach_clamp_max.List, minimum) +register_foreach_pointwise(aten._foreach_clamp_max.Scalar, minimum) +register_foreach_pointwise(aten._foreach_reciprocal, reciprocal) +register_foreach_pointwise(aten._foreach_sign, sign) +register_foreach_pointwise(aten._foreach_copy, copy) + + +# these are only encountered as outputs of the graph +# reinplacing epilogue copies improves compile time +# by removing extra buffers sent to the scheduler. +def register_foreach_inplace(aten_op, outplace_aten_op, outplace_op): + inplaceable_foreach_ops[outplace_aten_op] = aten_op + inplace_foreach_ops.add(aten_op) + + def fn(*args, **kwargs): + results = outplace_op(*args, **kwargs) + mut_results = [] + for arg, result in zip(args[0], results): + mut_results.append(mutate_to(arg, result, unsafe_alias=True)) + + return mut_results + + _register_foreach_lowering(aten_op, fn) + + +register_foreach_inplace( + aten._foreach_add_.List, aten._foreach_add.List, foreach_add_list +) +register_foreach_inplace( + aten._foreach_add_.Scalar, aten._foreach_add.Scalar, foreach_add_scalar +) +register_foreach_inplace( + aten._foreach_mul_.List, aten._foreach_mul.List, foreach_mul_list +) +register_foreach_inplace( + aten._foreach_mul_.Scalar, aten._foreach_mul.Scalar, foreach_mul_scalar +) +register_foreach_inplace( + aten._foreach_div_.List, aten._foreach_div.List, foreach_div_list +) +register_foreach_inplace( + aten._foreach_div_.Scalar, aten._foreach_div.Scalar, foreach_div_scalar +) + + +def register_inplace(aten_op, outplace_op): + @register_lowering(aten_op, type_promotion_kind=None) + def fn(*args, **kwargs): + result = outplace_op(*args, **kwargs) + result = to_dtype(result, args[0].get_dtype()) + return mutate_to(args[0], result) + + return fn + + +register_inplace(aten.add_, add) +register_inplace(aten.bitwise_and_, bitwise_and) +register_inplace(aten.bitwise_left_shift_, bitwise_left_shift) +register_inplace(aten.bitwise_not_, bitwise_not) +register_inplace(aten.bitwise_or_, bitwise_or) +register_inplace(aten.bitwise_right_shift_, bitwise_right_shift) +register_inplace(aten.bitwise_xor_, bitwise_xor) +register_inplace(aten.mul_, mul) +register_inplace(aten.div_.Tensor, div) +register_inplace(aten.div_.Tensor_mode, div_mode) +register_inplace(aten.logical_and_, logical_and) +register_inplace(aten.logical_not_, logical_not) +register_inplace(aten.logical_or_, logical_or) +register_inplace(aten.logical_xor_, logical_xor) +register_inplace(aten.sub_, sub) +register_inplace(aten.relu_, relu) +register_inplace(aten.sigmoid_, sigmoid) + + +register_lowering(aten.__and__)(bitwise_and) +register_lowering(aten.__lshift__)(bitwise_left_shift) +register_lowering(aten.__or__)(bitwise_or) +register_lowering(aten.__rshift__)(bitwise_right_shift) +register_lowering(aten.__xor__)(bitwise_xor) + +register_inplace(aten.__iand__, aten.__and__) +register_inplace(aten.__ilshift__, aten.__lshift__) +register_inplace(aten.__ior__, aten.__or__) +register_inplace(aten.__irshift__, aten.__rshift__) +register_inplace(aten.__ixor__, aten.__xor__) + + +@register_lowering(aten.sym_constrain_range) +def sym_constrain_range(a, min=None, max=None): + return None + + +@register_lowering(aten.sym_size.int) +def sym_size(a, dim): + val = V.graph.current_node.meta["val"] + # Note [Can val be an int?] + # ~~~~~~~~~~~~~~~~~~~~~~~~~ + # In principle, someone could construct an FX graph where + # a call to size/stride has a val that is a plain int (not + # SymInt). However, we will maintain the invariant that + # this is not possible: if you are constructing an FX graph + # where there is a call to size/stride that returns an + # int, but you KNOW that int must always be a constant, + # then you do not need trace that call at all (and just + # constant propagate the integer as is.) + assert isinstance(val, torch.SymInt) + return val.node.expr + +@register_lowering(aten.sym_stride.int) +def sym_stride(a, dim): + val = V.graph.current_node.meta["val"] + # See Note [Can val be an int?] + assert isinstance(val, torch.SymInt) + return val.node.expr + + +@register_lowering(aten.sym_numel) +def sym_numel(a): + return a.get_numel() + + +for method, func in magic_methods.items(): + register_lowering(method_to_operator(method))(func) + + +@register_lowering(torch.sym_sum) +def sym_sum(args): + return sympy.Add(*args) + + +@register_lowering(aten._foobar) +def foobar(self, *args, **kwargs): + raise NotImplementedError("Helpful for debugging") + + +@register_lowering(torch.ops._inductor_test.realize) +def _realize(x): + x.realize() + return clone(x) + + +@register_lowering(torch.ops.inductor.resize_storage_bytes_) +def resize_storage_bytes_(variable, new_size): + variable.realize() + ir.ResizeStorageBytes(variable, new_size) + return variable + + +@register_lowering(torch.ops.aten.set_.source_Tensor) +def set__source_tensor(self, source_tensor): + self.realize() + source_tensor.realize() + return TensorBox.create(ir.SetSourceTensorKernel(self, source_tensor)) + + +if hasattr(torch.ops.fsdp, "copy_"): + + @register_lowering(torch.ops.fsdp.copy_.default) + def fsdp_copy_(dst, src): + if dst is src: + # dst.copy_(dst) can happen from the reinplacing pass + return dst + src = to_device(src, dst.get_device()) + src = to_dtype(src, dst.get_dtype()) + src = expand(src, dst.get_size()) + return mutate_to(dst, src) + + +@register_lowering(torch.ops.aten.resize) +def resize(x, size, *, memory_format=None): + assert isinstance(x, TensorBox) + assert isinstance(size, (list, tuple)) + + if memory_format is None: + memory_format = torch.contiguous_format + if memory_format == torch.preserve_format: + raise RuntimeError(f"unsupported memory format: {memory_format}") + + if memory_format == torch.channels_last: + assert len(size) == 4 + if memory_format == torch.channels_last_3d: + assert len(size) == 5 + + old_numel = x.get_numel() + dtype = x.get_dtype() + device = x.get_device() + + if isinstance(x.data, ir.BaseView): + x.data = x.data.unwrap_view() + + if ( + torch.are_deterministic_algorithms_enabled() + and torch.utils.deterministic.fill_uninitialized_memory # type: ignore[attr-defined] + ): + if is_float_dtype(dtype): + uninitalized_val = float("nan") + elif is_integer_dtype(dtype): + uninitalized_val = torch.iinfo(dtype).max + else: + uninitalized_val = True + else: + # using zero as that is what empty does + uninitalized_val = 0.0 + + if V.graph.sizevars.statically_known_equals(old_numel, 0): # type: ignore[arg-type] + return full(size, uninitalized_val, dtype=dtype, device=device) + + x_flat = as_strided( + x, + [ + old_numel, + ], + [ + 1, + ], + ) + flat_loader = x_flat.make_loader() + out_stride = ir.FlexibleLayout.stride_ordered_for_memory_format(size, memory_format) + out_indexer = ir.FixedLayout(device, dtype, size, out_stride).make_indexer() + + def inner_fn(idx): + flat_index = out_indexer(idx) + flat_index_expr = ops.index_expr(flat_index, torch.int64) + limit = ops.index_expr(old_numel, torch.int64) + mask = ops.lt(flat_index_expr, limit) + return ops.masked(mask, lambda: flat_loader([flat_index]), uninitalized_val) + + out = Pointwise.create( + device=device, dtype=dtype, inner_fn=inner_fn, ranges=list(size) + ) + return out + + +from torch._higher_order_ops.auto_functionalize import auto_functionalized + + +make_fallback(auto_functionalized) + + +@register_lowering(triton_kernel_wrapper_mutation) +def triton_kernel_wrap_( + *, + kernel_idx, + constant_args_idx, + grid, + tma_descriptor_metadata, + kwargs, +): + from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table + + constant_args = kernel_side_table.get_constant_args(constant_args_idx) + ir.UserDefinedTritonKernel( + kernel_idx=kernel_idx, + grid=grid, + tma_descriptor_metadata=tma_descriptor_metadata, + kernel_args={**kwargs, **constant_args}, + ) + return {key: val for key, val in kwargs.items() if isinstance(val, TensorBox)} + + +@register_lowering(torch.ops.higher_order.cond) +def cond(pred, true_fn, false_fn, operands): + if is_triton(pred) or any(map(is_triton, operands)): + msg = "control flow operator: torch.cond." + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + result = ir.Conditional.create(pred, true_fn, false_fn, operands) + return list(map(TensorBox.create, result)) + + +@register_lowering(torch.ops.higher_order.while_loop) +def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): + if any(map(is_triton, carried_inputs + additional_inputs)): + msg = "control flow operator: torch.while_loop." + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + result = ir.WhileLoop.create(cond_fn, body_fn, carried_inputs, additional_inputs) + return list(map(TensorBox.create, result)) + + +@register_lowering(torch.ops.higher_order.invoke_subgraph, type_promotion_kind=None) +def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, operands): + result = ir.InvokeSubgraph.create(subgraph_fn, operands) + return list(map(TensorBox.create, result)) + + +@register_lowering(associative_scan_op, type_promotion_kind=None) +def associative_scan(combine_fn: ir.Subgraph, xs, dim: int): + from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph + + subgraph_inputs = [ + InputDescriptor(dtype=x.get_dtype(), device=x.get_device()) + for x in itertools.chain(xs, xs) + ] + lowered_combine_fn = lower_pointwise_subgraph(combine_fn, subgraph_inputs) # type: ignore[var-annotated] + + def wrapped_combine_fn(lhs, rhs): + return lowered_combine_fn( + *pytree.tree_leaves(lhs), + *pytree.tree_leaves(rhs), + ) + + kwargs = _make_scan_inner(xs[0], axis=dim, dtype=None) + kwargs["dtypes"] = tuple(x.get_dtype() for x in xs) + kwargs["inner_fns"] = tuple(x.make_loader() for x in xs) + result = ir.Scan.create( + combine_fn=wrapped_combine_fn, + can_fallback_to_aten=False, + **kwargs, + ) + if result[0] is None: + raise RuntimeError("Unable to generate code for associative_scan op") + return result + + +@register_lowering(torch.ops.prims._sink_tokens.default) +def _sink_tokens(tokens): + return None + + +@register_lowering(torch.ops.higher_order.with_effects, type_promotion_kind=None) +def with_effects(token, op, *args, **kwargs): + result = ir.EffectfulKernel.create(op, *args, **kwargs) + + from torch._higher_order_ops.effects import get_effect_key + + effect_type = get_effect_key(op, args, kwargs) + assert effect_type is not None + effectful_kernel = V.graph.effectful_ops[effect_type] + + if result is None: + return (effectful_kernel,) + + result = pytree.tree_map_only(ir.MultiOutput, TensorBox.create, result) + if not isinstance(result, (list, tuple)): + return (effectful_kernel, result) + else: + return (effectful_kernel, *result) + + +from torch._inductor.comm_lowering import register_comm_lowerings + + +register_comm_lowerings() + +# populate lowerings defined in kernel/* +from torch._inductor import kernel + + +import_submodule(kernel) diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/inductor_patch/scheduler.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/inductor_patch/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..71f87e75cf4e1714d1bb36d47065d39619be5136 --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/inductor_patch/scheduler.py @@ -0,0 +1,137 @@ +import torch +import sympy +import collections +from typing import ( + Union, + Optional +) + +from torch._inductor import ir +from torch._inductor import scheduler +from torch._inductor.scheduler import ( + SchedulerNode, + BaseSchedulerNode, + FusedSchedulerNode, + NopKernelSchedulerNode, + ExternKernelSchedulerNode, + OutputNode, + MultiOutput, + MultiOutputLayout, + OrderedSet, + Sequence, + get_dtype_size, + sympy_product, + V, +) + +def _npu_get_read_write_buffers_sizes(self) -> int: + """ + Counting the number of bytes accessed for a kernel is + surprisingly tricky. In particular, there is a differentiation + between 'theoretical' memory accesses and practical memory + accesses. For example, a layernorm kernel may actually access an + input 3 times, but in theory, it only needs to access its input + once (and may be optimized to do so through say, persistent + reductions) + + Another example is that even though a buffer is passed in, we may + not access the entire buffer. This may occur if we are accessing + a slice of the buffer. Another tricky case is for indirect + indexing, where the amount of bytes accessed depends on the + values of the input. + + What this function aims to compute is the memory accesses for + worst-case inputs, best-case optimization. What this means is + that for each buffer we compute the amount of potential accesses in two ways and take the minimum. + + 1. Numel in ranges multiplied by number of deps the buffer has + 2. The buffer size + """ + if isinstance(self, NopKernelSchedulerNode): + return 0 + if isinstance(self, ExternKernelSchedulerNode) and isinstance( + self.node, MultiOutput + ): + # todo: Calculate this - it's kinda annoying. + return 0 + + def try_size_hint(s: sympy.Expr) -> int: + return V.graph.sizevars.size_hint(s, fallback=0) + + if isinstance(self, SchedulerNode): + node_numel = try_size_hint( + sympy_product(self.get_ranges()[0]) + * sympy_product(self.get_ranges()[1]), + ) + else: + node_numel = int(1e9) + buf_accesses = collections.defaultdict(list) + for dep in self.read_writes.reads | self.read_writes.writes: + buf_accesses[dep.name].append(dep) + + reads = OrderedSet(dep.name for dep in self.read_writes.reads) + writes = OrderedSet(dep.name for dep in self.read_writes.writes) + + def is_materialized(buf: str, snodes: Sequence[BaseSchedulerNode]) -> bool: + users = self.scheduler.name_to_buf[buf].users + buf_uses = OrderedSet(user.node for user in users) + return len(buf_uses - OrderedSet(snodes)) > 0 + + if isinstance(self, FusedSchedulerNode): + removed_buffers = OrderedSet( + dep for dep in writes if not is_materialized(dep, self.snodes) + ) + writes = writes - removed_buffers + reads = reads - removed_buffers + node_bytes = 0 + + for buf_name in reads | writes: + buf_accessed_elems = sum(node_numel for dep in buf_accesses[buf_name]) + buf: Union[ir.Buffer, ir.TensorBox] + if buf_name in V.graph.name_to_buffer: + buf = V.graph.name_to_buffer[buf_name] + elif buf_name in V.graph.graph_inputs: + buf = V.graph.graph_inputs[buf_name] + else: + continue + + def get_buf_bytes(buf: Optional[Union[ir.Buffer, ir.TensorBox]]) -> int: + if not buf: + return 0 + # Kind of a lazy way to get the MultiOutput nodes corresponding to + # a MultiOutputLayout + if isinstance(buf.layout, MultiOutputLayout): + users = self.scheduler.name_to_buf[buf.get_name()].users + tot = 0 + for user in users: + # Custom ops can return a mixed output of tensor and ints. + # This could happen when the custom op return symints, + assert isinstance(user.node, (BaseSchedulerNode, OutputNode)) + if isinstance(user.node, BaseSchedulerNode): + if isinstance(user.node.node, MultiOutput): + for sched_buf in user.node.get_outputs(): + tot += get_buf_bytes(sched_buf.node) + else: + # Buf is a MultiOutputLayout but not all of its + # users are MultiOutputs... + # TODO: Figure out what's going on + return 0 + return tot + elif isinstance(buf.layout, ir.NoneLayout): + return sum( + get_buf_bytes(V.graph.get_buffer(mut_name)) + for mut_name in buf.get_mutation_names() + ) + else: + buf_elems = try_size_hint(sympy_product(buf.get_size())) + return get_dtype_size(buf.get_dtype()) * min( + buf_accessed_elems, buf_elems + ) + + node_bytes += get_buf_bytes(buf) + + return node_bytes + + +def _patch_scheduler(): + scheduler.BaseSchedulerNode.get_read_write_buffers_sizes = _npu_get_read_write_buffers_sizes \ No newline at end of file diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/mlir_compiler.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/mlir_compiler.py new file mode 100644 index 0000000000000000000000000000000000000000..d2527030640cc8289b763fa5d9679fe591526900 --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/mlir_compiler.py @@ -0,0 +1,654 @@ +import os +os.environ['TORCHINDUCTOR_MAX_AUTOTUNE'] = '1' +import sys +import functools +from typing import Callable, Dict, Any, Union, List, Tuple, Iterator +from itertools import count +import importlib +import tempfile +import subprocess +import shutil + +import torch +import torch_npu +from torch._inductor.compile_fx import clone_preserve_strides +from torch._inductor.runtime.cache_dir_utils import triton_cache_dir + +from .utils import ( + _build_npu_ext, + replace_placeholders, + do_bench, + logger, +) +from .. import config as anir_config +from ..cache import get_cache_manager +from .._C import load_kernel_binary +from .codegen.cpp_wrapper import cpp_launcher + + +reinterpret_tensor = torch.ops.inductor._reinterpret_tensor +global_cache = set() +_dump_id_iter: Iterator[int] = count() + +class NpuMlirCompiler: + def __init__(self, + kernel_name: str = '', + multiprocess_compile=False, + no_more_compile=False, + kernel_meta=None, + autotune=True): + self.function = None + self.mode = None + self.launch = None + self.dynamic = kernel_meta.get('dynamic') + self.mutated_indices = kernel_meta.get('mutated_indices') + self.kernel_hash = kernel_meta.get('kernel_hash') + self.signature = kernel_meta.get('signature') + self.ranks = kernel_meta.get('ranks') + self.num_outputs = kernel_meta.get('num_outputs') + self.num_call_functions = kernel_meta.get('num_call_functions') + self.device_index = kernel_meta.get('device_index', 0) + self.traced_graph_hash = kernel_meta.get('traced_graph_hash', 0) + self.kernel_meta = kernel_meta + if self.dynamic: + self.get_host_func_and_tiling_size = None + self.kernel_name = kernel_name + self.launchers = [] + self.kernel_paths = [] + self.is_fallback_kernels = [] + self.multiprocess_compile = multiprocess_compile + self.no_more_compile = no_more_compile + self.mlir_processed = False + self.fx_graph_launcher = None + self.mlir_text = None + self.non_contiguous_inputs = None + self.non_contiguous_outputs = None + self.autotuned = False + self.autotune = autotune + + def init(self, module, extra_env): + os.environ.update(extra_env) + self.mlir_text = module + if os.getenv("TRITON_CACHE_DIR") is None: + os.environ["TRITON_CACHE_DIR"] = triton_cache_dir( + self.kernel_meta.get("device_index", 0) + ) + self.cache = get_cache_manager(self.kernel_hash) + self.prepare_launch(need_pickle=self.multiprocess_compile) + self.get_named_op_path() + + def register_fx_fallback(self, kernel_meta): + def fx_graph_call(module: torch.nn.Module, num_outputs): + def module_call(*args, **kwargs): + actual_args = args[:-num_outputs] + actual_outputs = module.forward(*actual_args) + for out1, out2 in zip(actual_outputs, args[-num_outputs:]): + out2.data = out1.data + return module_call + num_outputs = kernel_meta.get('num_outputs', 0) + traced_graph_hash = kernel_meta.get('traced_graph_hash') + traced_graph_cache = os.path.join(os.getenv("TORCHINDUCTOR_CACHE_DIR"), kernel_meta.get('traced_graph_cache')) + device_index = kernel_meta.get('device_index') + dump_path = os.path.join(traced_graph_cache, str(device_index), traced_graph_hash) + sys.path.append(dump_path) + module = importlib.import_module(traced_graph_hash) + sys.path.remove(dump_path) + Model = getattr(module, traced_graph_hash) + if Model is None: + raise RuntimeError('Cannot find valid graph module!') + model = Model() + module_call = fx_graph_call(model, num_outputs) + self.register_launcher(module_call, kernel_path=self.kernel_name + "_fx_fallback", is_fallback_kernel=True) + + def bisheng_compile(self, + input_path: str, + output_path: str, + auto_db=True, + ops_reorder=False, + tiling_size=None, + extra_command=None): + bisheng_ir_compile_path = f"bishengir-compile" + command = [ + bisheng_ir_compile_path, + "-enable-hfusion-compile=true", + "--enable-bin-relocation=0", + f"-block-dim={anir_config.block_dim}", + ] + if auto_db: + command.append("--enable-auto-multi-buffer=true") + else: + command.append("--enable-auto-multi-buffer=false") + + if ops_reorder: + command.append("--enable-ops-reorder=true") + else: + command.append("--enable-ops-reorder=false") + + if tiling_size is not None: + command.append(f"--hfusion-max-buffer-count-tuning={tiling_size}") + + if anir_config.autotune: + command.append("-enable-tuning-mode=true") + + if self.dynamic: + command.append("--enable-static-bare-ptr=false") + command.append("--enable-symbol-analysis=true") + + if isinstance(extra_command, list) and extra_command: + command += extra_command + command += [ + input_path, + "-o", output_path + ] + logger.info(f"Start to compile, command is: [{' '.join(command)}]") + try: + subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, timeout=600) + logger.info(f"[bisheng-compile success]") + except subprocess.CalledProcessError as e: + logger.info(f"[bisheng-compile failed]") + logger.warning(f"Compile error msg: {e.stderr.decode('utf-8')}") + raise e + + def prepare_launch(self, need_pickle=False): + def get_launch_mod(so_path): + spec = importlib.util.spec_from_file_location("__launcher", so_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + so_name = f"{self.kernel_name}.so" + cache_so_path = self.cache.get_file(so_name) + if cache_so_path is None or (anir_config.always_compile and cache_so_path not in global_cache): + with tempfile.TemporaryDirectory() as tmpdir: + c_wrapper_path = os.path.join(tmpdir, f"{self.kernel_name}_launch.cpp") + cache_mlir_path = self.cache.put(cpp_launcher(self.signature, self.kernel_name, self.ranks, dynamic=self.dynamic), f"{self.kernel_name}_launch.cpp") + global_cache.add(cache_mlir_path) + with open(c_wrapper_path, 'w') as c_wrapper_file: + c_wrapper_file.write(cpp_launcher(self.signature, self.kernel_name, self.ranks, dynamic=self.dynamic)) + if self.dynamic: + with open(c_wrapper_path, "rb") as f: + cache_c_wrapper_path = self.cache.put(f.read(), f"{self.kernel_name}_launch.cpp", binary=True) + global_cache.add(cache_c_wrapper_path) + so_path = _build_npu_ext(self.kernel_name, c_wrapper_path, tmpdir) + with open(so_path, "rb") as f: + cache_so_path = self.cache.put(f.read(), so_name, binary=True) + global_cache.add(cache_so_path) + if not need_pickle: + mod = get_launch_mod(cache_so_path) + self.launch = getattr(mod, "launch") + if self.dynamic: + self.get_host_func_and_tiling_size = getattr(mod, "get_host_func_and_tiling_size") + + def get_named_op_path(self): + named_op_name = f"{self.kernel_name}_named_op.mlir" + cache_mlir_path = self.cache.get_file(named_op_name) + if cache_mlir_path is None or (anir_config.always_compile and cache_mlir_path not in global_cache): + #if anir_config.cache_named_op: + cache_mlir_path = self.cache.put(self.mlir_text, named_op_name) + global_cache.add(cache_mlir_path) + if anir_config.fx_subgraph_dump_path: + shutil.copy(cache_mlir_path, os.path.join(anir_config.fx_subgraph_dump_path, \ + str(self.device_index), self.kernel_name)) + return cache_mlir_path + + def get_launch_dynamic(self, function, tiling_func, tiling_size): + block_dim = anir_config.block_dim + def kernel_call(*args, stream=None): + self.launch(block_dim, stream, function, tiling_func, tiling_size, None, None, None, *args) + return kernel_call + + def get_launch(self, function): + block_dim = anir_config.block_dim + def kernel_call(*args, function, stream=None): + self.launch(block_dim, stream, function, None, None, None, *args) + + return functools.partial(kernel_call, function=function) + + def get_launch_func(self, cache_kernel_path): + if self.dynamic: + function, tiling_func, tiling_size = self.get_host_func_and_tiling_size(self.kernel_name, + self.kernel_name + '_tiling_function', + self.kernel_name + '_get_tiling_struct_size_function', + cache_kernel_path) + return self.get_launch_dynamic(function, tiling_func, tiling_size) + else: + function = load_kernel_binary(self.kernel_name, cache_kernel_path) + return self.get_launch(function) + + def register_launcher(self, + launcher, + kernel_path=None, + num_outputs=None, + disable_dump=False, + auto_fallback=False, + is_fallback_kernel=False): + if num_outputs: + self.num_outputs = num_outputs + self.launchers.append(launcher) + self.kernel_paths.append(kernel_path) + self.is_fallback_kernels.append(is_fallback_kernel) + self.fx_graph_launcher = launcher + if kernel_path.endswith('_fx_fallback'): + if auto_fallback: + if anir_config.fallback_warning: + print(f"This kernel {self.kernel_name} has been fallback to the eager fx graph mode, ", \ + "which will lead to a significant decrease in performance.", flush=True) + if anir_config.fallback_dump and not disable_dump: + self.fx_subgraph_dump('fallback') + logger.info(f"register launcher {launcher} {kernel_path} success") + + def compile_mlir(self, + device_info: Tuple[Any], + compile_args: List[Any], + logger_level = None) -> Callable[..., None]: + if logger_level is not None: + # re-init logger level in subprocess + logger.setLevel(logger_level) + named_op_mlir_path = self.get_named_op_path() + + kernel_name = self.kernel_name + + tiling_size, ops_reorder, auto_db = compile_args + tiling_str = f"_{tiling_size}_{ops_reorder}_{auto_db}" + tiling_kernel_name = kernel_name + tiling_str + if self.dynamic: + cache_kernel_path = self.cache.get_file(f"lib{tiling_kernel_name}.so") + else: + cache_kernel_path = self.cache.get_file(f"{tiling_kernel_name}.o") + + logger.info("Start to get cached kernel. Tiling info: " + + f"tiling_size {tiling_size} ops_reorder {ops_reorder} auto_db {auto_db}") + + if cache_kernel_path is None and self.no_more_compile: + raise RuntimeError("Skip compile.") + + if cache_kernel_path is None or (anir_config.always_compile and cache_kernel_path not in global_cache): + logger.info("No cached kernel. Start to exec compile.") + with tempfile.TemporaryDirectory() as tmpdir: + kernel_path = os.path.join(tmpdir, tiling_kernel_name) + self.bisheng_compile(named_op_mlir_path, kernel_path, tiling_size=tiling_size, + ops_reorder=ops_reorder, auto_db=auto_db, + extra_command=anir_config.extra_command) + + + if self.dynamic: + kernel_path = os.path.join(tmpdir, f"lib{tiling_kernel_name}.so") + with open(kernel_path, "rb") as f: + cache_kernel_path = self.cache.put(f.read(), f"lib{tiling_kernel_name}.so", binary=True) + global_cache.add(cache_kernel_path) + else: + kernel_path = os.path.join(tmpdir, tiling_kernel_name + '.o') + with open(kernel_path, "rb") as f: + cache_kernel_path = self.cache.put(f.read(), f"{tiling_kernel_name}.o", binary=True) + global_cache.add(cache_kernel_path) + + logger.info("Get kernel success.") + if not self.multiprocess_compile: + logger.info(f"Start to register kernel, path '{cache_kernel_path}' func '{self.kernel_name}'") + launch_func = self.get_launch_func(cache_kernel_path) + self.register_launcher(launch_func, cache_kernel_path) + + if anir_config.fx_subgraph_dump_path: + kernel_dump_path = os.path.join(anir_config.fx_subgraph_dump_path, \ + str(self.device_index), self.kernel_name, 'kernel_dump') + os.makedirs(kernel_dump_path, exist_ok=True) + shutil.copy(cache_kernel_path, kernel_dump_path) + + def replace_kernel_by_path(self, kernel_path: str): + self.launchers.clear() + self.kernel_paths.clear() + self.is_fallback_kernels.clear() + logger.info(f"Start to replace kernel by specific path, path '{kernel_path}' func '{self.kernel_name}'") + launch_func = self.get_launch_func(kernel_path) + self.register_launcher(launch_func, kernel_path) + + def get_best_kernel(self): + def get_launch_mod(so_path): + spec = importlib.util.spec_from_file_location("__launcher", so_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + best_kernel = self.cache.get_file('best_kernel') + if best_kernel is None: + raise RuntimeError("can not find best kernel") + with open(best_kernel, 'r') as f: + kernel_path = self.cache.get_file(f.read()) + if not kernel_path.endswith(('.so', '.o')): + self.register_fx_fallback(self.kernel_meta) + return + so_path = self.cache.get_file(f'{self.kernel_name}.so') + if kernel_path is None: + return RuntimeError() + mod = get_launch_mod(so_path) + self.launch = getattr(mod, "launch") + if self.dynamic: + self.get_host_func_and_tiling_size = getattr(mod, "get_host_func_and_tiling_size") + + launch_func = self.get_launch_func(kernel_path) + self.register_launcher(launch_func, kernel_path) + return True + + def get_autotune_config(self): + def get_tiling_range(): + return [i for i in range(-10, 20, 2)] + compile_args = [] + for ops_reorder in [True, False]: + for auto_db in [True, False]: + for tiling_size in get_tiling_range(): + compile_args.append((tiling_size, ops_reorder, auto_db)) + return compile_args + + def precompile(self, + device_info: Tuple[Any], + suppress_error=False, + logger_level=None): + + if anir_config.autotune: + compile_args = self.get_autotune_config() + else: + compile_args = [(None, True, True)] + for cargs in compile_args: + try: + self.compile_mlir(device_info, cargs, logger_level=logger_level) + except Exception as e: + if suppress_error: + logger.warning(f"compile args {cargs} fail, err msg: {e}") + else: + raise e + + def bench(self, idx, launcher, *args, **kwargs): + if anir_config.runtime_error_dump: + self.data_dump_fake(*args) + cloned_args = args + def kernel_call(): + launcher(*cloned_args, **kwargs) + try: + return do_bench(kernel_call, warmup=1, rep=5, fast_flush=True) + except Exception as e: + print(f"RUNTIME ERROR: eval kernel fail, kernel path: {self.kernel_paths[idx]}, ", + f"try to add {self.kernel_paths[idx]} to anir_config.force_fallback_kernel_paths", flush=True) + print(e, flush=True) + if anir_config.runtime_error_dump: + self.fx_subgraph_dump('runtime_error') + exit(0) + + def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]: + # [Note: clone mutated buffers] + # clone inplace buffers to avoid autotune contaminating them if + # the kernel does in-place stores. avoid cloning other buffers because + # it leads to increase memory use + cloned_args = [] + for i, arg in enumerate(args): + if i in self.mutated_indices: + assert isinstance(arg, torch.Tensor) + cloned_args.append(clone_preserve_strides(arg)) + else: + cloned_args.append(arg) + + return cloned_args + + def benchmark_all_configs(self, *args, **kwargs): + timings = [] + args_new = () + args = list(args) + if self.dynamic: + for idx, arg in enumerate(args): + if not torch.is_tensor(arg): + args_new = args_new + (arg, ) + continue + if idx in self.mutated_indices: + cloned_arg = clone_preserve_strides(arg) + args[idx] = cloned_arg + args_new = args_new + (cloned_arg, cloned_arg, 0) + arg.size() + arg.stride() + else: + args_new = args_new + (arg, arg, 0) + arg.size() + arg.stride() + else: + for idx, arg in enumerate(args): + if torch.is_tensor(arg) and idx in self.mutated_indices: + cloned_arg = clone_preserve_strides(arg) + args_new = args_new + (cloned_arg, ) + args[idx] = cloned_arg + else: + args_new = args_new + (arg, ) + + for idx, launcher in enumerate(self.launchers): + if self.dynamic and not self.is_fallback_kernels[idx]: + transformed_args = args_new + else: + transformed_args = args + if self.kernel_name in anir_config.force_fallback_kernel_names and not self.kernel_paths[idx].endswith('_fx_fallback'): + continue + if self.kernel_paths[idx] in anir_config.force_fallback_kernel_paths: + print(f"Skip kernel: {self.kernel_paths[idx]}", flush=True) + continue + try: + logger.info(f"start to eval kernel {self.kernel_paths[idx]}") + times = self.bench(idx, launcher, *transformed_args, **kwargs) + timings.append([times, idx]) + logger.info(f"eval over") + except Exception as e: + print(e) + continue + return timings + + def autotune_to_one_config(self, *args, **kwargs): + if anir_config.autotune_fx_fallback: + self.register_fx_fallback(self.kernel_meta) + if any([isinstance(arg, torch.Tensor) and not arg.is_contiguous() for arg in args]): + print(f'Non contiguous args exists! Kernel name is {self.kernel_name}') + timings = self.benchmark_all_configs(*args, **kwargs) + timings.sort() + logger.info(f"autotune over, timings: {timings}") + if timings[0][0] > 99999: + raise RuntimeError("All config exec failed.") + idx = timings[0][1] + logger.info(f"autotune benchmark over, using kernel {self.kernel_paths[idx]}") + self.kernel_paths = [self.kernel_paths[idx]] + self.launchers = [self.launchers[idx]] + self.is_fallback_kernels = [self.is_fallback_kernels[idx]] + if self.is_fallback_kernels[0]: + self.cache.put(self.traced_graph_hash, "best_kernel", binary=False) + else: + self.cache.put(self.kernel_paths[0].split('/')[-1], "best_kernel", binary=False) + + def data_dump(self, *args, dump_path=None): + if not dump_path: + dump_path = os.path.join(anir_config.fx_subgraph_dump_path, str(self.device_index), self.kernel_name) + data_dump_path = os.path.join(dump_path, 'data.pth') + args_cpu = [arg.cpu() if isinstance(arg, torch.Tensor) else arg for arg in args] + torch.save(args_cpu, data_dump_path) + + def data_dump_fake(self, *args, dump_path=None): + if not dump_path: + dump_path = os.path.join(anir_config.fx_subgraph_dump_path, str(self.device_index), self.kernel_name) + runable_py_path = os.path.join(dump_path, f'runnable_{self.kernel_name}.py') + fake_inputs = [f'rand_strided({arg.shape}, {arg.stride()}, device="{arg.device.type}", dtype={arg.dtype})' \ + if isinstance(arg, torch.Tensor) else str(arg) for arg in args[:-self.num_outputs]] + fake_outputs = [f'empty_strided({arg.shape}, {arg.stride()}, device="{arg.device.type}", dtype={arg.dtype})' \ + if isinstance(arg, torch.Tensor) else str(arg) for arg in args[-self.num_outputs:]] + replacements = {"FAKE_ARGS_PLACEHOLDER": f"args = [{', '.join(fake_inputs + fake_outputs)}]"} + replace_placeholders(runable_py_path, replacements) + + def fx_subgraph_dump(self, suffix): + subgraph_dump_path = os.path.join(anir_config.fx_subgraph_dump_path, str(self.device_index), self.kernel_name) + failed_fx_subgraph_dump_path = anir_config.fx_subgraph_dump_path + f'_{suffix}' + failed_subgraph_dump_path = os.path.join(failed_fx_subgraph_dump_path, str(self.device_index), f'{next(_dump_id_iter)}_' + self.kernel_name) + if os.path.exists(failed_subgraph_dump_path): + shutil.rmtree(failed_subgraph_dump_path) + shutil.copytree(subgraph_dump_path, failed_subgraph_dump_path) + return failed_subgraph_dump_path + + def acc_compare_and_dump(self, *args, **kwargs): + from torch.testing._comparison import _make_mismatch_msg + self.register_fx_fallback(self.kernel_meta) + launcher_fx = self.launchers[1] + launcher = self.launchers[0] + + fx_outputs = [clone_preserve_strides(arg).to(torch.float32) if arg.dtype == torch.bfloat16 \ + else clone_preserve_strides(arg) for arg in args[-self.num_outputs:]] + fx_inputs = [clone_preserve_strides(arg) if isinstance(arg, torch.Tensor) else arg for arg in args[:-self.num_outputs]] + fx_inputs = [inp.float() if isinstance(inp, torch.Tensor) and inp.dtype == torch.bfloat16 else inp for inp in fx_inputs] + + fx_args = fx_inputs + fx_outputs + launcher_fx(*fx_args, **kwargs) + + if self.dynamic: + args_new = () + for arg in args: + if not torch.is_tensor(arg): + args_new = args_new + (arg, ) + continue + args_new = args_new + (arg, arg, 0) + arg.size() + arg.stride() + else: + args_new = args + + output = launcher(*args_new, **kwargs) + + has_acc_error = False + num_inputs = len(args) - self.num_outputs + for idx, (actual, expected) in enumerate(zip(args[num_inputs:], fx_outputs)): + if actual.dtype != expected.dtype: + expected = expected.to(actual.dtype) + acc_comp_tol = anir_config.acc_comp_tol.get(actual.dtype, anir_config.acc_comp_tol['default']) + rtol = acc_comp_tol['rtol'] + atol = acc_comp_tol['atol'] + matches = torch.isclose( + actual, expected, rtol=rtol, atol=atol, equal_nan=True + ) + if not matches.all(): + abs_diff = abs(actual - expected) + rel_diff = abs_diff / abs(expected) + rel_diff.masked_fill_(matches, 0) + number_of_elements = matches.numel() + total_mismatches = number_of_elements - int(torch.sum(matches)) + extra = ( + f"Mismatched elements: {total_mismatches} / {number_of_elements} " + f"({total_mismatches / number_of_elements:.1%})" + ) + msg = _make_mismatch_msg( + default_identifier="Tensor-likes", + identifier=None, + extra=extra, + abs_diff=abs_diff.max().item(), + abs_diff_idx=None, + atol=atol, + rel_diff=rel_diff.max().item(), + rel_diff_idx=None, + rtol=rtol, + ) + print(f"Kernel Name: {self.kernel_name}\n{msg}", flush=True) + has_acc_error = True + args[idx + num_inputs].copy_(expected) + del abs_diff + del rel_diff + del matches + del expected + + if anir_config.fx_subgraph_dump_path: + data = args + if has_acc_error: + data_dump_path = self.fx_subgraph_dump('acc_failed') + self.data_dump_fake(*data, dump_path=data_dump_path) + del fx_inputs + torch.npu.synchronize() + self.launchers = [self.launchers[0]] + self.is_fallback_kernels = [self.is_fallback_kernels[0]] + + return output + + def mlir_dump(self, *args, **kwargs): + self.data_dump(*args) + launcher_fx = self.launchers[-1] + fx_output = launcher_fx(*args, **kwargs) + return fx_output + + def make_inputs_contiguous(self, args): + args = list(args) + for idx in self.non_contiguous_indices['inputs']: + args[idx] = args[idx].contiguous() + return tuple(args) + + def run(self, *args, **kwargs): + args = list(args) + + if self.non_contiguous_inputs is None: + self.non_contiguous_inputs = [] + if self.num_call_functions > 0: + for idx, arg in enumerate(args[:-self.num_outputs]): + if isinstance(arg, torch.Tensor) and not arg.is_contiguous(): + args[idx] = args[idx].contiguous() + self.non_contiguous_inputs.append(idx) + else: + for idx in self.non_contiguous_inputs: + args[idx] = args[idx].contiguous() + + contiguous_outputs = [] + + if self.non_contiguous_outputs is None: + self.non_contiguous_outputs = [] + original_outputs = [] + num_inputs = len(args) - self.num_outputs + for idx, arg in enumerate(args[num_inputs:]): + if isinstance(arg, torch.Tensor) and not arg.is_contiguous(): + contiguous_output = torch.empty( + arg.shape, + dtype=arg.dtype, + device=arg.device) + arg_idx = idx - self.num_outputs + original_outputs.append(arg) + args[arg_idx] = contiguous_output + self.non_contiguous_outputs.append(arg_idx) + contiguous_outputs.append(contiguous_output) + else: + original_outputs = [] + for idx in self.non_contiguous_outputs: + contiguous_output = torch.empty( + args[idx].shape, + dtype=args[idx].dtype, + device=args[idx].device) + original_outputs.append(args[idx]) + args[idx] = contiguous_output + contiguous_outputs.append(contiguous_output) + + if not self.autotuned: + if len(self.launchers) > 1: + self.autotune_to_one_config(*args, **kwargs) + elif self.autotune: + if self.kernel_paths[0].endswith('_fx_fallback'): + self.cache.put(self.traced_graph_hash, "best_kernel", binary=False) + else: + self.cache.put(self.kernel_paths[0].split('/')[-1], "best_kernel", binary=False) + else: + pass + self.autotuned = True + + (launcher,) = self.launchers + (is_fallback_kernel, ) = self.is_fallback_kernels + + if anir_config.fx_subgraph_dump_path and \ + anir_config.online_acc_comp and \ + not is_fallback_kernel: + output = self.acc_compare_and_dump(*args, **kwargs) + if self.non_contiguous_outputs: + for i, idx in enumerate(self.non_contiguous_outputs): + original_outputs[i].copy_(args[idx]) + return output + + if self.dynamic and not is_fallback_kernel: + args_new = () + for arg in args: + if not torch.is_tensor(arg): + args_new = args_new + (arg, ) + continue + args_new = args_new + (arg, arg, 0) + arg.size() + arg.stride() + args = args_new + + output = launcher(*args, **kwargs) + + if self.non_contiguous_outputs: + for i, idx in enumerate(self.non_contiguous_outputs): + original_outputs[i].copy_(args[idx]) + + del contiguous_outputs + return output \ No newline at end of file diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/npu_decomp.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/npu_decomp.py new file mode 100644 index 0000000000000000000000000000000000000000..24c55499ecdef0ef2191a4f09eb5cb933dfb2a5f --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/npu_decomp.py @@ -0,0 +1,204 @@ +import functools +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch._inductor import decomposition as inductor_decomp +from torch._C import DispatchKey +from torch import Tensor + +from torch._decomp import ( + remove_decompositions, +) + +from .. import config as anir_config + +aten = torch.ops.aten +npu = torch.ops.npu + +remove_decompositions(inductor_decomp.decompositions, anir_config.decomps_to_exclude_npu) + +# Batch_norm_decomposition function registered to fix dynamic shape dynamo tracing issue. +@aten.batch_norm.default.py_impl(DispatchKey.Autograd) +@aten.batch_norm.default.py_impl(DispatchKey.AutogradPrivateUse1) +def batch_norm_decomposition( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + training: bool, + momentum: float, + eps: float, + cudnn_enabled: bool, +) -> Tensor: + if input.numel() == 0: + out = input.clone() + if weight is not None: + out *= weight[0] + if bias is not None: + out += bias[0] + return out + return aten._batch_norm_impl_index.default( + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + cudnn_enabled, + )[0] + +def npu_convolution_backward( + grad_output, + input, + weight, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_mask, +): + if not output_mask[2]: + return NotImplemented + grad_bias = torch.ops.aten.sum(grad_output, [0] + list(range(2, grad_output.dim()))) + grad_inp, grad_weight, _ = torch.ops.aten.convolution_backward( + grad_output, + input, + weight, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + [output_mask[0], output_mask[1], False], + ) + return (grad_inp, grad_weight, grad_bias) + +def npu__softmax_backward_data( + grad_output: torch.Tensor, + output: torch.Tensor, + dim: int, + input_dtype: torch.dtype, +) -> torch.Tensor: + new_grad_output = grad_output * output + sum_new_grad = torch.sum(new_grad_output, dim=dim, keepdim=True) + grad_input = new_grad_output - output * sum_new_grad + # grad_input = inductor_prims.fma(-output, sum_new_grad, new_grad_output) + + # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor + # if grad_output.device == torch.device("cpu"): + # return grad_input.contiguous() + + if grad_output.dtype != input_dtype: + grad_input = grad_input.to(input_dtype) + return grad_input.contiguous() + +def npu_rms_norm( + x: torch.Tensor, + weight: torch.Tensor, + epsilon=1e-6 +) -> torch.Tensor: + dtype = x.dtype + x = x.float() + rsqrt = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + epsilon) + output = (x * rsqrt * weight).to(dtype) + return output, rsqrt + +def npu_rms_norm_backward(grad_output: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + rsqrt: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + dx = (grad_output * weight - x * rsqrt * (grad_output * weight * x * rsqrt).mean(-1, keepdim=True)) * rsqrt + dgamma = (grad_output * x * rsqrt).sum(0, keepdim=False) + return dx, dgamma + +def npu_swiglu(x, dim=-1): + x = torch.chunk(x, 2, dim=dim) + return F.silu(x[0]) * x[1] + +def npu_swiglu_backward(grad_output, x, dim=-1): + x0, x1 = torch.chunk(x, 2, dim=dim) + + # 计算 x0 的梯度 + sigmoid_x0 = torch.sigmoid(x0) + silu_grad = sigmoid_x0 * (1 + x0 * (1 - sigmoid_x0)) # SiLU 的导数 + grad_x0 = grad_output * x1 * silu_grad + + # 计算 x1 的梯度 + grad_x1 = grad_output * F.silu(x0) + grad_x = torch.cat([grad_x0, grad_x1], dim=dim) + return grad_x + +def _rotate_half(x: Tensor) -> Tensor: + x1, x2 = torch.chunk(x, 2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + +def npu_rotary_mul(t, cos_, sin_): + t = (t * cos_) + (_rotate_half(t) * sin_) + return t + +def npu_rotary_mul_backward(grad_output, t, cos_, sin_): + rotated_t = _rotate_half(t) + grad_t = cos_ * grad_output + grad_rotated_part = grad_output * sin_ + a, b = torch.chunk(grad_rotated_part, 2, dim=-1) + grad_rotated_t = torch.cat((b, -a), dim=-1) + grad_t = grad_t + grad_rotated_t + + grad_cos = t * grad_output + grad_sin = rotated_t * grad_output + + return grad_t, grad_cos, grad_sin + +def gelu(a, approximate: str = "none"): + """ + Reference implementation of torch.nn.functional.gelu + """ + M_SQRT2 = 1.41421356237309504880 + M_2_SQRTPI = 1.12837916709551257390 + kBeta = M_SQRT2 * M_2_SQRTPI * 0.5 + kKappa = 0.044715 + a_cube = a * a * a + inner = kBeta * (a + kKappa * a_cube) + return 0.5 * a * (1 + torch.tanh(inner)) + +def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"): + M_SQRT2 = 1.41421356237309504880 + M_SQRT1_2 = 0.70710678118654752440 + M_2_SQRTPI = 1.12837916709551257390 + kBeta = M_SQRT2 * M_2_SQRTPI * 0.5 + kKappa = 0.044715 + x_sq = self * self + x_cube = x_sq * self + inner = kBeta * (self + kKappa * x_cube) + tanh_inner = torch.tanh(inner) + + left = 0.5 * self + right = 1.0 + tanh_inner + + left_derivative = 0.5 * right + + tanh_derivative = (tanh_inner * tanh_inner) * -1.0 + 1.0 + inner_derivative = kBeta * (1.0 + 3.0 * kKappa * x_sq) + right_derivative = left * tanh_derivative * inner_derivative + + return grad * (left_derivative + right_derivative) + +inductor_decomp.register_decomposition(torch.ops.aten.convolution_backward)(npu_convolution_backward) +inductor_decomp.register_decomposition(torch.ops.aten._softmax_backward_data.default)(npu__softmax_backward_data) +inductor_decomp.register_decomposition(torch.ops.aten.gelu.default)(gelu) +inductor_decomp.register_decomposition(torch.ops.aten.gelu_backward.default)(gelu_backward) +# inductor_decomp.register_decomposition(torch.ops.npu.npu_rms_norm.default)(npu_rms_norm) +# inductor_decomp.register_decomposition(torch.ops.npu.npu_rms_norm_backward.default)(npu_rms_norm_backward) +# inductor_decomp.register_decomposition(torch.ops.npu.npu_swiglu.default)(npu_swiglu) +# inductor_decomp.register_decomposition(torch.ops.npu.npu_swiglu_backward.default)(npu_swiglu_backward) +# inductor_decomp.register_decomposition(torch.ops.npu.npu_rotary_mul.default)(npu_rotary_mul) +# inductor_decomp.register_decomposition(torch.ops.npu.npu_rotary_mul_backward.default)(npu_rotary_mul_backward) \ No newline at end of file diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/npu_inductor_plugin.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/npu_inductor_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..213d85c4f07d87dcd3a013893cb9197fd6a65d9e --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/npu_inductor_plugin.py @@ -0,0 +1,430 @@ +import atexit +import collections +from collections import Counter +import functools +import itertools +import shutil +from torch.utils._ordered_set import OrderedSet + +import torch +import torch.nn.functional as F +from torch import Tensor + +from torch.utils import _triton +_triton.has_triton = lambda: False +_triton.has_triton_package = lambda: False + +from typing import ( + Set, + Dict, + Optional, + Tuple, + List +) +from torch._dynamo.utils import dynamo_timed +from torch._inductor.async_compile import shutdown_compile_workers +from torch._inductor.codegen.common import register_backend_for_device, register_device_op_overrides +from torch._inductor.virtualized import V +from torch._inductor import decomposition as inductor_decomp +from ..npu.codegen.mlir import NpuMlirScheduling +from ..npu.codegen.wrapper import NpuMlirWrapperCodeGen +from ..npu.npu_lowering import _register_npu_inductor_fallbacks +from ..npu.utils import ( + npu_optimize_fx_graph, + run_once +) + +from .. import config as anir_config + +from torch._decomp import ( + decomposition_table, +) + +from . import npu_patch_deprecated, torch_mlir_patch +from .npu_meta import npu_patch_meta + +# Fix Error: Exit earlier than child process. +atexit.register(shutdown_compile_workers) + +# new npu meta registration. +npu_patch_meta() + +from torch._dynamo import config as dynamo_config +dynamo_config.fake_tensor_cache_enabled = False + +from torch._inductor import config +config.layout_optimization = False +config.size_asserts = False +config.fallback_random = True +config.optimize_scatter_upon_const_tensor = False + +if anir_config.online_acc_comp: + config.fx_graph_cache = False + +aten = torch.ops.aten + +## Override original inductor device overrides in torch_npu +from torch_npu.utils._inductor import NPUDeviceOpOverrides + +# Not good implementation, but no other way +def get_current_raw_stream(device): + return torch.npu.current_stream(device).npu_stream + +class NewNPUDeviceOpOverrides(NPUDeviceOpOverrides): + def import_get_raw_stream_as(self, name): + return f"from torch_npu._inductor.ascend_npu_ir.ascend_npu_ir.npu.npu_inductor_plugin import get_current_raw_stream as {name}" + +def _inductor_register_device_op_overrides(): + register_device_op_overrides('npu', NewNPUDeviceOpOverrides()) + +_inductor_register_device_op_overrides() + +## Override original dynamo device interface in torch_npu +from torch_npu.utils._dynamo_device import NpuInterface +try: + from torch_npu.npu import device_count +except: + from torch_npu.npu.utils import device_count +from torch._dynamo.device_interface import register_interface_for_device + +class NewNpuInterface(NpuInterface): + + @staticmethod + def is_available() -> bool: + return device_count() > 0 + + @staticmethod + def get_compute_capability(device=None): + # npu has no concept of cc. triton-npu compiler depends on subarch instead + return torch.npu.get_device_name(device) + +register_interface_for_device("npu", NewNpuInterface) + +register_backend_for_device("npu", NpuMlirScheduling, NpuMlirWrapperCodeGen) + +# recover from torch_npu._inductor patches to source code +def src_call(self, model_, inputs_): + from torch._inductor.compile_fx import compile_fx + + return compile_fx(model_, inputs_, config_patches=self.config) + +from torch import _TorchCompileInductorWrapper + +_TorchCompileInductorWrapper.__call__ = src_call + +## npu patch +from ..npu import npu_decomp +from torch._C import DispatchKey +from torch._prims_common.wrappers import out_wrapper + +def disable_implicit_decomposition(): + ''' + Since torch official will implicitly decompose some aten ops, + disable some ops here to avoid poor performance after decompose. + ''' + disable_aten_ops = [ + 'aten.upsample_nearest1d.vec', 'aten.upsample_nearest1d.default', + 'aten.upsample_nearest2d.vec', 'aten.upsample_nearest2d.default', + 'aten.upsample_nearest3d.vec', 'aten.upsample_nearest3d.default', + 'aten.upsample_bilinear2d.vec', 'aten.upsample_bilinear2d.default', + ] + for op_override in decomposition_table.keys(): + if str(op_override) in disable_aten_ops: + if DispatchKey.Autograd in op_override.py_kernels: + op_override.py_kernels.pop(DispatchKey.Autograd) + if DispatchKey.CompositeImplicitAutograd in op_override.py_kernels: + op_override.py_kernels.pop(DispatchKey.CompositeImplicitAutograd) + + +def wrap__dynamo_optimize(fn): + @functools.wraps(fn) + def npu__dynamo_optimize(*args, **kwargs): + from ..npu import inductor_patch + disable_implicit_decomposition() + return fn(*args, **kwargs) + return npu__dynamo_optimize + +from torch import _dynamo +_dynamo.optimize = wrap__dynamo_optimize(_dynamo.optimize) + + +from torch._dynamo.backends import common +from torch._dynamo.backends.common import AotAutograd + +def wrap_compiler(fn): + @functools.wraps(fn) + def npu_compiler(gm: torch.fx.GraphModule, example_inputs, *args, **kwargs): + npu_optimize_fx_graph(gm) + return fn(gm, example_inputs, *args, **kwargs) + return npu_compiler + +def wrap_aot_autograd(fn): + @functools.wraps(fn) + def npu_aot_autograd(*args, **kwargs): + _register_npu_inductor_fallbacks() + def wrap_compiler_by_key(name): + if name in kwargs: + kwargs[name] = wrap_compiler(kwargs[name]) + wrap_compiler_by_key('fw_compiler') + wrap_compiler_by_key('bw_compiler') + wrap_compiler_by_key('inference_compiler') + return fn(*args, **kwargs) + return npu_aot_autograd + +AotAutograd.__call__ = wrap_aot_autograd(AotAutograd.__call__) + +# recompute last usage for inductor scheduler +from torch._inductor import scheduler +from torch._inductor.scheduler import ( + Dep, + WeakDep, + Scheduler, + SchedulerNode, + SchedulerBuffer, + FusedSchedulerNode, + BaseSchedulerNode, + ForeachKernelSchedulerNode, + ExternKernelSchedulerNode, + NopKernelSchedulerNode, + WhyNoFuse, + MemoryDep + ) + +def used_or_aliased_buffer_names(node) -> Set[str]: + used_names: OrderedSet[str] = OrderedSet() + if isinstance(node, (SchedulerNode, FusedSchedulerNode)) and not isinstance(node, ForeachKernelSchedulerNode): + snodes = [node] if isinstance(node, SchedulerNode) else node.snodes + for snode in snodes: + traced_graph = snode.node.data.traced_graph + used_names = used_names.union(traced_graph.get_placeholder_names()) + used_names.add(snode.node.get_name()) + else: + deps = [ + dep.name + for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes) + ] + while len(deps) > 0: + dep = deps.pop() + used_names.add(dep) + if V.graph.name_to_buffer.get(dep): + for alias in V.graph.name_to_buffer[dep].get_inputs_that_alias_output(): + if alias not in used_names: + deps.append(alias) + return used_names + +def set_last_usage( + node: BaseSchedulerNode, future_used_buffers: Set[str], mutation_real_name: Dict[str, str] +): + used_buffers = used_or_aliased_buffer_names(node) + used_buffers = OrderedSet(mutation_real_name.get(k, k) for k in used_buffers) + node.last_usage = used_buffers - future_used_buffers + +def wrap_scheduler_codegen(fn): + @functools.wraps(fn) + def npu_sheduler_codegen(self): + future_used_buffers = set() + for node_name in V.graph.get_output_names(): + future_used_buffers.add(node_name) + for node in reversed(self.nodes): + set_last_usage(node, future_used_buffers, self.mutation_real_name) + future_used_buffers.update(node.last_usage) + return fn(self) + return npu_sheduler_codegen + +def npu_compute_ancestors(self) -> None: + """ + Populate each node.ancestors + """ + # note self.nodes is topologically sorted + name_to_ancestors: Dict[str, OrderedSet[str]] = {} + for node in self.nodes: + ancestors: OrderedSet[str] = OrderedSet() + for dep in node.unmet_dependencies: + if dep.name not in self.name_to_buf: + continue + dep_node_name = self.name_to_buf[dep.name].defining_op.get_name() + ancestors.add(dep_node_name) + ancestors |= name_to_ancestors[dep_node_name] + name_to_ancestors[node.get_name()] = ancestors + node.ancestors = ancestors + + for order, node in enumerate(self.nodes): + node.min_order = order + node.max_order = order + +def _npu_prune_redundant_deps( + node: BaseSchedulerNode, + name_to_fused_node: Dict[str, BaseSchedulerNode], + name_to_buf: Dict[str, SchedulerBuffer], +) -> None: + """ + Prunes weakdeps intended for mutation ordering + on an upstream fused node if after fusion there is another dependency + on the fused upstream node, making the weakdep redundant + + In essence this enforces an ordering on fusions. As fusions occur, weakdeps will + be incrementally removed, enabling other fusions, ensuring they are fused in order. + """ + name_to_dep_count: Counter[str] = collections.Counter() + + for dep in node.unmet_dependencies: + if not isinstance(dep, WeakDep) and dep.name in name_to_buf: + op = name_to_buf[dep.name].defining_op + name_to_dep_count[name_to_fused_node[op.get_name()].get_name()] += 1 + + def should_prune(dep: Dep) -> bool: + if isinstance(dep, WeakDep) and dep.name in name_to_buf: + op_name = name_to_buf[dep.name].defining_op.get_name() + is_redundant = name_to_dep_count[name_to_fused_node[op_name].get_name()] > 0 + # These can occur because fused nodes always gather deps from their snodes + # If B has a weakdep on A + # B gets fused with C, then any time BC is fused, the weakdep will reappear + is_self_dep = name_to_fused_node[op_name] == node + return is_redundant or is_self_dep + else: + return False + + deps_to_prune = OrderedSet( + dep for dep in node.unmet_dependencies if should_prune(dep) + ) + + if deps_to_prune: + node.unmet_dependencies = node.unmet_dependencies - deps_to_prune + node.set_read_writes(node.read_writes.remove_reads(deps_to_prune)) + +def npu_can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode +) -> bool: + """ + Check if it is legal to fuse a consumer (node2) into a producer (node1). + + We can fuse them if all the reads of node2 either match + corresponding writes in node1, or are written by nodes that can + be scheduled before the fusion of node1 and node2. + """ + node1_buf_names = node1.get_buffer_names() + node1_op_names = node1.get_operation_names() + computed_deps: OrderedSet[Dep] = OrderedSet() + why = WhyNoFuse(node1, node2) + + for cd in node1.read_writes.writes: + if not isinstance(cd, MemoryDep): + continue + for rd in node2.unmet_dependencies: + if self.fusable_read_and_write(rd, cd): + computed_deps.add(rd) + + for dep in node2.unmet_dependencies: + if isinstance(dep, WeakDep) and self.fusable_weak_dep(dep, node1, node2): + computed_deps.add(dep) + + remaining_deps = OrderedSet( + dep.name for dep in node2.unmet_dependencies - computed_deps + ) + if remaining_deps & node1_buf_names: + # MemoryDeps didn't match and read different locations of the same buffer. + # Examples here include: + # - MemoryDep("foo", x) != MemoryDep("foo", x + 1) + # - MemoryDep("foo", x) != StarDep("foo") + why("memory deps did not match") + return False + for name in remaining_deps: + if name not in self.name_to_buf: + continue + op_name = self.name_to_buf[name].defining_op.get_name() + if node1_op_names & self.name_to_fused_node[op_name].ancestors: + why("intermediate nodes between node1 & node2") + return False + + return True + +def _npu_get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> List[BaseSchedulerNode]: + unmet_deps = set() + if isinstance( + snode, + ( + SchedulerNode, + ExternKernelSchedulerNode, + NopKernelSchedulerNode, + FusedSchedulerNode, + ), + ): + for dep in snode.unmet_dependencies: + unmet_deps.add(dep.name) + else: + raise RuntimeError( + f"get_unmet_dep_nodes is not implemented for {type(snode)}." + ) + unmet_dep_ops = (self.name_to_buf[dep].defining_op for dep in unmet_deps if dep in self.name_to_buf) + return list({self.name_to_fused_node[n.get_name()] for n in unmet_dep_ops}) + +if anir_config.enable_graph_trace: + Scheduler._codegen = wrap_scheduler_codegen(Scheduler._codegen) + Scheduler.compute_ancestors = npu_compute_ancestors + scheduler._prune_redundant_deps = _npu_prune_redundant_deps + Scheduler.can_fuse_vertical = npu_can_fuse_vertical + Scheduler._get_unmet_dep_nodes = _npu_get_unmet_dep_nodes + +def wrap_avg_pool2d(fn): + @functools.wraps(fn) + def dynamo_avg_pool2d(input, *args, **kwargs): + if input.dtype == torch.bfloat16: + input = input.to(torch.float32) + output = fn(input, *args, **kwargs) + return output.to(torch.bfloat16) + else: + return fn(input, *args, **kwargs) + return dynamo_avg_pool2d + +F.avg_pool2d = wrap_avg_pool2d(F.avg_pool2d) + +# patches for transfer_to_npu +def patch_transfer_to_npu(): + try: + import torch + import torch_npu + from torch_npu.contrib import transfer_to_npu + from torch_npu.contrib.transfer_to_npu import ( + _replace_cuda_to_npu_in_list, + device_kwargs_list, + _replace_cuda_to_npu_in_kwargs, + ) + + def new_wrapper_cuda(module, method): + src_method = f"_src_{method}" + if hasattr(getattr(module, method), '__wrapped__'): + src_func = getattr(module, method).__wrapped__ + else: + src_func = getattr(module, method) + + setattr(module, src_method, src_func) + fn = getattr(module, src_method) + + def decorated(*args, **kwargs): + replace_int = fn.__name__ in ['to', 'to_empty'] + if args: + args_new = list(args) + args = _replace_cuda_to_npu_in_list(args_new, replace_int) + if kwargs: + for device_arg in device_kwargs_list: + device = kwargs.get(device_arg, None) + if device is not None: + _replace_cuda_to_npu_in_kwargs(kwargs, device_arg, device) + device_ids = kwargs.get('device_ids', None) + if type(device_ids) == list: + device_ids = _replace_cuda_to_npu_in_list(device_ids, replace_int) + return fn(*args, **kwargs) + + setattr(module, method, decorated) + return decorated + + def new_device_wrapper(enter_fn, white_list): + for fn_name in white_list: + fn = getattr(enter_fn, fn_name, None) + if fn: + new_wrapper_cuda(enter_fn, fn_name) + + transfer_to_npu._device_wrapper = new_device_wrapper + transfer_to_npu._init() + except: + pass \ No newline at end of file diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/npu_lowering.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/npu_lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3488e50d31cbd9bf02e969f1957c9c79b77228 --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/npu_lowering.py @@ -0,0 +1,59 @@ +import os +import sys +import time +from torch._inductor import lowering +from torch._inductor.lowering import make_fallback +from torch._inductor import decomposition +import torch._ops +from .. import config +from ..npu.utils import run_once, get_anir_mode + +aten = torch.ops.aten +tr_c10d = torch.ops.tr_c10d +prims = torch.ops.prims + +@run_once +def _register_npu_inductor_fallbacks(): + gen_set = set() + fallback_set = set() + + for fn in config.GENERATE_LIST: + gen_set.add(fn) + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + other_fn = getattr(fn, overload) + gen_set.add(other_fn) + + for fn in config.FALLBACK_LIST: + fallback_set.add(fn) + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + other_fn = getattr(fn, overload) + fallback_set.add(other_fn) + + def fallback_except_gen_set(gen_set): + for op in lowering.lowerings: + if op not in decomposition.decompositions and op not in gen_set: + if isinstance(op, torch._ops.OpOverloadPacket) or \ + isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + make_fallback(op) + + def fallback_via_fallback_set(fallback_set): + for op in lowering.lowerings: + if op not in decomposition.decompositions and op in fallback_set: + if isinstance(op, torch._ops.OpOverloadPacket) or \ + isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + make_fallback(op) + + if config.fallback_to_aten_mode not in {"off", "include", "exclude"}: + raise AssertionError(f"Error! Unsupported fallback_to_aten_mode: {config.fallback_to_aten_mode} was set!") + + if get_anir_mode() == 'O0': + fallback_except_gen_set(gen_set=[]) + decomposition.decompositions.clear() + return + + if config.fallback_to_aten_mode == 'include': + fallback_via_fallback_set(fallback_set=fallback_set) + elif config.fallback_to_aten_mode == 'exclude': + fallback_except_gen_set(gen_set=gen_set) \ No newline at end of file diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/npu_meta.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/npu_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..2175c4e1718000a390133dd8352be943f50d466a --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/npu_meta.py @@ -0,0 +1,150 @@ +import os +import sys +import operator +from functools import wraps, reduce, lru_cache +from typing import Callable, Optional, Tuple +import torch +import torch_npu +from torch import Tensor +from torch._ops import OpOverload, OpOverloadPacket +from torch._subclasses import fake_tensor as _subclasses_fake_tensor +from torch._C import DispatchKey +from torch._refs import div as refs_div, _broadcast_shapes +from torch._prims_common import corresponding_real_dtype, corresponding_complex_dtype +from torch._prims_common.wrappers import out_wrapper +from torch._decomp import decomposition_table, decompositions_for_rng, get_decompositions +from torch._dynamo.symbolic_convert import break_graph_if_unsupported, InstructionTranslatorBase, stack_op +from torch._dynamo.exc import Unsupported +from torch._dynamo.variables.lists import TupleVariable +from torch._dynamo.variables.nn_module import NNModuleVariable + + +aten = torch.ops.aten +npu = torch.ops.npu + + +def run_once(f): + """Runs a function (successfully) only once. + The running can be reset by setting the `has_run` attribute to False + """ + @wraps(f) + def wrapper(*args, **kwargs): + if not wrapper.has_run: + result = f(*args, **kwargs) + wrapper.has_run = True + return result + return None + wrapper.has_run = False + return wrapper + + +npu_meta_table = {} +break_fn_table = {} +break_mapping_table = {} +avoid_make_fallback_table = [] + + +def _add_op_to_meta_table(op, fn, avoid_fallback_flag=False): + overloads = [] + if isinstance(op, OpOverload): + overloads.append(op) + else: + if not isinstance(op, OpOverloadPacket): + raise AssertionError("op must be instance of OpOverloadPacket.") + for ol in op.overloads(): + overloads.append(getattr(op, ol)) + + for op_overload in overloads: + if op_overload in npu_meta_table: + raise RuntimeError(f"duplicate registrations for npu_meta_table {op_overload}") + npu_meta_table[op_overload] = fn + if avoid_fallback_flag: + avoid_make_fallback_table.append(op_overload) + +def patch_torch_decomp_decompositions(): + ''' + Because source torch_decomp_decompositions only enable the decompositions in + torch/_decomp/decompositions.py. Patch it to make decompositions in this file work. + ''' + src_func = _subclasses_fake_tensor.torch_decomp_decompositions + + @lru_cache(None) + def torch_decomp_decompositions_new(func): + if func in npu_meta_table.keys(): + return True + return src_func(func) + _subclasses_fake_tensor.torch_decomp_decompositions = torch_decomp_decompositions_new + +def register_meta_npu(op, avoid_fallback_flag=False): + def meta_decorator(fn: Callable): + _add_op_to_meta_table(op, fn, avoid_fallback_flag) + return fn + + return meta_decorator + +@run_once +def npu_patch_meta(): + ''' + Torch official register decompostions and meta func for some aten ops, + which will raise conflict when npu outputs' dtype and shape are different + from native impl. Delete decompositions and meta func of these ops and add + npu decompositions and meta func. + ''' + for op_overload, fn in npu_meta_table.items(): + if not isinstance(op_overload, OpOverload): + raise AssertionError("op_overload must be instance of OpOverload.") + if op_overload not in avoid_make_fallback_table: + decomposition_table[op_overload] = fn + op_overload.py_kernels.pop(DispatchKey.Meta, None) + op_overload.py_impl(DispatchKey.Meta)(fn) + + patch_torch_decomp_decompositions() + + +@register_meta_npu(aten.native_dropout) +def meta_native_dropout(tensor_input: Tensor, p: float, train: Optional[bool]): + if train and p != 0: + sizes_1 = tensor_input.shape + numel = reduce(operator.mul, sizes_1) + numel = (numel + 128 - 1) // 128 * 128 + numel = numel // 8 + return (torch.empty_like(tensor_input), torch.empty(numel, dtype=torch.uint8, device=tensor_input.device)) + else: + return (tensor_input, torch.ones_like(tensor_input, dtype=torch.bool)) + +@register_meta_npu(npu.npu_fusion_attention) +def npu_fusion_attention_forward(query, key, value, head_num, input_layout, pse=None, padding_mask=None, + atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, + inner_precise=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False): + B = query.size(0) + N = head_num + S1 = query.size(2) + S2 = key.size(2) + + if input_layout == "BSH": + B = query.size(0) + S1 = query.size(1) + S2 = key.size(1) + + if input_layout == "SBH": + B = query.size(1) + S1 = query.size(0) + S2 = key.size(0) + + seed = 0 + offset = 0 + numels = 0 + return (torch.empty_like(query).contiguous(), + query.new_empty([B, head_num, S1, 8], dtype=torch.float32), + query.new_empty([B, head_num, S1, 8], dtype=torch.float32), + query.new_empty([0]), + seed, + offset, + numels) + +@register_meta_npu(npu.npu_fusion_attention_grad) +def npu_fusion_attention_backward(query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None, atten_mask=None, + softmax_max=None, softmax_sum=None, softmax_in=None, attention_in=None, scale_value=1.0, + keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, inner_precise=0, seed=0, offset=0, + numels=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False): + return (torch.empty_like(query).contiguous(), torch.empty_like(key).contiguous(), torch.empty_like(value).contiguous(), query.new_empty([0])) diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/npu_patch_deprecated.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/npu_patch_deprecated.py new file mode 100644 index 0000000000000000000000000000000000000000..31ab9ccafa9ef609d72326a394613367f7dcbccb --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/npu_patch_deprecated.py @@ -0,0 +1,71 @@ +from typing import Dict,Any +import hashlib +import functools +import json + +import torch +from torch._inductor.codecache import CacheBase +from torch.distributed import distributed_c10d +from torch.distributed.distributed_c10d import ( + _world, + timedelta, +) +from torch.library import Library, impl +from ..npu.utils import get_anir_mode + +python_dispatcher_lib = Library("aten", "IMPL", "PythonDispatcher") + +@impl(python_dispatcher_lib, "embedding_backward") +def embedding_backward(grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse): + if sparse != False: + raise RuntimeError("the current NPU does not yet support sparse tensor, when sparse is set to True") + return torch.ops.aten.embedding_dense_backward(grad, indices, num_weights, padding_idx, scale_grad_by_freq) + +@impl(python_dispatcher_lib, "contiguous") +def py_contiguous(x, memory_format=torch.contiguous_format): + return x.clone(memory_format=memory_format) + + +@staticmethod +@functools.lru_cache(None) +def _patch_get_system() -> Dict[str, Any]: + system = {} + system["hash"] = hashlib.sha256( + json.dumps(system, sort_keys=True).encode("utf-8") + ).hexdigest() + + return system + +def _patch_add_ephemeral_timeout_for_all_pgs(timeout: timedelta) -> None: + """ + This API adds an ephemeral timeout extension for all PGs locally + on one rank. The timeout gets reset when the first collective issued + after API called finished. + NOTE: We only support to set timeout for cuda backends for now. + NOTE: While this feature + provides flexibility in specific scenarios, it introduces statefulness + to timeout setting. Therefore, it is advisable to use this API sparingly + and consider alternative approaches, such as directly setting the timeout + or utilizing a barrier collective (one can set any timeout to the barrier), + whenever feasible. + + Args: + timeout (timedelta): The delta of timeout to extend. + + Returns: + None. + """ + for pg in _world.pg_map.keys(): + devices = pg._device_types + if torch.device("npu") in devices: + backend = pg._get_backend(torch.device("npu")) + +CacheBase.get_system = _patch_get_system +distributed_c10d._add_ephemeral_timeout_for_all_pgs = _patch_add_ephemeral_timeout_for_all_pgs + +if get_anir_mode() == 'O0': + @torch.ops.aten.silu_backward.default.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd) + def my_silu_backward(grad_out, self): + # use with some caution: this is only really valid to run in the context of proxy tensor tracing + return NotImplemented + diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/npu_stream.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/npu_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..dd70b0df43f8848375d5d8b5442b864e08694530 --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/npu_stream.py @@ -0,0 +1,453 @@ +import torch +import torch_npu +import torch.library +from torch.library import Library + +from typing import Callable, Optional, Sequence, List, Tuple + +NPU_STREAMS = {} +NPU_EVENTS = {} + +# create a library to hold the custom op +npu_stream_lib = Library("npu_stream", "FRAGMENT") # noqa + +def direct_register_custom_op( + op_name: str, + op_func: Callable, + mutates_args: list[str], + fake_impl: Optional[Callable] = None, + target_lib: Optional[Library] = None, + dispatch_key: str = "PrivateUse1", + tags: tuple[torch.Tag, ...] = (), +): + """ + `torch.library.custom_op` can have significant overhead because it + needs to consider complicated dispatching logic. This function + directly registers a custom op and dispatches it to the CUDA backend. + See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 + for more details. + + By default, the custom op is registered to the vLLM library. If you + want to register it to a different library, you can pass the library + object to the `target_lib` argument. + + IMPORTANT: the lifetime of the operator is tied to the lifetime of the + library object. If you want to bind the operator to a different library, + make sure the library object is alive when the operator is used. + """ + import torch.library + if hasattr(torch.library, "infer_schema"): + schema_str = torch.library.infer_schema(op_func, + mutates_args=mutates_args) + else: + # for pytorch 2.4 + import torch._custom_op.impl + schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) + my_lib = target_lib or npu_stream_lib + my_lib.define(op_name + schema_str, tags=tags) + my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) + if fake_impl is not None: + my_lib._register_fake(op_name, fake_impl) + +class StreamResgistrator: + def __init__(self) -> None: + pass + + @staticmethod + def register_npu_stream(stream: torch.npu.Stream, tag: str = '0'): + NPU_STREAMS[tag] = stream + + @staticmethod + def register_npu_event(event: torch.npu.Event, tag: str = '0'): + NPU_EVENTS[tag] = event + +def npu_set_stream( + dependency: Sequence[torch.Tensor], + stream_tag: str, + ) -> List[torch.Tensor]: + stream = NPU_STREAMS[stream_tag] + torch_npu.npu.utils.set_stream(stream) + return dependency + +def npu_set_stream_fake( + dependency: Sequence[torch.Tensor], + stream_tag: str, + ) -> List[torch.Tensor]: + return dependency + +direct_register_custom_op( + op_name="npu_set_stream", + op_func=npu_set_stream, + mutates_args=[], + fake_impl=npu_set_stream_fake, + dispatch_key='PrivateUse1' +) + +def npu_event_record( + dependency: Sequence[torch.Tensor], + event_tag: str, + stream_tag: str + ) -> List[torch.Tensor]: + event = NPU_EVENTS[event_tag] + stream = NPU_STREAMS[stream_tag] + event.record(stream) + return dependency + +def npu_event_record_fake( + dependency: Sequence[torch.Tensor], + event_tag: str, + stream_tag: str + ) -> List[torch.Tensor]: + return dependency + +direct_register_custom_op( + op_name="npu_event_record", + op_func=npu_event_record, + mutates_args=[], + fake_impl=npu_event_record_fake, + dispatch_key='PrivateUse1' +) + +def npu_event_wait( + dependency: Sequence[torch.Tensor], + event_tag: str, + ) -> List[torch.Tensor]: + event = NPU_EVENTS[event_tag] + event.wait() + return dependency + +def npu_event_wait_fake( + dependency: Sequence[torch.Tensor], + event_tag: str, + ) -> List[torch.Tensor]: + return dependency + +direct_register_custom_op( + op_name="npu_event_wait", + op_func=npu_event_wait, + mutates_args=[], + fake_impl=npu_event_wait_fake, + dispatch_key='PrivateUse1' +) + +def graph_break( + dependency: Sequence[torch.Tensor], + ) -> List[torch.Tensor]: + return dependency + +def graph_break_fake( + dependency: Sequence[torch.Tensor], + ) -> List[torch.Tensor]: + return dependency + +utils_lib = Library("npu_utils", "FRAGMENT") # noqa + +direct_register_custom_op( + op_name="graph_break", + op_func=graph_break, + mutates_args=[], + target_lib=utils_lib, + fake_impl=graph_break_fake, + dispatch_key='PrivateUse1' +) + +def npu_wait_stream( + dependency: Sequence[torch.Tensor], + stream1_tag: str, + stream2_tag: str, + ) -> List[torch.Tensor]: + stream1 = NPU_STREAMS[stream1_tag] + stream2 = NPU_STREAMS[stream2_tag] + stream1.wait_stream(stream2) + return dependency + +def npu_wait_stream_fake( + dependency: Sequence[torch.Tensor], + stream1_tag: str, + stream2_tag: str, + ) -> List[torch.Tensor]: + return dependency + +direct_register_custom_op( + op_name="npu_wait_stream", + op_func=npu_wait_stream, + mutates_args=[], + fake_impl=npu_wait_stream_fake, + dispatch_key='PrivateUse1' +) + + +def graph_break( + *args +): + outputs = [] + for inp in args: + outputs.append(torch.ops.npu_utils.graph_break(inp)) + return outputs + +inductor_npu_lib = Library("inductor_npu", "FRAGMENT") # noqa + +def npu_fusion_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + head_num: int, + input_layout: str, + pse: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + atten_mask: Optional[torch.Tensor] = None, + scale: float = 1.0, + keep_prob: float = 1.0, + pre_tockens: int = 2147483647, + next_tockens: int = 2147483647, + inner_precise: int = 0, + prefix: Optional[torch.Tensor] = None, + actual_seq_qlen: Optional[torch.Tensor] = None, + actual_seq_kvlen: Optional[torch.Tensor] = None, + sparse_mode: int = 0, + gen_mask_parallel: bool = True, + sync: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + prefix = prefix.tolist() if prefix is not None else prefix + actual_seq_qlen = actual_seq_qlen.tolist() if actual_seq_qlen is not None else actual_seq_qlen + actual_seq_kvlen = actual_seq_kvlen.tolist() if actual_seq_kvlen is not None else actual_seq_kvlen + attention_score, softmax_max, softmax_sum, softmax_out, seed, offset, numels = torch.ops.npu.npu_fusion_attention( + query, + key, + value, + head_num, + input_layout, + pse=pse, + padding_mask=padding_mask, + atten_mask=atten_mask, + scale=scale, + keep_prob=keep_prob, + pre_tockens=pre_tockens, + next_tockens=next_tockens, + inner_precise=inner_precise, + prefix=prefix, + actual_seq_qlen=actual_seq_qlen, + actual_seq_kvlen=actual_seq_kvlen, + sparse_mode=sparse_mode, + gen_mask_parallel=gen_mask_parallel, + sync=sync + ) + + seed = torch.tensor([seed], device='npu', dtype=torch.int64) + offset = torch.tensor([offset], device='npu', dtype=torch.int64) + numels = torch.tensor([numels], device='npu', dtype=torch.int64) + + return attention_score, softmax_max, softmax_sum, softmax_out, seed, offset, numels + +def npu_fusion_attention_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + head_num: int, + input_layout: str, + pse: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + atten_mask: Optional[torch.Tensor] = None, + scale: float = 1.0, + keep_prob: float = 1.0, + pre_tockens: int = 2147483647, + next_tockens: int = 2147483647, + inner_precise: int = 0, + prefix: Optional[torch.Tensor] = None, + actual_seq_qlen: Optional[torch.Tensor] = None, + actual_seq_kvlen: Optional[torch.Tensor] = None, + sparse_mode: int = 0, + gen_mask_parallel: bool = True, + sync: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + B = query.size(0) + N = head_num + S1 = query.size(2) + S2 = key.size(2) + + if input_layout == "BSH": + B = query.size(0) + S1 = query.size(1) + S2 = key.size(1) + + if input_layout == "SBH": + B = query.size(1) + S1 = query.size(0) + S2 = key.size(0) + + attention_score = torch.empty_like(query, dtype=query.dtype, device=query.device).contiguous() + softmax_max = torch.empty([B, head_num, S1, 8], dtype=torch.float32, device=query.device) + softmax_sum = torch.empty([B, head_num, S1, 8], dtype=torch.float32, device=query.device) + softmax_out = torch.empty([0], dtype=query.dtype, device=query.device) + seed = torch.empty([1], dtype=torch.int64, device=query.device) + offset = torch.empty([1], dtype=torch.int64, device=query.device) + numels = torch.empty([1], dtype=torch.int64, device=query.device) + + return (attention_score, + softmax_max, + softmax_sum, + softmax_out, + seed, + offset, + numels) + +direct_register_custom_op( + op_name="npu_fusion_attention", + op_func=npu_fusion_attention, + mutates_args=[], + target_lib=inductor_npu_lib, + fake_impl=npu_fusion_attention_fake, + dispatch_key='PrivateUse1' +) + + +def npu_fusion_attention_grad( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dy: torch.Tensor, + head_num: int, + input_layout: str, + *, + pse: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + atten_mask: Optional[torch.Tensor] = None, + softmax_max: Optional[torch.Tensor] = None, + softmax_sum: Optional[torch.Tensor] = None, + softmax_in: Optional[torch.Tensor] = None, + attention_in: Optional[torch.Tensor] = None, + scale_value: float = 1.0, + keep_prob: float = 1.0, + pre_tockens: int = 2147483647, + next_tockens: int = 2147483647, + inner_precise: int = 0, + seed: Optional[torch.Tensor] = None, + offset: Optional[torch.Tensor] = None, + numels: Optional[torch.Tensor] = None, + prefix: Optional[torch.Tensor] = None, + actual_seq_qlen: Optional[torch.Tensor] = None, + actual_seq_kvlen: Optional[torch.Tensor] = None, + sparse_mode: int = 0, + gen_mask_parallel: bool = True, + sync: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + prefix = prefix.tolist() if prefix is not None else prefix + actual_seq_qlen = actual_seq_qlen.tolist() if actual_seq_qlen is not None else actual_seq_qlen + actual_seq_kvlen = actual_seq_kvlen.tolist() if actual_seq_kvlen is not None else actual_seq_kvlen + + seed = seed.item() + offset = offset.item() + numels = numels.item() + + dq, dk, dv, dpse = torch.ops.npu.npu_fusion_attention_grad( + query, key, value, dy, head_num, input_layout, pse=pse, padding_mask=padding_mask, atten_mask=atten_mask, + softmax_max=softmax_max, softmax_sum=softmax_sum, softmax_in=softmax_in, attention_in=attention_in, scale_value=scale_value, + keep_prob=keep_prob, pre_tockens=pre_tockens, next_tockens=next_tockens, inner_precise=inner_precise, seed=seed, offset=offset, + numels=numels, prefix=prefix, actual_seq_qlen=actual_seq_qlen, actual_seq_kvlen=actual_seq_kvlen, sparse_mode=sparse_mode, + gen_mask_parallel=gen_mask_parallel, sync=sync + ) + + return dq, dk, dv, dpse + +def npu_fusion_attention_grad_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dy: torch.Tensor, + head_num: int, + input_layout: str, + *, + pse: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + atten_mask: Optional[torch.Tensor] = None, + softmax_max: Optional[torch.Tensor] = None, + softmax_sum: Optional[torch.Tensor] = None, + softmax_in: Optional[torch.Tensor] = None, + attention_in: Optional[torch.Tensor] = None, + scale_value: float = 1.0, + keep_prob: float = 1.0, + pre_tockens: int = 2147483647, + next_tockens: int = 2147483647, + inner_precise: int = 0, + seed: Optional[torch.Tensor] = None, + offset: Optional[torch.Tensor] = None, + numels: Optional[torch.Tensor] = None, + prefix: Optional[torch.Tensor] = None, + actual_seq_qlen: Optional[torch.Tensor] = None, + actual_seq_kvlen: Optional[torch.Tensor] = None, + sparse_mode: int = 0, + gen_mask_parallel: bool = True, + sync: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + dq = torch.empty_like(query, dtype=query.dtype, device=query.device).contiguous() + dk = torch.empty_like(key, dtype=query.dtype, device=query.device).contiguous() + dv = torch.empty_like(value, dtype=query.dtype, device=query.device).contiguous() + dpse = torch.empty([0], dtype=query.dtype, device=query.device).contiguous() + return dq, dk, dv, dpse if pse else None + +direct_register_custom_op( + op_name="npu_fusion_attention_grad", + op_func=npu_fusion_attention_grad, + mutates_args=[], + target_lib=inductor_npu_lib, + fake_impl=npu_fusion_attention_grad_fake, + dispatch_key='PrivateUse1' +) + +class InductorNpuAttentionFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, query, key, value, head_num, input_layout, pse=None, padding_mask=None, atten_mask=None, scale=1.0, + keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, inner_precise=0, prefix=None, + actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False): + attention_score, softmax_max, softmax_sum, softmax_out, seed, offset, numels = torch.ops.inductor_npu.npu_fusion_attention( + query, key, value, head_num, input_layout, pse=pse, padding_mask=padding_mask, atten_mask=atten_mask, + scale=scale, keep_prob=keep_prob, pre_tockens=pre_tockens, next_tockens=next_tockens, + inner_precise=inner_precise, prefix=prefix, actual_seq_qlen=actual_seq_qlen, + actual_seq_kvlen=actual_seq_kvlen, sparse_mode=sparse_mode, gen_mask_parallel=gen_mask_parallel, sync=sync + ) + ctx.save_for_backward(query, key, value, pse, padding_mask, atten_mask, actual_seq_qlen, actual_seq_kvlen,\ + softmax_max, softmax_sum, softmax_out, attention_score, seed, offset, numels) + ctx.head_num = head_num + ctx.input_layout = input_layout + ctx.scale = scale + ctx.keep_prob = keep_prob + ctx.pre_tockens = pre_tockens + ctx.next_tockens = next_tockens + ctx.inner_precise = inner_precise + ctx.prefix = prefix + # ctx.actual_seq_qlen = actual_seq_qlen + # ctx.actual_seq_kvlen = actual_seq_kvlen + ctx.sparse_mode = sparse_mode + ctx.gen_mask_parallel = gen_mask_parallel + ctx.sync = sync + + return attention_score, softmax_max, softmax_sum, softmax_out, seed, offset, numels + + @staticmethod + def backward(ctx, grad_attention_score, grad_softmax_max, grad_softmax_sum, grad_softmax_out, grad_seed, grad_offset, grad_numels): + query, key, value, pse, padding_mask, atten_mask, actual_seq_qlen, actual_seq_kvlen, \ + softmax_max, softmax_sum, softmax_out, attention_score, seed, offset, numels = ctx.saved_tensors + grad_query, grad_key, grad_value, grad_pse = torch.ops.inductor_npu.npu_fusion_attention_grad( + query, key, value, grad_attention_score, ctx.head_num, ctx.input_layout, pse=pse, padding_mask=padding_mask, + atten_mask=atten_mask, softmax_max=softmax_max, softmax_sum=softmax_sum, softmax_in=softmax_out, attention_in=attention_score, + scale_value=ctx.scale, keep_prob=ctx.keep_prob, pre_tockens=ctx.pre_tockens, next_tockens=ctx.next_tockens, + inner_precise=ctx.inner_precise, seed=seed, offset=offset, numels=numels, prefix=None, + actual_seq_qlen=actual_seq_qlen, actual_seq_kvlen=actual_seq_kvlen, sparse_mode=ctx.sparse_mode, + gen_mask_parallel=ctx.gen_mask_parallel, sync=ctx.sync + ) + return ( + grad_query, grad_key, grad_value, None, None, grad_pse, None, None, None, None, None, None, None, None, None, + None, None, None, None, None, None, None, None, None, None, None) + +def inductor_npu_fusion_attention(query, key, value, head_num, input_layout, pse=None, padding_mask=None, + atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647, + next_tockens=2147483647, + inner_precise=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, + gen_mask_parallel=True, sync=False): + return InductorNpuAttentionFunction.apply(query, key, value, head_num, input_layout, pse, padding_mask, + atten_mask, scale, keep_prob, pre_tockens, next_tockens, + inner_precise, prefix, actual_seq_qlen, actual_seq_kvlen, sparse_mode, + gen_mask_parallel, sync) + +def apply_inductor_npu_attention_patch(): + torch.ops.npu.npu_fusion_attention = inductor_npu_fusion_attention diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/torch_mlir_patch.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/torch_mlir_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..638ef59a3345403503212b258ae3096c94d8a7ca --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/torch_mlir_patch.py @@ -0,0 +1,216 @@ +import math +import sympy +from typing import Optional, Dict, Union + +import torch +from torch_mlir import ir +from torch_mlir.extras import fx_importer +from torch_mlir.compiler_utils import ( + OutputType +) +from torch_mlir.dialects import torch as torch_d +from torch_mlir.fx import ( + _module_lowering, + FxImporter, + FxImporterHooks, +) +from torch_mlir.extras import fx_importer + +from torch_mlir.extras.fx_importer import ( + Graph, + Operation, + Callable, + func_dialect, + RangeConstraint, + Block, + GraphNodeImporter, + UnitAttr, + sympy_expr_to_semi_affine_expr +) + +from torch_mlir.ir import ( + AffineAddExpr, + AffineConstantExpr, + AffineExpr, + AffineMap, + AffineMapAttr, + AffineModExpr, + AffineMulExpr, + AffineSymbolExpr, + AffineFloorDivExpr +) + +from torch.utils._sympy.functions import ( + CeilDiv, + FloorDiv, + Identity, + IntTrueDiv, + ModularIndexing, +) + +def _patch_import_stateless_graph( + self, + g: Graph, + *, + func_name: str = "main", + func_visibility: Optional[str] = None, + import_symbolic_shape_expressions: bool = True, +) -> Operation: + """Low-level import of a functionalized, assumed stateless Graph as a func. + + TODO: This mechanism is deprecated by the `import_program` entry-point and + it should be removed when no longer required for backwards compatibility. + """ + + def get_range_constraints(graph: torch.fx.Graph): + range_constraints = {} + for nd in graph.find_nodes( + op="placeholder" + ): + if isinstance(nd.meta['val'], torch.Tensor): + for s in nd.meta['val'].size(): + if isinstance(s, torch.SymInt): + for symbol in s._sympy_().free_symbols: + range_constraints[symbol] = torch.utils._sympy.value_ranges.ValueRanges(128, 1024) + else: + for symbol in nd.meta['val']._sympy_().free_symbols: + range_constraints[symbol] = torch.utils._sympy.value_ranges.ValueRanges(128, 1024) + return range_constraints + + + def _sympy_int_to_int(val: sympy.Expr, adjust_func: Callable): + # Convert simple sympy Integers into concrete int + if val == sympy.oo: + return math.inf + if val == -sympy.oo: + return -math.inf + if isinstance(val, sympy.Integer): + return int(val) + # TODO: Remove this adjustment when fractional ranges are removed + return adjust_func(val) + + range_constraints = get_range_constraints(g) + + self._cc._symbolic_guards = { + str(k): RangeConstraint( + _sympy_int_to_int(v.lower, math.ceil), + _sympy_int_to_int(v.upper, math.floor), + ) + for k, v in range_constraints.items() + } + + ftype, loc = self._graph_to_function_meta(g) + # TODO: The FuncOp constructor requires a context-manager context. + # Fix upstream and then unnest. + # See: https://github.com/nod-ai/SHARK-Turbine/issues/138 + with loc: + func = func_dialect.FuncOp( + func_name, + ftype, + ip=self._m_ip, + visibility=func_visibility, + ) + func.attributes["torch.assume_strict_symbolic_shapes"] = UnitAttr.get() + entry_block = Block.create_at_start(func.body, ftype.inputs) + node_importer = GraphNodeImporter( + self, + self._c, + self._cc, + entry_block, + ) + node_importer.import_nodes( + g.nodes, import_symbolic_shape_expressions=import_symbolic_shape_expressions + ) + self.symbol_table.insert(func) + return func + + +def _patch_sympy_expr_to_semi_affine_expr( + expr: sympy.Expr, symbols_map: Dict[str, AffineSymbolExpr] +) -> AffineExpr: + """Translate sympy expressions to MLIR (semi-)affine expressions. + + Recursively traverse the sympy expr AST and build the affine expr. + This is not a perfect translation. Sympy expressions are much more + expressive and not as constrained as affine (linear) expressions are. + However, for the most part, we don't need to support all of sympy. + PyTorch only uses a subset of sympy for capturing and expressing + symbolic shapes, and among what's supported, we expect the semi-affine + expressions (https://mlir.llvm.org/docs/Dialects/Affine/#semi-affine-maps) + to be sufficient. + """ + + if isinstance(expr, sympy.Symbol): + return symbols_map[str(expr)] + elif isinstance(expr, (int, sympy.Integer)): + return AffineConstantExpr.get(expr) + # This handles both add (`s0 + c`) and subtract (`s0 - c`). + # The expression is `sympy.Add` in both cases but with args + # (s0, c) in first case and (s0, -c) in the second case. + elif isinstance(expr, sympy.Add): + affine_expr = AffineConstantExpr.get(0) + for arg in expr.args: + affine_expr = AffineAddExpr.get( + affine_expr, sympy_expr_to_semi_affine_expr(arg, symbols_map) + ) + return affine_expr + elif isinstance(expr, sympy.Mul): + affine_expr = AffineConstantExpr.get(1) + for arg in expr.args: + affine_expr = AffineMulExpr.get( + affine_expr, sympy_expr_to_semi_affine_expr(arg, symbols_map) + ) + return affine_expr + elif isinstance(expr, sympy.Pow): + base, exp = expr.args + # Only integer exponent is supported + # So, s1 ** s0 isn't allowed. + assert isinstance(exp, (int, sympy.Integer)) + assert exp > 0, "Only positive exponents supported in sympy.Pow" + affine_expr = AffineConstantExpr.get(1) + for _ in range(exp): + affine_expr = AffineMulExpr.get( + affine_expr, sympy_expr_to_semi_affine_expr(base, symbols_map) + ) + return affine_expr + elif isinstance(expr, sympy.Mod): + dividend, divisor = expr.args + return AffineModExpr.get( + sympy_expr_to_semi_affine_expr(dividend, symbols_map), + sympy_expr_to_semi_affine_expr(divisor, symbols_map), + ) + elif isinstance(expr, FloorDiv): + dividend, divisor = expr.args + return AffineFloorDivExpr.get( + sympy_expr_to_semi_affine_expr(dividend, symbols_map), + sympy_expr_to_semi_affine_expr(divisor, symbols_map), + ) + else: + raise NotImplementedError( + f"Translation of sympy.Expr of type {type(expr)} not implemented yet." + ) + + +fx_importer.FxImporter.import_stateless_graph = _patch_import_stateless_graph +fx_importer.sympy_expr_to_semi_affine_expr = _patch_sympy_expr_to_semi_affine_expr + +def stateless_fx_import( + gm: torch.fx.GraphModule, + output_type: Union[str, OutputType] = OutputType.RAW, + fx_importer: Optional[FxImporter] = None, + hooks: Optional[FxImporterHooks] = None, + model_name: str = "main", + enable_graph_printing: bool = False, + enable_ir_printing: bool = False, + import_symbolic_shape_expressions:bool = False, +): + if enable_graph_printing: + gm.print_readable() + context = ir.Context() + torch_d.register_dialect(context) + if fx_importer is None: + fx_importer = FxImporter(context=context, hooks=hooks) + fx_importer.import_stateless_graph(gm.graph, func_name=model_name, import_symbolic_shape_expressions=import_symbolic_shape_expressions) + return _module_lowering( + enable_ir_printing, OutputType.get(output_type), fx_importer.module + ) diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/utils.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eb9e25fb05e6a1067ec624a33e17202b280465c0 --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/npu/utils.py @@ -0,0 +1,814 @@ +import os +import sys +import re +import hashlib +import tempfile +import textwrap +import functools +import logging +import sysconfig +import shutil +import subprocess +import copy +import torch +import torch.nn as nn +from typing import Any, Tuple +from sympy import Expr +from pathlib import Path +from typing import List + +from typing import ( + Optional, + List, + Dict, + Union, + Tuple +) + +from torch.fx.graph_module import ( + _custom_builtins, + _addindent, + warnings +) + +try: + from torch_mlir import ir + import torch_mlir + from torch_mlir.dialects import func as func_dialect +except ImportError: + print("Can NOT find torch_mlir, INSTALL it first.") + +from ..build_info import ABI_TAG + +MLIR_DTYPE_MAPPING = { + "f32": torch.float32, + "i1" : torch.bool, + "bf16" : torch.bfloat16, + "f16" : torch.float16, + "si64" : torch.int64 +} + +def run_once(f): + """Runs a function (successfully) only once. + The running can be reset by setting the `has_run` attribute to False + """ + @functools.wraps(f) + def wrapper(*args, **kwargs): + if not wrapper.has_run: + result = f(*args, **kwargs) + wrapper.has_run = True + return result + return None + wrapper.has_run = False + return wrapper + +def get_device_info(example_inputs) -> Union[Tuple[str, int], None]: + for inp in example_inputs: + if isinstance(inp, torch.Tensor): + return inp.device, inp.device.index + +@functools.lru_cache(None) +def _get_ascend_path() -> str: + path = os.getenv("ASCEND_HOME_PATH", "") + if path == "": + raise Exception("ASCEND_HOME_PATH is not set, source /set_env.sh first") + return Path(path) + +def _build_npu_ext(obj_name: str, src_path, src_dir) -> str: + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so_path = os.path.join(src_dir, f"{obj_name}{suffix}") + + cxx = os.environ.get("CC") + if cxx is None: + clangxx = shutil.which("clang++") + gxx = shutil.which("g++") + cxx = gxx if gxx is not None else clangxx + if cxx is None: + raise RuntimeError("Failed to find C++ compiler") + cc_cmd = [cxx, src_path] + + # find the python library + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + + cc_cmd += [f"-I{py_include_dir}"] + + anir_path = str(Path(os.path.realpath(__file__)).parent.parent) + lib_dir = os.path.join(anir_path, 'lib') + cc_cmd += [ + f"-I{os.path.join(anir_path, '_C/include')}", + f"-I{os.path.join(anir_path, 'cpp_common')}", + f"-L{lib_dir}", + "-lcpp_common", f"-Wl,-rpath,{lib_dir}", + "-std=c++17", f"-D_GLIBCXX_USE_CXX11_ABI={ABI_TAG}", "-shared" + ] + cc_cmd += ["-fPIC", "-o", so_path] + ret = subprocess.check_call(cc_cmd) + + if ret == 0: + return so_path + else: + raise RuntimeError("Failed to compile " + src_path) + + +def parse_fx_example_inputs(gm: torch.fx.GraphModule): + name_to_example_inputs = {} + for node in gm.graph.nodes: + if node.op == 'placeholder': + name_to_example_inputs[node.name] = node.meta['val'] + return name_to_example_inputs + + +def generate_compiler_repro_string(gm: torch.fx.GraphModule): + from torch._dynamo.debug_utils import NNModuleToString + model_str = textwrap.dedent( + f""" +import torch +import torch_npu +from torch import tensor, device +import torch.fx as fx +from torch._dynamo.testing import rand_strided +from math import inf +import torch._inductor.inductor_prims + + """ + ) + + model_str += NNModuleToString.convert(gm) + model_str += "\n" + model_str += "mod = Repro()\n" + return model_str + + +def get_fx_graph_code(code, num_args, method=2, runnable=False, kernel_code='', kernel_name=None): + kernel_header = '' + kernel_wrapper = '' + kernel_runner_and_acc_comp = '' + if len(kernel_code): + kernel_header = """ +from torch import empty_strided, empty, randn +from torch_npu._inductor.ascend_npu_ir.ascend_npu_ir.codecache import CustomAsyncCompile +from torch_npu._inductor.ascend_npu_ir.ascend_npu_ir.npu.utils import ( + logger, +) +import logging +logger.setLevel(logging.INFO) +async_compile = CustomAsyncCompile() + +""" + kernel_wrapper = """ +from torch_npu._inductor.ascend_npu_ir.ascend_npu_ir.npu.npu_inductor_plugin import get_current_raw_stream as get_raw_stream + +async_compile.wait(globals()) +del async_compile + +stream0 = get_raw_stream(0) +""" + kernel_runner_and_acc_comp = f""" +kernel_dump_path = os.path.join(dir_path, 'kernel_dump') +for file_name in os.listdir(kernel_dump_path): + kernel_path = os.path.join(kernel_dump_path, file_name) + {kernel_name}.replace_kernel_by_path(kernel_path) + {kernel_name}.run( + *args, + stream=stream0) + + output1 = args[num_args:] + + if not os.environ.get("DISABLE_ACC_COMP", "0") == "1": + for o1, o2 in zip(output1, output2): + if o2.dtype != o1.dtype: + o2 = o2.to(o1.dtype) + acc_comp_tol = npu_config.acc_comp_tol.get(o1.dtype, npu_config.acc_comp_tol['default']) + rtol = acc_comp_tol['rtol'] + atol = acc_comp_tol['atol'] + torch.testing.assert_close(o1, o2, rtol=rtol, atol=atol, equal_nan=False) + print('accuracy success!') +""" + code = textwrap.indent(code, ' ') + transformed_code_template = f""" +def get_args(): + args = torch.load(os.path.join(dir_path, "data.pth")) + args = [arg.npu() if isinstance(arg, torch.Tensor) else arg for arg in args] + num_args = {num_args} + + return args +""" + run_code_template = f""" + +try: + args = torch.load(os.path.join(dir_path, "data.pth")) +except Exception as e: + {{{{FAKE_ARGS_PLACEHOLDER}}}} +args = [arg.npu() if isinstance(arg, torch.Tensor) else arg for arg in args] +num_args = {num_args} + +fx_inputs = [clone_preserve_strides(arg) for arg in args[:num_args]] +""" + fx_runner = f""" +fx_inputs = [inp.float() if inp.dtype == torch.bfloat16 else inp for inp in fx_inputs] +with torch.no_grad(): + output2 = model(*fx_inputs) +""" + code_template = f""" +import os +import torch +from torch._inductor.compile_fx import clone_preserve_strides +from torch._dynamo.testing import rand_strided +from torch import device + +import torch_npu +from torch_npu._inductor.ascend_npu_ir.ascend_npu_ir import config as npu_config +{kernel_header} +file_path = os.path.abspath(__file__) +dir_path = os.path.dirname(file_path) + +{kernel_code} +{kernel_wrapper} + +class GraphModule(torch.nn.Module): + def __init__(self): + super().__init__() +{code} +model = GraphModule().npu() + +{run_code_template if runnable else transformed_code_template} +{fx_runner if runnable else ''} +{kernel_runner_and_acc_comp if runnable else ''} +""" + return code_template + +def codegen_python_shape_tuple(shape: Tuple[Expr, ...]) -> str: + from torch._inductor.virtualized import V + parts = list(map(V.graph.wrapper_code.codegen_python_sizevar, shape)) + if len(parts) == 0: + return "()" + if len(parts) == 1: + return f"({parts[0]}, )" + return f"({', '.join(parts)})" + +def view_to_reshape(gm: torch.fx.GraphModule): + for nd in gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.view.default + ): + nd.target = torch.ops.aten.reshape.default + + for nd in gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.div.Tensor + ): + if not (isinstance(nd.args[1], torch.fx.node.Node) and \ + isinstance(nd.args[1].meta['val'], torch.Tensor)): + nd.target = torch.ops.aten.div.Scalar + + for nd in gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.add.Tensor + ): + if not (isinstance(nd.args[1], torch.fx.node.Node) and \ + isinstance(nd.args[1].meta['val'], torch.Tensor)): + nd.target = torch.ops.aten.add.Scalar + + for nd in gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.sub.Tensor + ): + if not (isinstance(nd.args[1], torch.fx.node.Node) and \ + isinstance(nd.args[1].meta['val'], torch.Tensor)): + nd.target = torch.ops.aten.sub.Scalar + + for nd in gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.mul.Tensor + ): + if not (isinstance(nd.args[1], torch.fx.node.Node) and \ + isinstance(nd.args[1].meta['val'], torch.Tensor)): + nd.target = torch.ops.aten.mul.Scalar + + for nd in gm.graph.find_nodes( + op="call_function", target=torch.ops.prims.convert_element_type.default + ): + nd.target = torch.ops.npu.npu_dtype_cast.default + +def npu_cast_to_prim_cast(gm: torch.fx.GraphModule): + """ + Replace npu.npu_dtype_cast ops in the GraphModule to prims.convert_element_type ops. + """ + new_gm = copy.deepcopy(gm) + for nd in new_gm.graph.nodes: + if nd.target in [torch.ops.npu.npu_dtype_cast.default, torch.ops.npu.npu_dtype_cast_backward.default, torch.ops.npu._npu_dtype_cast.default, torch.ops.npu._npu_dtype_cast_backward.default]: + nd.target = torch.ops.prims.convert_element_type.default + if nd.target in [torch.ops.aten.index_put_.default]: + nd.target = torch.ops.aten.index_put.default + return new_gm + +def modify_gm_for_acc_comp(gm: torch.fx.GraphModule): + """ + In precision comparison mode, if the second argument of npu_dtype_cast is torch.bfloat16, change it to torch.float32. + """ + for nd in gm.graph.nodes: + if nd.target in [torch.ops.npu.npu_dtype_cast.default, torch.ops.npu.npu_dtype_cast_backward.default, torch.ops.npu._npu_dtype_cast.default, torch.ops.npu._npu_dtype_cast_backward.default]: + if nd.args[1] == torch.bfloat16: + new_args = list(nd.args) + new_args[1] = torch.float32 + nd.args = tuple(new_args) + +def replace_iota_int64_to_int32(nd: torch.fx.Node): + """ + Replace iota dtype from int64 to int32. + """ + if nd.target in [torch.ops.prims.iota.default] and nd.kwargs['dtype'] == torch.int64: + new_args = dict(nd.kwargs) + new_args['dtype'] = torch.int32 + nd.kwargs = new_args + +def npu_optimize_fx_graph(gm: torch.fx.GraphModule): + """ + optimize fx graph for npu + """ + aten_empty_nodes = set() + for nd in gm.graph.nodes: + replace_iota_int64_to_int32(nd) + # Replace npu type_as ops in the GraphModule to cast ops. + if nd.target == torch.ops.aten.empty.memory_format and len(nd.users) == 1: + aten_empty_nodes.add(nd) + if nd.target == torch.ops.aten.copy.default: + node0 = nd.args[0] + if node0 in aten_empty_nodes: + with gm.graph.inserting_after(nd): + dtype = node0.kwargs.get('dtype') + op_target = torch.ops.npu.npu_dtype_cast.default + args = (nd.args[1], dtype) + new_node = gm.graph.call_function(op_target, args=args) + new_node.name = nd.name + nd.replace_all_uses_with(new_node) + gm.graph.erase_node(nd) + aten_empty_nodes.remove(node0) + gm.graph.erase_node(node0) + + gm.recompile() + +def get_last_node(gm: torch.fx.GraphModule): + last_node = None + for node in gm.graph.nodes: + last_node = node + return last_node + +def fx_graph_op_types(gm: torch.fx.GraphModule) -> List[str]: + op_types = [] + for nd in gm.graph.nodes: + if nd.op not in ['call_function', 'call_method', 'call_module']: + continue + type_str = str(nd.target) + if type_str.startswith(('aten', 'prims', 'npu')): + op_types.append(type_str.split('.')[1]) + return op_types + +# Borrowed from https://github.com/llvm/torch-mlir/blob/2b01f8b7f3cca87c3dc9c75edd91397803e9f6d4/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py#L67 +def scalarize_tensor_ops_on_scalars(gm: torch.fx.GraphModule): + # Modify gm.graph + for node in gm.graph.nodes: + # Checks if we're calling a function (i.e: + # torch.add) + if node.op == "call_function": + # The target attribute is the function + # that call_function calls. + # call_function[target=torch.ops.aten.add.Tensor](args = (%arg64_1, 1), kwargs = {}) + if node.target == torch.ops.aten.add.Tensor: + if len(node.args) != 2 or node.kwargs != {}: + continue + elif not isinstance(node.args[1], torch.fx.node.Node): + node.target = torch.ops.aten.add.Scalar + if node.target == torch.ops.aten.mul.Tensor: + if len(node.args) != 2 or node.kwargs != {}: + continue + elif not isinstance(node.args[1], torch.fx.node.Node): + node.target = torch.ops.aten.mul.Scalar + + gm.graph.lint() # Does some checks to make sure the + + # Recompile the forward() method of `gm` from its Graph + gm.recompile() + +def generate_fake_inputs(name_to_example_inputs): + inputs_str = "" + for name, example_input in name_to_example_inputs.items(): + input_str = ( + f"{name} = rand_strided(" + f"{codegen_python_shape_tuple(example_input.size())}, " + f"{codegen_python_shape_tuple(example_input.stride())}, " + f"device='{example_input.device}', dtype={example_input.dtype})" + ) + inputs_str += f" {input_str}\n" + + return inputs_str + + +def get_num_call_functions(graph): + num_call_functions = 0 + for node in graph.graph.nodes: + if node.op == "call_function" and node.target != torch.ops.aten.reshape.default: + num_call_functions += 1 + if num_call_functions > 1: + break + return num_call_functions + +class MLIRProcessor: + def __init__(self, bisheng_install_path: str = None): + """ + 初始化MLIR处理器 + + :param bisheng_install_path: Bisheng安装路径,默认从环境变量获取 + """ + self.bisheng_torch_mlir_path = f"bishengir-opt" + + def extract_function(self, module: ir.Module) -> func_dialect.FuncOp: + """从MLIR模块中提取主函数并添加标记属性""" + with module.context: + for func in module.body.operations: + if isinstance(func, func_dialect.FuncOp): + func.attributes["hacc.placeholder"] = ir.UnitAttr.get(func.context) + return func + raise ValueError("No valid FuncOp found in module") + + def rebuild_mlir_module(self, module_str: str) -> ir.Module: + """从字符串重新构建MLIR模块""" + with ir.Context() as ctx: + ctx.allow_unregistered_dialects = True + torch_mlir.dialects.torch.register_dialect(ctx) + return ir.Module.parse(module_str) + + def get_signature(self, func: func_dialect.FuncOp) -> tuple: + """获取函数的签名信息:类型签名、输出数量和张量维度""" + func_type = func.type + signature = {} + ranks = [] + + # 处理输入+输出类型 + for i, tensor_type in enumerate(func_type.inputs + func_type.results): + try: # RankedTensorType + signature[i] = '*' + str(tensor_type.element_type) + ranks.append(len(tensor_type.shape)) + except AttributeError: # ValueTensorType + type_str = str(tensor_type) + signature[i] = '*' + type_str.split(',')[-1].split('>')[0] + # 从类型字符串中提取维度信息 + dim_start = type_str.find('[') + 1 + dim_end = type_str.find(']', dim_start) + dim_str = type_str[dim_start:dim_end] + ranks.append(dim_str.count(',') + 1 if dim_str else 1) + + num_outputs = len(func_type.results) + return signature, num_outputs, ranks + + def process_mlir(self, + module: Union[str, ir.Module], + get_sig: bool = True, + dynamic: bool = False) -> tuple: + """ + 处理MLIR模块的核心方法 + + :param module: MLIR模块字符串或对象 + :param get_sig: 是否获取函数签名 + :param dynamic: 是否为动态执行模式 + :return: (函数字符串, 元数据字典) + """ + if isinstance(module, str): + module = self.rebuild_mlir_module(module) + + func = self.extract_function(module) + kernel_info = None + func_str = str(func) + func_hash_str = func_str + "_host" if dynamic else func_str + module_hash = hashlib.sha256(func_hash_str.encode()).hexdigest() + logger.info(f"Generated kernel hash: {module_hash}") + + if get_sig: + signature, num_outputs, ranks = self.get_signature(func) + kernel_info = { + "signature": signature, + "ranks": ranks, + 'kernel_hash': module_hash, + } + + return func_str, kernel_info + + def get_named_op_str(self, + module: Union[str, ir.Module], + kernel_name: str, + dynamic: bool = False) -> Dict[str, Any]: + """ + 获取命名操作格式的MLIR字符串 + + :param module: MLIR模块字符串或对象 + :param kernel_name: 内核名称(用于临时文件) + :param dynamic: 是否为动态执行模式 + :return: 包含处理结果和签名字典 + """ + func_str, sig_dict = self.process_mlir(module, get_sig=True, dynamic=dynamic) + + cleaned_func = func_str.replace( + '"#hfusion.fusion_kind"', + '#hfusion.fusion_kind' + ) + logger.debug(f"原始Linalg方言MLIR:\n{cleaned_func}") + + # 执行转换命令 + with tempfile.TemporaryDirectory() as tmpdir: + torch_mlir_path = os.path.join(tmpdir, f"{kernel_name}.mlir") + with open(torch_mlir_path, 'w') as f: + f.write(cleaned_func) + + cmd = (f"{self.bisheng_torch_mlir_path} " + "--torch-backend-to-named-op-backend-pipeline=" + "\"ensure-no-implicit-broadcast=true\" " + f"{torch_mlir_path}") + + try: + result = subprocess.check_output( + cmd, text=True, shell=True + ) + # 过滤全局定义并更新函数属性 + processed_mlir = "\n".join( + line for line in result.splitlines() + if "ml_program.global" not in line + ) + + # 根据模式设置函数属性 + func_attr = ("hacc.entry, hacc.function_kind = #hacc.function_kind" + if dynamic else + "hacc.entry, hacc.function_kind = #hacc.function_kind") + processed_mlir = processed_mlir.replace("hacc.placeholder", func_attr) + + # 应用额外的数据类型处理(需实现mlir_match_and_replace_unsupported_dtypes) + final_mlir = self._replace_unsupported_dtypes(processed_mlir) + logger.debug(f"转换后的NamedOp方言MLIR:\n{final_mlir}") + + return final_mlir, sig_dict + + except subprocess.CalledProcessError as e: + logger.error(f"命令执行失败: {cmd}\n错误: {e.output}") + raise RuntimeError(f"MLIR转换失败: {e.stderr}") from e + + def _replace_unsupported_dtypes(self, mlir_text: str) -> str: + """替换不支持的MLIR数据类型""" + pattern1 = r"%(\d+) = arith\.truncf %(\w+) : f64 to bf16" + matches1 = re.findall(pattern1, mlir_text) + + for var1, var2 in matches1: + pattern2 = rf"%" + var2 + r" = arith\.constant (\d+(\.\d+)?) : f64" + match2 = re.search(pattern2, mlir_text) + if match2: + mlir_text = re.sub(r': f64', ': f32', mlir_text) + return mlir_text + +def mlir_match_and_replace_unsupported_dtypes(mlir_text: str) -> str: + pattern1 = r"%(\d+) = arith\.truncf %(\w+) : f64 to bf16" + matches1 = re.findall(pattern1, mlir_text) + + for var1, var2 in matches1: + pattern2 = rf"%" + var2 + r" = arith\.constant (\d+(\.\d+)?) : f64" + match2 = re.search(pattern2, mlir_text) + if match2: + mlir_text = re.sub(r': f64', ': f32', mlir_text) + return mlir_text + + +def to_folder( + gm: torch.fx.GraphModule, + folder: Union[str, os.PathLike], + graph_hash: str, + module_name: str = "FxModule"): + """Dumps out module to ``folder`` with ``module_name`` so that it can be + imported with ``from import `` + + Args: + + folder (Union[str, os.PathLike]): The folder to write the code out to + + module_name (str): Top-level name to use for the ``Module`` while + writing out the code + """ + folder = Path(folder) + Path(folder).mkdir(exist_ok=True) + tab = " " * 4 + custom_builtins = "\n".join([v.import_str for v in _custom_builtins.values()]) + model_str = f""" +import torch +{custom_builtins} + +from torch.nn import * + +class {module_name}(torch.nn.Module): + def __init__(self): + super().__init__() +""" + + def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: + safe_reprs = [ + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + ] + if type(module) in safe_reprs: + return f"{module.__repr__()}" + else: + return None + + blobified_modules = [] + for module_name, module in gm.named_children(): + module_str = _gen_model_repr(module_name, module) + if module_str is None: + module_file = folder / f"{module_name}.pt" + torch.save(module, module_file) + blobified_modules.append(module_name) + module_repr = module.__repr__().replace("\r", " ").replace("\n", " ") + # weights_only=False as this is legacy code that saves the model + module_str = ( + f"torch.load(r'{module_file}', weights_only=False) # {module_repr}" + ) + model_str += f"{tab * 2}self.{module_name} = {module_str}\n" + + for buffer_name, buffer in gm._buffers.items(): + if buffer is None: + continue + model_str += f"{tab * 2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" # noqa: B950 + + for param_name, param in gm._parameters.items(): + if param is None: + continue + model_str += f"{tab * 2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" # noqa: B950 + + model_str += f"{_addindent(gm.code, 4)}\n" + + module_file = folder / f"{graph_hash}.py" + module_file.write_text(model_str) + + if len(blobified_modules) > 0: + warnings.warn( + "Was not able to save the following children modules as reprs -" + f"saved as pickled files instead: {blobified_modules}" + ) + +def get_anir_mode(): + mode = os.getenv('ANIR_MODE', 'O1') + + if mode not in ["O0", "O1"]: + raise ValueError(f"Invalid MODE value: {mode}. Allowed values are 'O0' and 'O1'.") + return mode + +def is_fx_dynamic(graph): + for node in graph.graph.nodes: + if node.op == "placeholder": + if 'tensor_meta' in node.meta: + shape = node.meta['tensor_meta'].shape + if any(isinstance(dim, torch.SymInt) for dim in shape): + return True + elif node.op == "call_function": + if isinstance(node.meta['val'], torch.Tensor): + if any(isinstance(dim, torch.SymInt) for dim in node.meta['val'].shape): + return True + return False + +def replace_placeholders(file_path: str, replacements: dict, placeholder_format: str = r'\{\{(\w+)\}\}') -> None: + """ + 替换文件中的占位符 + + :param file_path: 文件路径 + :param replacements: 替换字典,如 {'function_body': 'your _code'} + :param placeholder_format: 占位符正则表达式(默认匹配{{xxx}}) + """ + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + pattern = re.compile(placeholder_format) + + def replacer(match: re.Match) -> str: + placeholder = match.group(1) + replacement = replacements.get(placeholder, match.group(0)) + + line_start = content.rfind('\n', 0, match.start()) + 1 + indent = re.match(r'^\s*', content[line_start:match.start()]).group(0) + + return '\n'.join([indent + line for line in replacement.split('\n')]) + + new_content = pattern.sub(replacer, content) + + with open(file_path, 'w', encoding='utf-8') as f: + f.write(new_content) + +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean", + device_type="npu"): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float], optional + :param fast_flush: Use faster kernel to flush L2 cache between measurements + :type fast_flush: bool, default is True + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". :type return_mode: str + """ + assert return_mode in ["min", "max", "mean", "median", "all"] + import torch + + di = torch._dynamo.device_interface.get_interface_for_device(device_type) + + fn() + di.synchronize() + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 cache + # doesn't contain any input data before the run + cache_size = 256 * 1024 * 1024 + if fast_flush: + cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device_type) + else: + cache = torch.empty(int(cache_size), dtype=torch.int8, device=device_type) + + # Estimate the runtime of the function + start_event = di.Event(enable_timing=True) + end_event = di.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + di.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + start_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + di.synchronize() + times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) + return torch.mean(times).item() + +def _get_logger(*, level=logging.ERROR, file=None, name=None): + logger = logging.getLogger(name) + logger.setLevel(level) + + try: + import colorlog + except ImportError: + formatter = logging.Formatter( + '[%(levelname)s] ANIR %(asctime)s %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + else: + formatter = colorlog.ColoredFormatter( + '%(log_color)s[%(levelname)s] ANIR %(asctime)s %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + log_colors={ + 'DEBUG': 'cyan', + 'INFO': 'green', + 'WARNING': 'yellow', + 'ERROR': 'red', + 'CRITICAL': 'bold_red', + }) + + if file: + file_handler = logging.FileHandler(file) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + else: + handler = logging.StreamHandler() + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.propagate = False + return logger + +logger = _get_logger(name='anir') \ No newline at end of file diff --git a/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/triton.py b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/triton.py new file mode 100644 index 0000000000000000000000000000000000000000..67ca4337cc420d1cd0ac88297492a332b408fb01 --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/ascend_npu_ir/triton.py @@ -0,0 +1,2 @@ +def constexpr(): + pass \ No newline at end of file diff --git a/torch_npu/_inductor/ascend_npu_ir/build_ext.py b/torch_npu/_inductor/ascend_npu_ir/build_ext.py new file mode 100755 index 0000000000000000000000000000000000000000..c9a9b26224064d60110c0610b39d4ea072569c56 --- /dev/null +++ b/torch_npu/_inductor/ascend_npu_ir/build_ext.py @@ -0,0 +1,169 @@ +import os +import shutil +import subprocess +import stat +import sysconfig +import sys +import functools +import threading +from pathlib import Path + +from setuptools import setup +from pybind11.setup_helpers import Pybind11Extension + +import torch +import torch_npu +import torch.distributed as dist + +# 获取脚本绝对路径 +BASE_DIR = Path(os.path.dirname(os.path.abspath(__file__))) +asc_path = os.getenv("ASCEND_HOME_PATH", "") +if not asc_path: + raise RuntimeError("ASCEND_HOME_PATH is not set, source /set_env.sh first") + +torch_npu_path = os.path.dirname(os.path.realpath(torch_npu.__file__)) + +# 创建全局安装目录 +INSTALL_DIR = BASE_DIR / "ascend_npu_ir" +LIB_DIR = INSTALL_DIR / "lib" +os.makedirs(LIB_DIR, exist_ok=True) +os.chmod(LIB_DIR, stat.S_IRWXU | stat.S_IRGRP | stat.S_IXGRP) + +def run_once(func): + result = None + has_run = False + lock = threading.Lock() + + @functools.wraps(func) + def wrapper(*args, **kwargs): + nonlocal result, has_run + + if has_run: + return result + + with lock: + if not has_run: + result = func(*args, **kwargs) + has_run = True + + return result + + return wrapper + +def get_cxx_compiler(): + """获取C++编译器路径""" + cxx = os.environ.get("CXX") or os.environ.get("CC") + if cxx: + return cxx + for compiler in ["clang++", "g++"]: + if path := shutil.which(compiler): + return path + raise RuntimeError("Failed to find C++ compiler (tried clang++, g++)") + +@run_once +def anir_build_libcpp_common(so_path): + """构建共享库(带缓存检查)""" + src_path = BASE_DIR / "ascend_npu_ir" / "cpp_common" / "cpp_common.cpp" + + # 获取Python头文件路径 + scheme = sysconfig.get_default_scheme() + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include = sysconfig.get_paths(scheme=scheme)["include"] + + # 构建编译命令 + cc_cmd = [ + get_cxx_compiler(), + str(src_path), + f"-I{py_include}", + f"-I{BASE_DIR / 'ascend_npu_ir' / '_C' / 'include'}", + f"-I{asc_path}/include", + f"-L{asc_path}/lib64", + f"-I{os.path.dirname(os.path.realpath(torch.__file__))}/include", + f"-I{os.path.join(torch_npu_path, 'include')}", + f"-L{os.path.join(torch_npu_path, 'lib')}", + "-lruntime", "-lascendcl", "-ltorch_npu", "-lprofapi", + "-std=c++17", + f"-D_GLIBCXX_USE_CXX11_ABI={int(torch._C._GLIBCXX_USE_CXX11_ABI)}", + "-fPIC", "-shared", "-o", str(so_path) + ] + + # 执行编译 + print(f"Executing: {' '.join(cc_cmd)}") + if (ret := subprocess.call(cc_cmd)) != 0: + raise RuntimeError(f"Build failed with code {ret}") + print(f"Successfully built: {so_path}") + +@run_once +def anir_build_pybind_extension(): + """构建Python扩展模块""" + # 确保扩展模块被构建到正确的目录 + build_lib_dir = str(INSTALL_DIR) + + extension = Pybind11Extension( + '_C', + [str(BASE_DIR / 'ascend_npu_ir' / '_C' / 'extension.cpp')], + include_dirs=[ + str(BASE_DIR / 'ascend_npu_ir' / '_C' / 'include'), + f'{asc_path}/include' + ], + library_dirs=[ + f'{asc_path}/lib64', + str(LIB_DIR) + ], + libraries=['runtime', 'cpp_common'], + extra_link_args=[ + f'-Wl,-rpath,{asc_path}/lib64', + f'-Wl,-rpath,{LIB_DIR}' + ], + extra_compile_args=["-std=c++17"], + ) + + # 切换到项目根目录进行构建 + original_cwd = os.getcwd() + os.chdir(BASE_DIR) + + try: + setup( + name="ascend_npu_ir", + version="0.1", + ext_modules=[extension], + script_args=["build_ext", f"--build-lib={build_lib_dir}"], + ) + finally: + os.chdir(original_cwd) # 恢复原始工作目录 + +def main_process_only(func): + """ + 装饰器:仅 rank 0 执行函数,其他进程等待。 + 适用于无返回值或无需返回值的函数(如 mkdir, print, download 等) + """ + def wrapper(*args, **kwargs): + if dist.is_initialized(): + if dist.get_rank() == 0: + result = func(*args, **kwargs) + else: + result = None # 非主进程不执行 + dist.barrier() # 同步:等待 rank 0 完成 + return result + else: + # 非分布式环境直接执行 + return func(*args, **kwargs) + return wrapper + + +@main_process_only +def build_ascend_npu_ir_ext(): + try: + so_path = LIB_DIR / "libcpp_common.so" + if not so_path.exists(): + print(f"Building libcpp_common.so at {so_path}") + anir_build_libcpp_common(so_path) + anir_build_pybind_extension() + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + build_ascend_npu_ir_ext() \ No newline at end of file diff --git a/torch_npu/utils/_dynamo.py b/torch_npu/utils/_dynamo.py index c5f9ba5272eebeac95f4b023408e792b1df5cc13..ae97489ab2c4198313252824727b55b7cb4c2ff9 100644 --- a/torch_npu/utils/_dynamo.py +++ b/torch_npu/utils/_dynamo.py @@ -158,6 +158,9 @@ def patch_inductor_wrapper(): src_call = _TorchCompileInductorWrapper.__call__ def new_call(self, model_, inputs_): + if self.config.get('max_autotune', False): + import os + os.environ['TORCHINDUCTOR_MAX_AUTOTUNE'] = '1' register_inductor_npu() return src_call(self, model_, inputs_) _TorchCompileInductorWrapper.__call__ = new_call diff --git a/torch_npu/utils/_inductor.py b/torch_npu/utils/_inductor.py index 9a36ddb6f0b95595b8a74380cb868c3b7792a0a6..42e9d183361d5ea36a1c809b344b47f828b02782 100644 --- a/torch_npu/utils/_inductor.py +++ b/torch_npu/utils/_inductor.py @@ -3,7 +3,7 @@ from torch._inductor.codegen.common import DeviceOpOverrides, register_device_op class NPUDeviceOpOverrides(DeviceOpOverrides): def import_get_raw_stream_as(self, name): - return f"from torch._C import _npu_getCurrentRawStream as {name}" + return f"from torch_npu._C import _npu_getCurrentRawStream as {name}" def set_device(self, device_idx): return f"torch_npu.npu.set_device({device_idx})"