diff --git a/test/_inductor/__init__.py b/test/_inductor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2497b84aeda72a81c72604bbd678e7cd0494594 --- /dev/null +++ b/test/_inductor/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. \ No newline at end of file diff --git a/test/_inductor/commonutils.py b/test/_inductor/commonutils.py new file mode 100644 index 0000000000000000000000000000000000000000..00527a41d87baaf9de5f31495acf6fe9f852bc27 --- /dev/null +++ b/test/_inductor/commonutils.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + +import subprocess +from io import StringIO +from typing import List + + +""" +对外统一调用接口 +""" +def get_available_npu_device_ids(): + npu_ids = _get_all_npu_device_ids() + sorted_npu_dict = _sort_npu_by_usage_cap(npu_ids) + return list(sorted_npu_dict.keys()) + + +""" +通过ls /dev/davinci*获取所有的npu的id +""" +def _get_all_npu_device_ids(): + ## ls /dev/davinci* + buffer = StringIO() + try: + result = subprocess.run( + ["ls /dev/davinci*"], + capture_output=True, + shell=True, + text=True, + check=True + ) + output = result.stdout + buffer.write(output) + except subprocess.CalledProcessError as e: + print(f"Error running command: {e}") + finally: + content = buffer.getvalue() + buffer.close() + + npu_ids = [] + if content is None: + return npu_ids + for line in content.splitlines(): + if not line[-1].isdigit(): + continue + idx = -1 + while line[idx].isdigit(): + idx -= 1 + id = line[idx + 1:] + npu_ids.append(id) + return npu_ids + + +""" +通过npu-smi info -t usages -i %id 获取每个卡的使用率并升序排序 +返回字典{id:[HBM Capacity(MB), HBM Usage Rate(%)]},按使用率升序,使用率相同按容量降序 +""" +def _sort_npu_by_usage_cap(npu_ids: List[str]) -> List[int]: + npu_dict = dict() + try: + for id in npu_ids: + result = subprocess.run(["npu-smi info -t usages -i " + id], + capture_output=True, + text=True, + shell=True, + check=True) + ss = result.stdout + ## [HBM Capacity(MB), HBM Usage Rate(%)] + tmp = [] + for line in ss.splitlines(): + if ":" not in line: + continue + key, val = line.split(":") + key, val = key.strip(), val.strip() + if key == "HBM Usage Rate(%)": + tmp.append(val) + if key == "HBM Capacity(MB)": + tmp.append(val) + if tmp is not None: + npu_dict[int(id)] = tmp + sorted_npu_dict = dict(sorted(npu_dict.items(), key=lambda x: (int(x[1][1]), -int(x[1][0])))) + return sorted_npu_dict + except subprocess.CalledProcessError as e: + print(f"Error running command: {e}") + + +if __name__ == '__main__': + res = get_available_npu_device_ids() + print(res) \ No newline at end of file diff --git a/test/_inductor/conftest.py b/test/_inductor/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..0b907d9a290b85769e9e746082fdda13c46791d6 --- /dev/null +++ b/test/_inductor/conftest.py @@ -0,0 +1,22 @@ +import pytest +import os +import torch_npu._inductor +import getpass + + +def pytest_addoption(parser): + parser.addoption("--npu_indexing", action='store', default='False', + help='whether enable npu indexing or not,default is True', choices=['True', 'False']) + + +@pytest.fixture(scope="session") +def clear_cache(): + # os.system('rm -rf /tmp/torchinductor_' + getpass.getuser() + '/*') + # os.system('rm -rf ~/.triton/dump') + # os.system('rm -rf ~/.triton/cache') + return + + +@pytest.fixture(scope="session", autouse=True) +def set_npu_indexing(pytestconfig): + torch_npu._inductor.config.enable_npu_indexing = eval(pytestconfig.getoption("--npu_indexing")) diff --git a/test/_inductor/run_ut.sh b/test/_inductor/run_ut.sh new file mode 100644 index 0000000000000000000000000000000000000000..c68e41b8cd2d28fefd7d3deb3219fb9c762d7b3f --- /dev/null +++ b/test/_inductor/run_ut.sh @@ -0,0 +1,73 @@ +#!/bin/bash +set -ex + +source /root/anaconda3/bin/activate inductor260 +pip list + +# 先编译tritonNpu +pip uninstall triton + +mkdir -p ${WORKSPACE}TritonNpu +cd ${WORKSPACE}TritonNpu +git clone https://gitee.com/ascend/triton-ascend.git -b master + +# clear inductor cache +rm -rf /tmp/torchinductor_* + +if [ -d ${WORKSPACE}TritonNpu/triton-ascend/triton ];then + rm -rf ${WORKSPACE}TritonNpu/triton-ascend/triton +fi + +if [ -d ~/.triton/dump ];then + rm -rf ~/.triton/dump +fi + +if [ -d ~/.triton/cache ];then + rm -rf ~/.triton/cache +fi + +cd ${WORKSPACE}TritonNpu/triton-ascend +git clone --depth 1 https://gitee.com/shijingchang/triton.git +#cp -r /triton_depends/triton ${WORKSPACE}TritonNpu/triton-ascend/triton +#cd ${WORKSPACE}TritonNpu/triton-ascend/triton +#git apply ${WORKSPACE}TritonNpu/triton-ascend/build/patch/triton_ebce7f.patch +#git apply ${WORKSPACE}TritonNpu/triton-ascend/build/patch/0001-AttrDescriptor-fix-and-delete-power-of-two.patch +#cd ${WORKSPACE}TritonNpu/triton-ascend +echo ${pwd} + +TRITON_PLUGIN_DIRS=${WORKSPACE}TritonNpu/triton-ascend/ascend \ +LLVM_INCLUDE_DIRS=$LLVM_SYSPATH/include \ +LLVM_LIBRARY_DIR=$LLVM_SYSPATH/lib \ +LLVM_SYSPATH=$LLVM_SYSPATH \ +TRITON_BUILD_WITH_CLANG_LLD=true \ +TRITON_BUILD_PROTON=OFF \ +pip install -e ${WORKSPACE}TritonNpu/triton-ascend/triton/python --no-build-isolation -vvv + +pip list + +cd ${WORKSPACE} +echo ${PWD} +ls -al + +# run inductor ut +export PYTHONPATH=${WORKSPACE}:$PYTHONPATH +export TORCHINDUCTOR_COMPILE_THREADS=1 +export ASCEND_LAUNCH_BLOCKING=1 +export CI="" +env + +if [ -d ~/.triton/dump ];then + rm -rf ~/.triton/dump +fi + +if [ -d ~/.triton/cache ];then + rm -rf ~/.triton/cache +fi + +tree + +cd test + +pytest -svvv . --npu_indexing=True || { exit 1; } + + diff --git a/test/_inductor/test_abs.py b/test/_inductor/test_abs.py new file mode 100644 index 0000000000000000000000000000000000000000..8440aab1fcadce97e53de9bf5d28d25ae4335d6a --- /dev/null +++ b/test/_inductor/test_abs.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from .testutils import OperatorType, TestUtils + + +class TestAbs(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, first_element): + result = torch.abs(first_element) + return result + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(1024, 32), (256, 8)]) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element) + # print(std_result[0:8]) + # print(inductor_result[0:8]) + torch.testing.assert_close(std_result, inductor_result, atol=1e-3, rtol=1e-3) diff --git a/test/_inductor/test_add.py b/test/_inductor/test_add.py new file mode 100644 index 0000000000000000000000000000000000000000..8da8dff4f5949946980e23ac4a8a54e0f333d789 --- /dev/null +++ b/test/_inductor/test_add.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestAdd(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, first_element, second_element): + result = first_element + second_element + return result + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float32', 'int64']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + second_element = self._generate_tensor(shape, dtype) + + std_sum = self.op_calc(first_element, second_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_sum = compiled_op_calc(first_element, second_element) + + torch.testing.assert_close(std_sum, inductor_sum) + + # should be implemented when __OPTYPE is OperatorType.REDUCTION + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape,dim', TestUtils._reduction_extest_SDbinding) + @pytest.mark.parametrize('dtype', TestUtils._test_dtypes) + @pytest.mark.skipif(__OPTYPE != OperatorType.REDUCTION, reason='not reduction operator') + def test_reduction_cases(self, shape, dim, dtype, clear_cache): + pass + +if __name__ == "__main__": + size = (1024, 1024) + test = TestAdd() + test.test_pointwise_cases(size, 'float32', None) \ No newline at end of file diff --git a/test/_inductor/test_add_sum.py b/test/_inductor/test_add_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..3e54152be9d8ca082797a19c06c884ab5cabb194 --- /dev/null +++ b/test/_inductor/test_add_sum.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestSumAdd(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.REDUCTION + + def foo(self,a, b, dim): + y = a + b + y = y.sum(dim) + return y + + # case:change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(9, 9, 31, 64)]) + @pytest.mark.parametrize('dim', [3]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + a, b = [torch.randn(shape, requires_grad=False, dtype=torch.float32, device="npu") for _ in range(2)] + r1 = self.foo(a, b, dim) + func = torch.compile(self.foo, backend="inductor", dynamic=False) + r = func(a, b, dim) + torch.testing.assert_close(r, r1, rtol=1e-3, atol=1e-3) + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(9, 10, 31, 63)]) + @pytest.mark.parametrize('dim', [0, 1]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes1(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + a, b = [torch.randn(shape, requires_grad=False, dtype=torch.float32, device="npu") for _ in range(2)] + r1 = self.foo(a, b, dim) + func = torch.compile(self.foo, backend="inductor", dynamic=False) + r = func(a, b, dim) + torch.testing.assert_close(r, r1, rtol=1e-3, atol=1e-3) diff --git a/test/_inductor/test_alias.py b/test/_inductor/test_alias.py new file mode 100644 index 0000000000000000000000000000000000000000..7f93f091ce4b6e5ecb4d60b34fa14f5b624db2d1 --- /dev/null +++ b/test/_inductor/test_alias.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestAlias(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, input_element, dim): + x = torch.ops.aten.alias(input_element) + y = x + 1.0 + return y + + # case:change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(32, 64)]) + @pytest.mark.parametrize('dim', [0]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + print(f"input_element= {input_element}") + std_ret = self.op_calc(input_element, dim) + print(f"std_ret= {std_ret}") + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_ret = compiled_op_calc(input_element, dim) + print(f"inductor_ret= {inductor_ret}") + rtol = 1e-1 + atol = 1e-1 + assert torch.allclose(std_ret, inductor_ret, equal_nan=True, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + size = (32, 64) + test = TestAlias() + test.test_reduction_cases_shapes(size, -1, 'float32', None) diff --git a/test/_inductor/test_argmax.py b/test/_inductor/test_argmax.py new file mode 100644 index 0000000000000000000000000000000000000000..d8ddb3771f6df92cb6da150c4cc63f631f68c723 --- /dev/null +++ b/test/_inductor/test_argmax.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor +import pytest +from testutils import OperatorType, TestUtils + + +class TestArgmax(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + def argmax(self, a, dim): + return torch.argmax(a, dim) + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.skip(reason='not support yet') + def test_argmax(self): + shape=(512, 64) + dim = -1 + print(f"start to test argmax on shape:{shape} dim:{dim} ") + a = torch.randn(shape, requires_grad=False, dtype=torch.float32, device='npu') + + argmax_triton = torch.compile(self.argmax, backend="inductor", dynamic=False) + r = self.argmax(a, dim) + r1 = argmax_triton(a, dim) + torch.testing.assert_close(r, r1, rtol=1e-3, atol=1e-3) + + diff --git a/test/_inductor/test_argmax_unalign.py b/test/_inductor/test_argmax_unalign.py new file mode 100644 index 0000000000000000000000000000000000000000..66beb403054f04cd733f93d25be168bbf7c981d6 --- /dev/null +++ b/test/_inductor/test_argmax_unalign.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import sys +sys.path.append("../..") +import torch_npu._inductor + +import pytest +# from .testutils import OperatorType, TestUtils +torch_npu._inductor.config.enable_npu_indexing = True +class TestMaxWithIndex(): + __TIME_LIMIT = 100 + def op_calc(self, input_element, dim): + return torch.argmax(input_element, dim) + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(512, 64)]) # (513, 64), (514,33) + @pytest.mark.parametrize('dim', [-1 ]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases(self, shape, dim, dtype): + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + input_element = torch.randn(size=shape, dtype=eval('torch.' + dtype), device=torch.device("npu")) * 2000 + std_argmax = self.op_calc(input_element, dim) + compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + inductor_argmax = compiled_op_calc(input_element, dim) + torch.testing.assert_close(std_argmax, inductor_argmax, rtol=1e-2, atol=1e-2) +if __name__ == '__main__': + self = TestMaxWithIndex() + self.test_reduction_cases((513, 64), -1, 'float32') \ No newline at end of file diff --git a/test/_inductor/test_arrange.py b/test/_inductor/test_arrange.py new file mode 100644 index 0000000000000000000000000000000000000000..3fe320fdb47d56e701d8d4029979ab2454a2721c --- /dev/null +++ b/test/_inductor/test_arrange.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestArrange(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, start, end, step): + a = torch.arange(start, end, step, device=torch.device('npu')) + y = a + a + return y + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(2, )]) + @pytest.mark.parametrize('dtype', TestUtils._test_dtypes) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + s = self._generate_tensor(shape, dtype) + start = min(s) + end = max(s) + step = (end - start) / 32 + + std_arrange = self.op_calc(start, end, step) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + inductor_arrange = compiled_op_calc(start, end, step) + + torch.testing.assert_close(std_arrange, inductor_arrange) + + # should be implemented when __OPTYPE is OperatorType.REDUCTION + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape,dim', TestUtils._reduction_extest_SDbinding) + @pytest.mark.parametrize('dtype', TestUtils._test_dtypes) + @pytest.mark.skipif(__OPTYPE != OperatorType.REDUCTION, reason='not reduction operator') + def test_reduction_cases(self, shape, dim, dtype, clear_cache): + pass diff --git a/test/_inductor/test_attncp.py b/test/_inductor/test_attncp.py new file mode 100644 index 0000000000000000000000000000000000000000..c2e40d3469915334040f103f1938bee1064849c5 --- /dev/null +++ b/test/_inductor/test_attncp.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor +import pytest +from testutils import OperatorType, TestUtils + + +class TestAttnCp(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + # @torch.compile(options={"aggressive_fusion": False}) + shape = (8, 8, 256, 128) + dim = -1 + def foo(self, a, b, c): + y = a + b + y = y.sum(self.dim) + y = y.unsqueeze(self.dim) + y = y.broadcast_to(self.shape) + b + y = c + y.permute(0, 1, 3, 2) + return y + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + def test_pointwise_cases(self): + a, b = [torch.randn(self.shape, dtype=torch.float32, device="npu") for _ in range(2)] + d = torch.randn(self.shape, dtype=torch.float32, device="npu") + c = d.permute(0, 1, 3, 2).contiguous() + func = torch.compile(self.foo, backend="inductor") + r = func(a, b, c) + r1 = self.foo(a, b, c) + torch.testing.assert_close(r, r1, rtol=1e-3, atol=1e-3) diff --git a/test/_inductor/test_batch_norm.py b/test/_inductor/test_batch_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..d330800f8f75ca78cf33b2974d85e549569386f4 --- /dev/null +++ b/test/_inductor/test_batch_norm.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestNativeBatchNorm(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, input_element): + # 创建权重和偏置张量 + weight = torch.ones(32).npu() + bias = torch.zeros(32).npu() + + # 创建运行均值和方差张量 + running_mean = torch.zeros(32).npu() + running_var = torch.ones(32).npu() + + + # 执行批量归一化 + output, running_mean_out, running_var_out = torch.native_batch_norm( + input=input_element, + weight=weight, + bias=bias, + running_mean=running_mean, + running_var=running_var, + training=True, + momentum=0.1, + eps=1e-05 + ) + return output, running_mean_out, running_var_out + + @pytest.mark.skip(reason="npu compiler bug") + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(16, 32, 64)]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes(self, shape, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + + print(f"input_element= {input_element}") + std_ret, std_ret2, std_ret3 = self.op_calc(input_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_ret, inductor_ret2, inductor_ret3 = compiled_op_calc(input_element) + rtol = 1e-1 + atol = 1e-1 + assert torch.allclose(std_ret, inductor_ret, equal_nan=True, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + size = (16, 32, 64) + test = TestNativeBatchNorm() + test.test_reduction_cases_shapes(size, 'float32', None) + diff --git a/test/_inductor/test_broadcast.py b/test/_inductor/test_broadcast.py new file mode 100644 index 0000000000000000000000000000000000000000..85ec062ff53acc408a149ca12620b09a36ff9aca --- /dev/null +++ b/test/_inductor/test_broadcast.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu + +import torch_npu._inductor + +import copy +import pytest +from testutils import OperatorType, TestUtils + +class TestBroadcast(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + broadcast_size = 128 + + def op_calc(self, a, b, dim, new_shape): + a = a.unsqueeze(dim) + a = a.broadcast_to(new_shape) + b = b.unsqueeze(dim) + b = b.broadcast_to(new_shape) + y = a + b + return y + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 8, 256)]) + @pytest.mark.parametrize('dtype', ['float32', 'int32', 'float16', 'bfloat16']) + def test_view_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + a = self._generate_tensor(shape, dtype) + b = self._generate_tensor(shape, dtype) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + for dim in [3, 2, 1, 0]: + new_shape = list(copy.deepcopy(shape)) + new_shape.insert(dim, self.broadcast_size) + std_broadcast = self.op_calc(a, b, dim, new_shape) + inductor_broadcast = compiled_op_calc(a, b, dim, new_shape) + + torch.testing.assert_close(std_broadcast.float(), inductor_broadcast.float(), rtol=1e-3, atol=1e-3) + print(f"data validation passed") + + # should be implemented when __OPTYPE is OperatorType.REDUCTION + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape,dim', TestUtils._reduction_extest_SDbinding) + @pytest.mark.parametrize('dtype', TestUtils._test_dtypes) + @pytest.mark.skipif(__OPTYPE != OperatorType.REDUCTION, reason='not reduction operator') + def test_reduction_cases(self, shape, dim, dtype, clear_cache): + pass diff --git a/test/_inductor/test_cat.py b/test/_inductor/test_cat.py new file mode 100644 index 0000000000000000000000000000000000000000..715c44bf561b14b8636b92c9f8360a8a1513e76a --- /dev/null +++ b/test/_inductor/test_cat.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestCat(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, input_element, dim): + return torch.cat([input_element, input_element], dim) + + # case:change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 16, 32, 64)]) + @pytest.mark.parametrize('dim', [-1]) + @pytest.mark.parametrize('dtype', ['bfloat16']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + std_cat = self.op_calc(input_element, dim) + # print(f"std_cat.shape= {std_cat.shape}") + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_cat = compiled_op_calc(input_element, dim) + # print(f"inductor_cat.shape= {inductor_cat.shape}") + rtol = 1e-1 + atol = 1e-1 + assert torch.allclose(std_cat, inductor_cat, equal_nan=True, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + size = (8, 8, 8, 2048) + test = TestCat() + test.test_reduction_cases_shapes(size, 2, 'float32', None) diff --git a/test/_inductor/test_ceil.py b/test/_inductor/test_ceil.py new file mode 100644 index 0000000000000000000000000000000000000000..da2d7cc73be89b4f4f676c1ce5fdb94db7462e10 --- /dev/null +++ b/test/_inductor/test_ceil.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestRelu(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + def op_calc(self, first_element): + result = torch.ceil(first_element) + return result + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float32', 'float16', 'bfloat16']) + def test_pointwise_cases(self, shape, dtype): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element) + + torch.testing.assert_close(std_result, inductor_result) + + +if __name__ == '__main__': + TestRelu() diff --git a/test/_inductor/test_clamp.py b/test/_inductor/test_clamp.py new file mode 100644 index 0000000000000000000000000000000000000000..adc3bcaf12286201a4df5db299ff341deb1d2da5 --- /dev/null +++ b/test/_inductor/test_clamp.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor +import pytest +from testutils import OperatorType, TestUtils + + +class TestClamp(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, input, min=None, max=None): + return input.clamp(min, max) + + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32', 'int64']) + @pytest.mark.skip(reason='not support yet') + def test_pointwise_cases_minmax_is_tensor(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + min = self._generate_tensor(shape, dtype) + max = self._generate_tensor(shape, dtype) + + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element, min=min, max=max) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, min=min, max=max) + + torch.testing.assert_close(std_result, inductor_result) + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(1,)]) + @pytest.mark.parametrize('dtype', ['float32']) + @pytest.mark.skip(reason='not support yet') + def test_pointwise_cases_single_scalar(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + min = 0 + max = 100 + + first_element = 200 * torch.rand(size=shape, dtype=eval('torch.' + dtype), device=torch.device("npu")) + + std_result = self.op_calc(first_element, min=min, max=max) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, min=min, max=max) + torch.testing.assert_close(std_result, inductor_result) + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(1024, 32)]) + @pytest.mark.parametrize('dtype', ['int32']) + @pytest.mark.skip(reason='not support yet') + def test_pointwise_cases_minmax_is_number(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + min = 0 + max = 100 + + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element, min=min, max=max) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, min=min, max=max) + + torch.testing.assert_close(std_result, inductor_result) + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32', 'int64']) + def test_pointwise_cases_max_only(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + max = 100 + + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element, min=None, max=max) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, min=None, max=max) + + torch.testing.assert_close(std_result, inductor_result) + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32', 'int64']) + def test_pointwise_cases_min_only(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + min = 0 + + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element, min=min, max=None) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, min=min, max=None) + + torch.testing.assert_close(std_result, inductor_result) +if __name__ == '__main__': + obj = TestClamp() + obj.test_pointwise_cases_single_scalar((1,), 'float32', None) \ No newline at end of file diff --git a/test/_inductor/test_clone.py b/test/_inductor/test_clone.py new file mode 100644 index 0000000000000000000000000000000000000000..374317523b7a9959015fdc6301937e1e0c1bf6fc --- /dev/null +++ b/test/_inductor/test_clone.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestClone(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, input_element, dim): + return torch.clone(input_element) + + # case: change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 64, 128)]) + @pytest.mark.parametrize('dim', [0]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + std_ret = self.op_calc(input_element, dim) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_ret = compiled_op_calc(input_element, dim) + + assert torch.allclose(std_ret, inductor_ret, equal_nan=True) + + +if __name__ == "__main__": + size = (8, 64, 128) + test = TestClone() + test.test_reduction_cases_shapes(size, 2, 'float32', None) + + + diff --git a/test/_inductor/test_cos.py b/test/_inductor/test_cos.py new file mode 100644 index 0000000000000000000000000000000000000000..b963eb8ce822c6f4be91057137aa82a56c487089 --- /dev/null +++ b/test/_inductor/test_cos.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestLog(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, first_element): + result = torch.cos(first_element) + return result + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型, 将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float32', 'int64']) + @pytest.mark.skip(reason='not support yet') + def test_pointwise_cases(self, shape, dtype): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element) + + torch.testing.assert_close(std_result, inductor_result) + + # should be implemented when __OPTYPE is OperatorType.REDUCTION + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape,dim', TestUtils._reduction_extest_SDbinding) + @pytest.mark.parametrize('dtype', TestUtils._test_dtypes) + @pytest.mark.skipif(__OPTYPE != OperatorType.REDUCTION, reason='not reduction operator') + def test_reduction_cases(self, shape, dim, dtype, clear_cache): + pass + +if __name__ == '__main__': + TestLog() + + diff --git a/test/_inductor/test_device_put.py b/test/_inductor/test_device_put.py new file mode 100644 index 0000000000000000000000000000000000000000..39b17ea27db16d315dc28c8c2f876c1fc7bf35f9 --- /dev/null +++ b/test/_inductor/test_device_put.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestDevicePut(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, input_element1, input_element2): + return torch.add(input_element1, input_element2) + + # case: change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 16, 8)]) + @pytest.mark.parametrize('dtype', ['int32']) + def test_cases_shapes(self, shape, dtype, clear_cache): + low = 0 + high = 2 + dtype = eval('torch.' + dtype) + print(f"shape= {shape}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + # 指定目标设备为 NPU + npu_device = torch.device('npu:0') + input_element1_tmp = torch.randint(low, high, shape, dtype=dtype).cpu() + input_element2_tmp = torch.randint(low, high, shape, dtype=dtype).cpu() + input_element1 = torch.ops.prims.device_put(input_element1_tmp, npu_device) + input_element2 = torch.ops.prims.device_put(input_element2_tmp, npu_device) + + std_ret = self.op_calc(input_element1, input_element2) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_ret = compiled_op_calc(input_element1, input_element2) + + assert torch.allclose(std_ret, inductor_ret, equal_nan=True) + + +if __name__ == "__main__": + size = (8, 16, 8) + test = TestDevicePut() + test.test_cases_shapes(size, 2, 'int32', None) + + + diff --git a/test/_inductor/test_div.py b/test/_inductor/test_div.py new file mode 100644 index 0000000000000000000000000000000000000000..318b521fe43b9e8ee72d2b58b0f75b5e1d28a5a9 --- /dev/null +++ b/test/_inductor/test_div.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from .testutils import OperatorType, TestUtils + + +class TestMul(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, first_element, second_element): + result = torch.div(first_element, second_element) + return result + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型, 将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32', 'int64']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + second_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element, second_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, second_element) + torch.testing.assert_close(std_result, inductor_result, equal_nan=True) + + + diff --git a/test/_inductor/test_embedding.py b/test/_inductor/test_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..dce1208cb091e786fe08c0bf6f6aee23d6a7a502 --- /dev/null +++ b/test/_inductor/test_embedding.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +#from .testutils import OperatorType, TestUtils +import torch.nn as nn + +class TestSub(): + + def op_calc(self, input): + embedding = nn.Embedding(16, 128).npu() + output = embedding(input) + return output + + def test_pointwise_cases(self): + + input = torch.tensor([[14, 1, 2, 10, 0, 10, 0], + [ 9, 13, 13, 4, 7, 15, 14], + [ 8, 0, 3, 15, 4, 2, 6], + [15, 12, 13, 9, 0, 8, 1], + [ 8, 15, 4, 15, 12, 9, 3], + [ 6, 11, 12, 8, 0, 13, 8], + [ 4, 10, 1, 12, 0, 0, 4], + [ 6, 6, 15, 6, 0, 10, 15], + [ 2, 5, 14, 0, 5, 7, 9], + [13, 4, 14, 11, 11, 9, 2], + [ 1, 1, 5, 1, 1, 6, 14], + [ 3, 9, 8, 4, 13, 8, 3], + [ 4, 10, 8, 13, 6, 8, 3]], device='npu:0') + + std_sub = self.op_calc(input) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_sum = compiled_op_calc(input) + #torch.testing.assert_close(std_sub, inductor_sum) + + +if __name__ == "__main__": + test = TestSub() + test.test_pointwise_cases() + + + + diff --git a/test/_inductor/test_embedding_fallback.py b/test/_inductor/test_embedding_fallback.py new file mode 100644 index 0000000000000000000000000000000000000000..4a5b492b9de662d8c5fe16a594faa73a68aa16ad --- /dev/null +++ b/test/_inductor/test_embedding_fallback.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestRsqrt(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, slice_4, sum_23): + result = torch.ops.aten.embedding_dense_backward.default(sum_23, slice_4, 512, -1, False) + return result + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(1, 512, 128)]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_pointwise_cases(self, shape, dtype): + torch_npu._inductor.config.enable_npu_indexing = True + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = torch.randint(low=0, high=128, size=(1, 512), dtype=torch.int64).npu() + second_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element, second_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, second_element) + + print(std_result) + print(inductor_result) + + torch.testing.assert_close(std_result, inductor_result, rtol=1e-1, atol=1e-1) + + +if __name__ == "__main__": + size = (1, 512, 128) + test = TestRsqrt() + test.test_pointwise_cases(size, 'float32') + + + + diff --git a/test/_inductor/test_empty.py b/test/_inductor/test_empty.py new file mode 100644 index 0000000000000000000000000000000000000000..9cc8fe36f5d0379cc61ac2ec48dffbbb8a890765 --- /dev/null +++ b/test/_inductor/test_empty.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestEmpty(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self): + x = torch.empty(8, 64, 128, dtype=torch.float32).npu() + x.uniform_(-100, 100) + return x + def op_calc_empty_permuted(self): + input_shape = (8, 64, 128) + physical_layout =(0, 1, 2) #物理布局与输入形状相同 + x = torch.empty_permuted(input_shape, physical_layout).npu() + x.uniform_(-100, 100) + return x + + # case: change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 64, 128)]) + @pytest.mark.parametrize('dim', [0]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_cases_empty(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + std_ret = self.op_calc() + # print(f"std_ret= {std_ret}") + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_ret = compiled_op_calc() + # print(f"inductor_ret= {inductor_ret}") + + assert inductor_ret.numel() > 0 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 64, 128)]) + @pytest.mark.parametrize('dim', [0]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_cases_empty_permuted(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + std_ret = self.op_calc_empty_permuted() + # print(f"std_ret= {std_ret}") + compiled_op_calc = torch.compile(self.op_calc_empty_permuted, backend="inductor") + inductor_ret = compiled_op_calc() + # print(f"inductor_ret= {inductor_ret}") + + assert inductor_ret.numel() > 0 + + +if __name__ == "__main__": + size = (8, 64, 128) + test = TestEmpty() + test.test_reduction_cases_shapes(size, 2, 'float32', None) + + + diff --git a/test/_inductor/test_eq.py b/test/_inductor/test_eq.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4c9d103bc96d260f2337b14a694136c4c2b9da --- /dev/null +++ b/test/_inductor/test_eq.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor +import pytest +from testutils import OperatorType, TestUtils + + +class TestEq(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + def op_calc(self, first_element, second_element): + return torch.eq(first_element, second_element) + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float32', 'int32', 'float16', 'bfloat16']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + first_element = self._generate_tensor(shape, dtype) + second_element = first_element.clone() + + # randomly change some elements in second tensor + flat_second_view = second_element.flatten() + num_elements_to_change = first_element.numel() //3 + random_indices = torch.randint(0, first_element.numel(), (num_elements_to_change,)) + flat_second_view[random_indices] = 1- flat_second_view[random_indices] + + std_result = self.op_calc(first_element, second_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, second_element) + + torch.testing.assert_close(std_result, inductor_result) + + + diff --git a/test/_inductor/test_exp.py b/test/_inductor/test_exp.py new file mode 100644 index 0000000000000000000000000000000000000000..078f9e653d9cee2b9d292cd0717d3801e462bf2a --- /dev/null +++ b/test/_inductor/test_exp.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from .testutils import OperatorType, TestUtils + + +class TestExp(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, first_element): + result = torch.exp(first_element) + return result + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型, 将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int64']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element) + # print(std_result[0:8]) + # print(inductor_result[0:8]) + # torch.testing.assert_close(std_result, inductor_result) + # 需要比较包含 NaN 值的张量, 并且希望认为两个 NaN值是相等的, 您可以使用 torch.allclose 函数, 并设置 equal_nan=True 参数 + rtol = 1e-1 + atol = 1e-1 + assert torch.allclose(std_result, inductor_result, equal_nan=True, rtol=rtol, atol=atol) + + + diff --git a/test/_inductor/test_expm1.py b/test/_inductor/test_expm1.py new file mode 100644 index 0000000000000000000000000000000000000000..27d8e053466b9eb50c6909efa92e3d88573ee7f5 --- /dev/null +++ b/test/_inductor/test_expm1.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from .testutils import OperatorType, TestUtils + + +class TestSqrt(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, first_element): + result = torch.expm1(first_element) + return result + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型, 将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int64']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element) + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element) + + torch.allclose(std_result, inductor_result, equal_nan=True) + + diff --git a/test/_inductor/test_floor.py b/test/_inductor/test_floor.py new file mode 100644 index 0000000000000000000000000000000000000000..1d7d144feed47d8fddaf4119a1af812219424562 --- /dev/null +++ b/test/_inductor/test_floor.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestRelu(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + def op_calc(self, first_element): + result = torch.floor(first_element) + return result + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32', 'int64']) + def test_pointwise_cases(self, shape, dtype): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element) + + torch.testing.assert_close(std_result, inductor_result) + + +if __name__ == '__main__': + TestRelu() + + + + diff --git a/test/_inductor/test_foreach_add.py b/test/_inductor/test_foreach_add.py new file mode 100644 index 0000000000000000000000000000000000000000..66111096f641f0b3af5d31d90afe1c63b79c51eb --- /dev/null +++ b/test/_inductor/test_foreach_add.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestRsqrt(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + def op_calc(self, first_element, second_element): + tensor_list = [first_element, second_element] + + add_list =[first_element, second_element] + result = torch._foreach_add_(tensor_list, add_list) + return result + + @pytest.mark.skip(reason='compile error, torch npu segmet fault') + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['int32']) + def test_pointwise_cases(self, shape, dtype): + torch_npu._inductor.config.enable_npu_indexing = True + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + second_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element, second_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, second_element) + + torch.testing.assert_close(std_result, inductor_result, rtol=1e-1, atol=1e-1) + + +if __name__ == "__main__": + size = (1024, 32) + test = TestRsqrt() + test.test_pointwise_cases(size, 'float32') + + + diff --git a/test/_inductor/test_ge.py b/test/_inductor/test_ge.py new file mode 100644 index 0000000000000000000000000000000000000000..4fe7a95e50f2d15c2b7500485926efde143c5f6e --- /dev/null +++ b/test/_inductor/test_ge.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor +import pytest +from testutils import OperatorType, TestUtils + + +class TestGe(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + def op_calc(self, first_element, second_element): + return torch.ge(first_element, second_element) + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + first_element = self._generate_tensor(shape, dtype) + second_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element, second_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, second_element) + + torch.testing.assert_close(std_result, inductor_result) + + + diff --git a/test/_inductor/test_geometric.py b/test/_inductor/test_geometric.py new file mode 100644 index 0000000000000000000000000000000000000000..827146f2e5a43b3190867d53b33e22005649bcb3 --- /dev/null +++ b/test/_inductor/test_geometric.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestGeometric(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self): + # 创建一个形状为 (3, 3)的张量, 每个位置的概率为 0.5 + prob =torch.full((16, 16), 0.5).npu() + + #使用 aten.geometric生成几何分布的随机数 + geometric_tensor =torch.ops.aten.geometric(prob, p=0.5) + + return geometric_tensor + + # case: change shapes + @pytest.mark.skip(reason="this has problem in torch 260") + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(16, 16, 16)]) + @pytest.mark.parametrize('dim', [0]) + @pytest.mark.parametrize('dtype', ['int32']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + std_ret = self.op_calc() + std_ret_mean =torch.mean(std_ret) + print(f"std_ret_mean= {std_ret_mean}") + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_ret = compiled_op_calc() + + inductor_ret_mean = torch.mean(inductor_ret) + print(f"inductor_ret_mean= {inductor_ret_mean}") + assert inductor_ret_mean is not None + + +if __name__ == "__main__": + size = (16, 16, 16) + test = TestGeometric() + test.test_reduction_cases_shapes(size, -1, 'float32', None) + + + diff --git a/test/_inductor/test_gt.py b/test/_inductor/test_gt.py new file mode 100644 index 0000000000000000000000000000000000000000..4b29d7eef7ed75e4561b23a62db8c422e66745db --- /dev/null +++ b/test/_inductor/test_gt.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei TechNologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from .testutils import OperatorType, TestUtils + + +class TestGt(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, first_element, second_element): + result = torch.gt(first_element, second_element) + return result + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型, 将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32']) + def test_pointwise_cases(self, shape, dtype): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + second_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element, second_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element, second_element) + + print("start test!") + torch.testing.assert_close(std_result, inductor_result) + + # should be implemented when __OPTYPE is OperatorType.REDUCTION + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape,dim', TestUtils._reduction_extest_SDbinding) + @pytest.mark.parametrize('dtype', TestUtils._test_dtypes) + @pytest.mark.skipif(__OPTYPE != OperatorType.REDUCTION, reason='not reduction operator') + def test_reduction_cases(self, shape, dim, dtype, clear_cache): + pass + +if __name__ == '__main__': + TestGt() + + + + + diff --git a/test/_inductor/test_high_order_sum.py b/test/_inductor/test_high_order_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..8b48e963e3bd205e279c24b17fa68642220c299d --- /dev/null +++ b/test/_inductor/test_high_order_sum.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch.nn.functional as F +import torch +import torch_npu +import torch_npu._inductor + +def op_sum(npu_dropout_backward_9): + view_337: "f32[32768, 256]" = torch.ops.aten.view.default(npu_dropout_backward_9, [32768, 256]); + sum_63: "f32[1, 256]" = torch.ops.aten.sum.dim_IntList(view_337, [0], True); + view_338: "f32[256]" = torch.ops.aten.view.default(sum_63, [256]); + return view_338 + +device='npu' + +def test_high_order_sum(): + npu_dropout_backward_9 = torch.randn((32768, 256), device=device, dtype=torch.float32) + ref = op_sum(npu_dropout_backward_9) + func = torch.compile(op_sum, backend="inductor", dynamic=False) + calc = func(npu_dropout_backward_9) + + torch.testing.assert_close(ref, calc, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(ref, calc, rtol=1e-3, atol=1e-3) + +if __name__ == "__main__": + npu_dropout_backward_9 = torch.randn((32768, 256), device=device, dtype=torch.float32) + ref = op_sum(npu_dropout_backward_9) + func = torch.compile(op_sum, backend="inductor", dynamic=False) + calc = func(npu_dropout_backward_9) + + torch.testing.assert_close(ref, calc, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(ref, calc, rtol=1e-3, atol=1e-3) + + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False + ) + with torch_npu.profiler.profile( + activities=[ # torch_npu.profiler.ProfilerActivity.CPU, + torch_npu.profiler.ProfilerActivity.NPU], + with_stack=False, #采集torch 算子的函数调用栈的开关,该参数选填,默认关闭 + record_shapes=False, # 采集torch 算子的input shape和input type的开关,该参数选填,默认关闭 + profile_memory=False, # 采集memory相关数据的开关,该参数选填,默认关闭 + schedule=torch_npu.profiler.schedule(wait=1, + warmup=1, + active=10, + repeat=1, + skip_first=1), + # schedule=torch_npu.profiler.schedule(wait=1, warmup=1, active=1, skip_first=6), + # warmup默认为0,老版本torch_npu包该参数为必填项 + experimental_config=experimental_config, # 该参数选填,默认为Level0 + # 产生的profling文件的位置 + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./result_dir") + # 导出tensorboard可呈现的数据形式,可指定worker_name, 默认为:{host名称}_{进程id} + ) as prof: + for i in range(20): + # ref1 = call(args) + op_sum(npu_dropout_backward_9) + func(npu_dropout_backward_9) + prof.step() + + + + + + diff --git a/test/_inductor/test_issue54.py b/test/_inductor/test_issue54.py new file mode 100644 index 0000000000000000000000000000000000000000..ce943a616dddf70a4b696986693543e4355febb9 --- /dev/null +++ b/test/_inductor/test_issue54.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import torch.nn.functional as F +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from torch.nn import CrossEntropyLoss +from torch import nn + + +class Test_issue54(): + def func_layernorm(self, add_3, primals_6, primals_7, view, primals_9, permute_1, primals_10, primals_11): + # 原网络 + permute: "f32[256, 256]" = torch.ops.aten.permute.default(primals_6, [1, 0]); + addmm: "f32[32768, 256]" = torch.ops.aten.addmm.default(primals_7, view, permute); + view_1: "f32[64, 512, 256]" = torch.ops.aten.view.default(addmm, [64, 512, 256]); + addmm_1: "f32[32768, 256]" = torch.ops.aten.addmm.default(primals_9, view, permute_1); + view_3: "f32[64, 512, 256]" = torch.ops.aten.view.default(addmm_1, [64, 512, 256]); + view_4: "f32[64, 512, 4, 64]" = torch.ops.aten.view.default(view_3, [64, 512, 4, 64]); + permute_2: "f32[64, 4, 512, 64]" = torch.ops.aten.permute.default(view_4, [0, 2, 1, 3]); + permute_3: "f32[256, 256]" = torch.ops.aten.permute.default(primals_10, [1, 0]); + addmm_2: "f32[32768, 256]" = torch.ops.aten.addmm.default(primals_11, view, permute_3); + view_6: "f32[64, 512, 256]" = torch.ops.aten.view.default(addmm_2, [64, 512, 256]); + + view_8: "f32[64, 512, 4, 64]" = torch.ops.aten.view.default(view_1, [64, 512, 4, 64]); + permute_5: "f32[64, 4, 512, 64]" = torch.ops.aten.permute.default(view_8, [0, 2, 1, 3]); + + permute_6: "f32[64, 4, 64, 512]" = torch.ops.aten.permute.default(permute_2, [0, 1, 3, 2]); + expand_1: "f32[64, 4, 512, 64]" = torch.ops.aten.expand.default(permute_5, [64, 4, 512, 64]) + clone: "f32[64, 4, 512, 64]" = torch.ops.aten.clone.default(expand_1, memory_format=torch.contiguous_format); + view_9: "f32[256, 512, 64]" = torch.ops.aten.view.default(clone, [256, 512, 64]); + expand_2: "f32[64, 4, 64, 512]" = torch.ops.aten.expand.default(permute_6, [64, 4, 64, 512]) + clone_1: "f32[64, 4, 64, 512]" = torch.ops.aten.clone.default(expand_2, memory_format=torch.contiguous_format); + view_10: "f32[256, 64, 512]" = torch.ops.aten.view.default(clone_1, [256, 64, 512]); + bmm: "f32[256, 512, 512]" = torch.ops.aten.bmm.default(view_9, view_10); + view_7: "f32[64, 512, 4, 64]" = torch.ops.aten.view.default(view_6, [64, 512, 4, 64]); + permute_4: "f32[64, 4, 512, 64]" = torch.ops.aten.permute.default(view_7, [0, 2, 1, 3]); + expand_4: "f32[64, 4, 512, 64]" = torch.ops.aten.expand.default(permute_4, [64, 4, 512, 64]) + clone_2: "f32[64, 4, 512, 64]" = torch.ops.aten.clone.default(expand_4, memory_format=torch.contiguous_format); + view_13: "f32[256, 512, 64]" = torch.ops.aten.view.default(clone_2, [256, 512, 64]); + + return bmm, view_13 + + def test_issue54(self): + device = 'npu' + test = Test_issue54() + # add_3, primals_6, primals_7, view, primals_9, permute_1, primals_10, primals_11 + + add_3 = torch.randn((64, 512, 256), device=device, dtype=torch.float32) + primals_6 = torch.randn((256, 256), device=device, dtype=torch.float32) + primals_7 = torch.randn((256), device=device, dtype=torch.float32) + view = torch.randn((32768, 256), device=device, dtype=torch.float32) + primals_9 = torch.randn((256), device=device, dtype=torch.float32) + permute_1 = torch.randn((256, 256), device=device, dtype=torch.float32) + primals_10 = torch.randn((256, 256), device=device, dtype=torch.float32) + primals_11 = torch.randn((256), device=device, dtype=torch.float32) + + ref = test.func_layernorm(add_3, primals_6, primals_7, view, primals_9, permute_1, primals_10, primals_11) + func = torch.compile(test.func_layernorm, backend="inductor", dynamic=False, + options={"unroll_reductions_threshold": 1, "aggressive_fusion": True}) + calc = func(add_3, primals_6, primals_7, view, primals_9, permute_1, primals_10, primals_11) + torch.testing.assert_close(ref[0], calc[0], rtol=1e-2, atol=1e-2) + torch.testing.assert_close(ref[1], calc[1], rtol=1e-2, atol=1e-2) + print("valid ok") + + +if __name__ == "__main__": + test = Test_issue54() + test.test_issue54() \ No newline at end of file diff --git a/test/_inductor/test_issue57.py b/test/_inductor/test_issue57.py new file mode 100644 index 0000000000000000000000000000000000000000..ac0cde11ae9b817551da7f28c450797aed98e100 --- /dev/null +++ b/test/_inductor/test_issue57.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import torch.nn.functional as F +import torch +import torch_npu +import torch_npu._inductor +import pytest + + +class Test_issue57(): + def op_sum(self, view_12, embedding_1, slice_11): + # 原网络 + + permute_7 = torch.ops.aten.permute.default(embedding_1, [2, 0, 1]); + embedding_1 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(permute_7, 0); + permute_7 = None + + add_5 = torch.ops.aten.add.Tensor(unsqueeze_4, slice_11); + slice_8 = slice_11 = None + add_6 = torch.ops.aten.add.Tensor(view_12, add_5); + view_12 = None + return add_6 + + def test_issue57(self): + device = 'npu' + test = Test_issue57() + embedding_1 = torch.randn((512, 512, 64), device=device, dtype=torch.float32) + primals_221 = torch.randn((1, 1, 1, 512), device=device, dtype=torch.float32) + view_12 = torch.randn((1, 64, 512, 512), device=device, dtype=torch.float32) + slice_11 = torch.randn((1, 1, 1, 512), device=device, dtype=torch.float32) + + ref = test.op_sum(view_12, embedding_1, primals_221) + func = torch.compile(test.op_sum, backend="inductor", dynamic=False) + calc = func(view_12, embedding_1, primals_221) + + torch.testing.assert_close(ref, calc, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(ref, calc, rtol=1e-3, atol=1e-3) + + print("valid ok") + + +if __name__ == "__main__": + test = Test_issue57() + test.test_issue57() \ No newline at end of file diff --git a/test/_inductor/test_issue59.py b/test/_inductor/test_issue59.py new file mode 100644 index 0000000000000000000000000000000000000000..eac1ae795b2915b78fdc1d3054259774a0f10a0e --- /dev/null +++ b/test/_inductor/test_issue59.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor +import pytest + + +class Test_issue59(): + def layernorm_backward(self, x, y, z): + sum = torch.sum(x) + mean = sum / torch.numel(sum) + sub = x - mean + sqr = sub * sub + sum_1 = torch.sum(sqr) + mean_1 = sum_1 / torch.numel(sum_1) + 1e-05 + rsqrt = torch.rsqrt(mean_1) + mul = sub * rsqrt + mul_1 = mul * y + add = mul_1 + z + mean_2 = rsqrt / torch.numel(rsqrt) + return mul, add, mean_2 + + def test_issue59(self): + device = 'npu' + test = Test_issue59() + x = torch.randn((1, 1024), device=device, dtype=torch.float32) + y = torch.randn((1, 1024), device=device, dtype=torch.float32) + z = torch.randn((1, 1024), device=device, dtype=torch.float32) + + mul, add, mean_2 = test.layernorm_backward(x, y, z) + func = torch.compile(test.layernorm_backward, backend="inductor", dynamic=False) + mul_t, add_t, mean_2_t = func(x, y, z) + + torch.testing.assert_close(mul, mul_t, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(add, add_t, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(mean_2, mean_2_t, rtol=1e-3, atol=1e-3) + + print("valid ok") + + +if __name__ == "__main__": + test = Test_issue59() + test.test_issue59() diff --git a/test/_inductor/test_issue62.py b/test/_inductor/test_issue62.py new file mode 100644 index 0000000000000000000000000000000000000000..075b45a7b04ef9c6f37a340deb47c1d0ff157235 --- /dev/null +++ b/test/_inductor/test_issue62.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import torch +import torch_npu +import triton +import triton.language as tl +import torch_npu._inductor +import pytest + + +# 实际就是 layernorm的计算过程 : torch.nn.LayerNorm(convert_element_type_25, elementwise_affine=False, eps=1e-6) +class Test_issue62(): + def op_func(self, addmm_5, add): + split = torch.ops.aten.split.Tensor(addmm_5, 1536, 1) + getitem = split[0] + getitem_1 = split[1] + getitem_2 = split[2] + getitem_3 = split[3] + getitem_4 = split[4] + getitem_5 = split[5] + + clone_1 = torch.ops.aten.clone.default(add, memory_format=torch.contiguous_format) + convert_element_type_25 = torch.ops.prims.convert_element_type.default(clone_1, torch.float32) + var_mean = torch.ops.aten.var_mean.correction(convert_element_type_25, [2], correction=0, keepdim=True) + getitem_6 = var_mean[0] + getitem_7 = var_mean[1] + add_3 = torch.ops.aten.add.Tensor(getitem_6, 1e-06) + rsqrt = torch.ops.aten.rsqrt.default(add_3) + sub = torch.ops.aten.sub.Tensor(clone_1, getitem_7) + mul_7 = torch.ops.aten.mul.Tensor(sub, rsqrt) + convert_element_type_26 = torch.ops.prims.convert_element_type.default(mul_7, torch.float16) + slice_11 = torch.ops.aten.slice.Tensor(getitem_1, 0, 0, 9223372036854775807) + unsqueeze_2 = torch.ops.aten.unsqueeze.default(slice_11, 1) + add_4 = torch.ops.aten.add.Tensor(unsqueeze_2, 1) + mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_26, add_4) + slice_12 = torch.ops.aten.slice.Tensor(getitem, 0, 0, 9223372036854775807) + unsqueeze_3 = torch.ops.aten.unsqueeze.default(slice_12, 1) + add_5 = torch.ops.aten.add.Tensor(mul_8, unsqueeze_3) + return add_5 + + def test_issue62(self): + test = Test_issue62() + addmm_5 = torch.randn((2, 9216), device='npu:0', dtype=torch.float16) + add = torch.randn((2, 4096, 1536), device='npu:0', dtype=torch.float16) + + std_ret = test.op_func(addmm_5, add) + compiled_func = torch.compile(test.op_func, backend="inductor") + inductor_ret = compiled_func(addmm_5, add) + assert torch.allclose(std_ret, inductor_ret, atol=1e-2, rtol=1e-2), "Tensors are not close enough!" + print("valid ok") + + +if __name__ == "__main__": + test = Test_issue62() + test.test_issue62() \ No newline at end of file diff --git a/test/_inductor/test_issue70.py b/test/_inductor/test_issue70.py new file mode 100644 index 0000000000000000000000000000000000000000..6b8410bb1d100cdba70eb8dbf4eb187778fb78d1 --- /dev/null +++ b/test/_inductor/test_issue70.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor +import torch.nn as nn +import pytest + + +class Test_issue70(): + + def op_forward(self, x): + return x.mean(-1) + + + def test_issue70(self): + test = Test_issue70() + compiled_net = torch.compile(test.op_forward, backend="inductor") + + input = torch.randn((1, 1, 7168)).npu() + + output = test.op_forward(input) + output1 = compiled_net(input) + torch.testing.assert_allclose(output, output1, rtol=1e-03, atol=1e-03) + print("valid ok") + + +if __name__ == "__main__": + test = Test_issue70() + test.test_issue70() diff --git a/test/_inductor/test_opensora_graph1.py b/test/_inductor/test_opensora_graph1.py new file mode 100644 index 0000000000000000000000000000000000000000..5b7f35992f986abac4c90f52a65f00379ba9974b --- /dev/null +++ b/test/_inductor/test_opensora_graph1.py @@ -0,0 +1,343 @@ +import torch +import torch_npu +import torch_npu._inductor +import pytest +__TIME_LIMIT = 100 +from torch import device +device_npu = 'npu' + +@pytest.mark.timeout(__TIME_LIMIT) +def test_opensora_cases_model_9_inference(): + def forward(primals_1: "f32[1, 9600, 2304]"): + permute: "f32[9600, 1, 2304]" = torch.ops.aten.permute.default(primals_1, [1, 0, 2]); + return permute + primals_2 = torch.randn((1, 9600, 2304), device = device_npu, dtype=torch.float32) + ref = forward(primals_2) + forward_calc = torch.compile(forward, backend="inductor", dynamic=False) + calc = forward_calc(primals_2) + assert torch.allclose(ref, calc, equal_nan=True, rtol=1e-4, atol=1e-4) + primals_3 = torch.randn((1, 512, 2304), device=device_npu, dtype=torch.float32) + forward_calc = torch.compile(forward, backend="inductor", dynamic=False) + calc = forward_calc(primals_3) + ref = forward(primals_3) + assert torch.allclose(ref, calc, equal_nan=True, rtol=1e-4, atol=1e-4) + primals_4 = torch.randn((9600, 1, 2304), device=device_npu, dtype=torch.float32) + forward_calc = torch.compile(forward, backend="inductor", dynamic=False) + calc = forward_calc(primals_4) + ref = forward(primals_4) + assert torch.allclose(ref, calc, equal_nan=True, rtol=1e-4, atol=1e-4) + +@pytest.mark.skip +@pytest.mark.timeout(__TIME_LIMIT) +def test_opensora_cases_model_11_inference(): + def forward(arg0_1: "f32[1, 1, 9600]", arg1_1: "f32[1, 1, 512]"): + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:119 in prepare_sparse_mask, code: video_mask = video_mask.unsqueeze(1) + unsqueeze: "f32[1, 1, 1, 9600]" = torch.ops.aten.unsqueeze.default(arg0_1, 1); + arg0_1 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:120 in prepare_sparse_mask, code: prompt_mask = prompt_mask.unsqueeze(1) + unsqueeze_1: "f32[1, 1, 1, 512]" = torch.ops.aten.unsqueeze.default(arg1_1, 1); + arg1_1 = None + # File: /root/anaconda3/envs/inductor2.3_sora/lib/python3.9/site-packages/torch/nn/functional.py:4522 in pad, code: return torch._C._nn.pad(input, pad, mode, value) + constant_pad_nd: "f32[1, 1, 1, 9600]" = torch.ops.aten.constant_pad_nd.default(unsqueeze, [0, 0, 0, 0], + -9980.0); + unsqueeze = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:128 in prepare_sparse_mask, code: video_mask_sparse_1d = rearrange( + view: "f32[1, 9600, 1]" = torch.ops.aten.view.default(constant_pad_nd, [1, 9600, 1]) + permute: "f32[1, 1, 9600]" = torch.ops.aten.permute.default(view, [2, 0, 1]); + view = None + view_1: "f32[1, 1, 1, 9600]" = torch.ops.aten.view.default(permute, [1, 1, 1, 9600]); + permute = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:133 in prepare_sparse_mask, code: video_mask_sparse_1d_group = rearrange( + view_2: "f32[1, 9600, 1, 1]" = torch.ops.aten.view.default(constant_pad_nd, [1, 9600, 1, 1]); + constant_pad_nd = None + permute_1: "f32[1, 1, 9600, 1]" = torch.ops.aten.permute.default(view_2, [2, 0, 1, 3]); + view_2 = None + view_3: "f32[1, 1, 1, 9600]" = torch.ops.aten.view.default(permute_1, [1, 1, 1, 9600]); + permute_1 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:139 in prepare_sparse_mask, code: prompt_mask_sparse = prompt_mask.repeat(sparse_n, 1, 1, 1) + repeat: "f32[1, 1, 1, 512]" = torch.ops.aten.repeat.default(unsqueeze_1, [1, 1, 1, 1]); + unsqueeze_1 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:142 in get_attention_mask, code: mask = mask.to(torch.bool) + npu_dtype_cast: "b8[1, 1, 1, 9600]" = torch.ops.npu.npu_dtype_cast.default(view_1, torch.bool); + view_1 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:143 in get_attention_mask, code: mask = mask.repeat(1, 1, repeat_num, 1) + repeat_1: "b8[1, 1, 9600, 9600]" = torch.ops.aten.repeat.default(npu_dtype_cast, [1, 1, 9600, 1]); + npu_dtype_cast = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:142 in get_attention_mask, code: mask = mask.to(torch.bool) + npu_dtype_cast_1: "b8[1, 1, 1, 9600]" = torch.ops.npu.npu_dtype_cast.default(view_3, torch.bool); + view_3 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:143 in get_attention_mask, code: mask = mask.repeat(1, 1, repeat_num, 1) + repeat_2: "b8[1, 1, 9600, 9600]" = torch.ops.aten.repeat.default(npu_dtype_cast_1, [1, 1, 9600, 1]); + npu_dtype_cast_1 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:142 in get_attention_mask, code: mask = mask.to(torch.bool) + npu_dtype_cast_2: "b8[1, 1, 1, 512]" = torch.ops.npu.npu_dtype_cast.default(repeat, torch.bool); + repeat = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:143 in get_attention_mask, code: mask = mask.repeat(1, 1, repeat_num, 1) + repeat_3: "b8[1, 1, 9600, 512]" = torch.ops.aten.repeat.default(npu_dtype_cast_2, [1, 1, 9600, 1]); + npu_dtype_cast_2 = None + return (repeat_1, repeat_3, repeat_2) + arg0_1 = torch.rand((1, 1, 9600), device=device_npu, dtype=torch.float32) + arg1_1 = torch.rand((1, 1, 512), device=device_npu, dtype=torch.float32) + ref = forward(arg0_1, arg1_1) + forward_calc = torch.compile(forward, backend="inductor", dynamic=False) + calc = forward_calc(arg0_1, arg1_1) + for i in range(len(ref)): + print(ref[i]) + assert torch.allclose(ref[i], calc[i], equal_nan=True, rtol=1e-4, atol=1e-4) + +@pytest.mark.skip +@pytest.mark.timeout(__TIME_LIMIT) +def test_opensora_cases_model_14_backward(): + def forward(primals_5: "f32[1, 9600, 2304]", getitem_3: "f32[1, 9600, 1]", rsqrt: "f32[1, 9600, 1]", + add_2: "f32[1, 1, 2304]", view: "f32[9600, 2304]", permute_1: "f32[32, 2304]", + tangents_1: "f32[1, 9600, 32]"): + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:384 in _get_output_for_patched_inputs, code: latents = self.norm_out(latents) + sub: "f32[1, 9600, 2304]" = torch.ops.aten.sub.Tensor(primals_5, getitem_3); + primals_5 = getitem_3 = None + mul: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(sub, rsqrt); + sub = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:387 in _get_output_for_patched_inputs, code: latents = self.proj_out(latents) + view_2: "f32[9600, 32]" = torch.ops.aten.view.default(tangents_1, [9600, 32]); + tangents_1 = None + mm: "f32[9600, 2304]" = torch.ops.aten.mm.default(view_2, permute_1); + permute_1 = None + permute_2: "f32[32, 9600]" = torch.ops.aten.permute.default(view_2, [1, 0]) + mm_1: "f32[32, 2304]" = torch.ops.aten.mm.default(permute_2, view); + permute_2 = view = None + permute_3: "f32[2304, 32]" = torch.ops.aten.permute.default(mm_1, [1, 0]); + mm_1 = None + sum_1: "f32[1, 32]" = torch.ops.aten.sum.dim_IntList(view_2, [0], True); + view_2 = None + view_3: "f32[32]" = torch.ops.aten.view.default(sum_1, [32]); + sum_1 = None + permute_4: "f32[32, 2304]" = torch.ops.aten.permute.default(permute_3, [1, 0]); + permute_3 = None + view_4: "f32[1, 9600, 2304]" = torch.ops.aten.view.default(mm, [1, 9600, 2304]); + mm = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:386 in _get_output_for_patched_inputs, code: latents = latents * (1 + scale) + shift + sum_2: "f32[1, 1, 2304]" = torch.ops.aten.sum.dim_IntList(view_4, [1], True) + mul_2: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(view_4, mul) + mul_3: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(view_4, add_2); + view_4 = add_2 = None + sum_3: "f32[1, 1, 2304]" = torch.ops.aten.sum.dim_IntList(mul_2, [1], True); + mul_2 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:384 in _get_output_for_patched_inputs, code: latents = self.norm_out(latents) + mul_5: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(mul_3, 2304) + sum_4: "f32[1, 9600, 1]" = torch.ops.aten.sum.dim_IntList(mul_3, [2], True) + mul_6: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(mul_3, mul); + mul_3 = None + sum_5: "f32[1, 9600, 1]" = torch.ops.aten.sum.dim_IntList(mul_6, [2], True); + mul_6 = None + mul_7: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(mul, sum_5); + mul = sum_5 = None + sub_2: "f32[1, 9600, 2304]" = torch.ops.aten.sub.Tensor(mul_5, sum_4); + mul_5 = sum_4 = None + sub_3: "f32[1, 9600, 2304]" = torch.ops.aten.sub.Tensor(sub_2, mul_7); + sub_2 = mul_7 = None + div: "f32[1, 9600, 1]" = torch.ops.aten.div.Tensor(rsqrt, 2304); + rsqrt = None + mul_8: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(div, sub_3); + div = sub_3 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:383 in _get_output_for_patched_inputs, code: shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + cat: "f32[1, 2, 2304]" = torch.ops.aten.cat.default([sum_2, sum_3], 1); + sum_2 = sum_3 = None + sum_6: "f32[1, 1, 2304]" = torch.ops.aten.sum.dim_IntList(cat, [1], True) + squeeze_1: "f32[1, 2304]" = torch.ops.aten.squeeze.dim(sum_6, 1); + sum_6 = None + full_default: "f32[1, 2304]" = torch.ops.aten.full.default([1, 2304], 0, dtype=torch.float32, + layout=torch.strided, + device=device(type='npu', index=0), pin_memory=False) + slice_scatter: "f32[1, 2304]" = torch.ops.aten.slice_scatter.default(full_default, squeeze_1, 0, 0, + 9223372036854775807); + full_default = squeeze_1 = None + squeeze_2: "f32[2, 2304]" = torch.ops.aten.squeeze.dim(cat, 0); + cat = None + return [squeeze_2, permute_4, view_3, slice_scatter, mul_8] + primals_5 = torch.randn((1, 9600, 2304), device=device_npu, dtype=torch.float32) + getitem_3 = torch.randn((1, 9600, 1), device=device_npu, dtype=torch.float32) + rsqrt = torch.randn((1, 9600, 1), device=device_npu, dtype=torch.float32) + add_2 = torch.randn((1, 1, 2304), device=device_npu, dtype=torch.float32) + view = torch.randn((9600, 2304), device=device_npu, dtype=torch.float32) + permute_1 = torch.randn((32, 2304), device=device_npu, dtype=torch.float32) + tangents_1 = torch.randn((1, 9600, 32), device=device_npu, dtype=torch.float32) + ref = forward(primals_5, getitem_3, rsqrt, + add_2, view, permute_1,tangents_1) + forward_calc = torch.compile(forward, backend="inductor", dynamic=False) + calc = forward_calc(primals_5, getitem_3, rsqrt, + add_2, view, permute_1,tangents_1) + for i in range(len(ref)): + # 1e-3 can not pass, should check reduction accuracy + assert torch.allclose(ref[i], calc[i], equal_nan=True, rtol=1e-4, atol=1e-4) + +@pytest.mark.timeout(__TIME_LIMIT) +def test_opensora_cases_model_14_forward(): + def forward(primals_1: "f32[2, 2304]", primals_2: "f32[32, 2304]", primals_3: "f32[32]", + primals_4: "f32[1, 2304]", primals_5: "f32[1, 9600, 2304]"): + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:383 in _get_output_for_patched_inputs, code: shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + unsqueeze: "f32[1, 2, 2304]" = torch.ops.aten.unsqueeze.default(primals_1, 0); + primals_1 = None + slice_1: "f32[1, 2304]" = torch.ops.aten.slice.Tensor(primals_4, 0, 0, 9223372036854775807); + primals_4 = None + unsqueeze_1: "f32[1, 1, 2304]" = torch.ops.aten.unsqueeze.default(slice_1, 1); + slice_1 = None + add: "f32[1, 2, 2304]" = torch.ops.aten.add.Tensor(unsqueeze, unsqueeze_1); + unsqueeze = unsqueeze_1 = None + split = torch.ops.aten.split.Tensor(add, 1, 1); + add = None + getitem: "f32[1, 1, 2304]" = split[0] + getitem_1: "f32[1, 1, 2304]" = split[1]; + split = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:384 in _get_output_for_patched_inputs, code: latents = self.norm_out(latents) + var_mean = torch.ops.aten.var_mean.correction(primals_5, [2], correction=0, keepdim=True) + getitem_2: "f32[1, 9600, 1]" = var_mean[0] + getitem_3: "f32[1, 9600, 1]" = var_mean[1]; + var_mean = None + add_1: "f32[1, 9600, 1]" = torch.ops.aten.add.Tensor(getitem_2, 1e-06); + getitem_2 = None + rsqrt: "f32[1, 9600, 1]" = torch.ops.aten.rsqrt.default(add_1); + add_1 = None + sub: "f32[1, 9600, 2304]" = torch.ops.aten.sub.Tensor(primals_5, getitem_3) + mul: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(sub, rsqrt); + sub = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:386 in _get_output_for_patched_inputs, code: latents = latents * (1 + scale) + shift + add_2: "f32[1, 1, 2304]" = torch.ops.aten.add.Tensor(getitem_1, 1); + getitem_1 = None + mul_1: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(mul, add_2); + mul = None + add_3: "f32[1, 9600, 2304]" = torch.ops.aten.add.Tensor(mul_1, getitem); + mul_1 = getitem = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:387 in _get_output_for_patched_inputs, code: latents = self.proj_out(latents) + view: "f32[9600, 2304]" = torch.ops.aten.view.default(add_3, [9600, 2304]); + add_3 = None + permute: "f32[2304, 32]" = torch.ops.aten.permute.default(primals_2, [1, 0]); + primals_2 = None + addmm: "f32[9600, 32]" = torch.ops.aten.addmm.default(primals_3, view, permute); + primals_3 = None + view_1: "f32[1, 9600, 32]" = torch.ops.aten.view.default(addmm, [1, 9600, 32]); + addmm = None + # No stacktrace found for following nodes + squeeze: "f32[1, 9600, 32]" = torch.ops.aten.squeeze.dim(view_1, 1); + view_1 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:387 in _get_output_for_patched_inputs, code: latents = self.proj_out(latents) + permute_1: "f32[32, 2304]" = torch.ops.aten.permute.default(permute, [1, 0]); + permute = None + return [squeeze, primals_5, getitem_3, rsqrt, add_2, view, permute_1] + primals_1 = torch.ones((2, 2304), device=device_npu, dtype=torch.float32) + primals_2 = torch.ones((32, 2304), device=device_npu, dtype=torch.float32) + primals_3 = torch.ones((32,), device=device_npu, dtype=torch.float32) + primals_4 = torch.ones((1, 2304), device=device_npu, dtype=torch.float32) + primals_5 = torch.ones((1, 9600, 2304), device=device_npu, dtype=torch.float32) + ref = forward(primals_1, primals_2, primals_3,primals_4, primals_5) + forward_calc = torch.compile(forward, backend="inductor", dynamic=False) + calc = forward_calc(primals_1, primals_2, primals_3,primals_4, primals_5) + for i in range(len(ref)): + assert torch.allclose(ref[i], calc[i], equal_nan=True, rtol=1e-4, atol=1e-4) + +@pytest.mark.timeout(__TIME_LIMIT) +def test_opensora_cases_model_15_forward(): + def forward(primals_1: "f32[1, 8, 30, 40, 1, 2, 2, 8]", primals_2: "i64[]", primals_3: "i64[]", + primals_4: "i64[]"): + permute: "f32[1, 8, 8, 1, 30, 2, 40, 2]" = torch.ops.aten.permute.default(primals_1, [0, 7, 1, 4, 2, 5, 3, 6]); + mul: "i64[]" = torch.ops.aten.mul.Tensor(primals_2, 1); + mul_1: "i64[]" = torch.ops.aten.mul.Tensor(primals_3, 2); + mul_2: "i64[]" = torch.ops.aten.mul.Tensor(primals_4, 2); + return [permute, mul, mul_1, mul_2] + + primals_1 = torch.randn((1, 8, 30, 40, 1, 2, 2, 8), device=device_npu, dtype=torch.float32) + primals_2 = torch.tensor((1), device=device_npu, dtype=torch.int64) + primals_3 = torch.tensor((1), device=device_npu, dtype=torch.int64) + primals_4 = torch.tensor((1), device=device_npu, dtype=torch.int64) + ref = forward(primals_1, primals_2, primals_3, + primals_4) + forward_calc = torch.compile(forward, backend="inductor", dynamic=False) + calc = forward_calc(primals_1, primals_2, primals_3, + primals_4) + for i in range(len(ref)): + assert torch.allclose(ref[i], calc[i], equal_nan=True, rtol=1e-4, atol=1e-4) + +def find_first_mismatch(output_calc, out, rtol=1e-2, atol=1e-2): + for index in torch.cartesian_prod(*[torch.arange(s) for s in output_calc.shape]): + index = tuple(index.tolist()) + diff = torch.abs(output_calc[index] - out[index]) + rel_diff = diff / torch.abs(out[index]) if torch.abs(out[index]) > 0 else 0 + if diff > atol or rel_diff > rtol: + return index + return None + +@pytest.mark.skip +@pytest.mark.timeout(__TIME_LIMIT) +def test_opensora_cases_model_16_forward(): + def forward(primals_1: "f32[2, 2304]", primals_2: "f32[32, 2304]", primals_3: "f32[32]", primals_4: "f32[1, 2304]", primals_5: "f32[1, 9600, 2304]"): + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:407 in _get_output_for_patched_inputs, code: shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + unsqueeze: "f32[1, 2, 2304]" = torch.ops.aten.unsqueeze.default(primals_1, 0); primals_1 = None + slice_1: "f32[1, 2304]" = torch.ops.aten.slice.Tensor(primals_4, 0, 0, 9223372036854775807); primals_4 = None + unsqueeze_1: "f32[1, 1, 2304]" = torch.ops.aten.unsqueeze.default(slice_1, 1); slice_1 = None + add: "f32[1, 2, 2304]" = torch.ops.aten.add.Tensor(unsqueeze, unsqueeze_1); unsqueeze = unsqueeze_1 = None + split = torch.ops.aten.split.Tensor(add, 1, 1); add = None + getitem: "f32[1, 1, 2304]" = split[0] + getitem_1: "f32[1, 1, 2304]" = split[1]; split = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:408 in _get_output_for_patched_inputs, code: latents = self.norm_out(latents) + var_mean = torch.ops.aten.var_mean.correction(primals_5, [2], correction = 0, keepdim = True) + getitem_2: "f32[1, 9600, 1]" = var_mean[0] + getitem_3: "f32[1, 9600, 1]" = var_mean[1]; var_mean = None + add_1: "f32[1, 9600, 1]" = torch.ops.aten.add.Tensor(getitem_2, 1e-06); getitem_2 = None + rsqrt: "f32[1, 9600, 1]" = torch.ops.aten.rsqrt.default(add_1); add_1 = None + sub: "f32[1, 9600, 2304]" = torch.ops.aten.sub.Tensor(primals_5, getitem_3) + mul: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(sub, rsqrt); sub = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:410 in _get_output_for_patched_inputs, code: latents = latents * (1 + scale) + shift + add_2: "f32[1, 1, 2304]" = torch.ops.aten.add.Tensor(getitem_1, 1); getitem_1 = None + mul_1: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(mul, add_2); mul = None + add_3: "f32[1, 9600, 2304]" = torch.ops.aten.add.Tensor(mul_1, getitem); mul_1 = getitem = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:411 in _get_output_for_patched_inputs, code: latents = self.proj_out(latents) + view: "f32[9600, 2304]" = torch.ops.aten.view.default(add_3, [9600, 2304]); add_3 = None + permute: "f32[2304, 32]" = torch.ops.aten.permute.default(primals_2, [1, 0]); primals_2 = None + addmm: "f32[9600, 32]" = torch.ops.aten.addmm.default(primals_3, view, permute); primals_3 = None + #import pdb;pdb.set_trace() + view_1: "f32[1, 9600, 32]" = torch.ops.aten.view.default(addmm, [1, 9600, 32]); + # No stacktrace found for following nodes + squeeze: "f32[1, 9600, 32]" = torch.ops.aten.squeeze.dim(view_1, 1); + # import pdb; + # pdb.set_trace() + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:418 in _get_output_for_patched_inputs, code: latents = latents.reshape( + view_2: "f32[1, 8, 30, 40, 1, 2, 2, 8]" = torch.ops.aten.view.default(squeeze, [1, 8, 30, 40, 1, 2, 2, 8]); squeeze = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:428 in _get_output_for_patched_inputs, code: latents = latents.permute(0, 7, 1, 4, 2, 5, 3, 6).contiguous() + permute_1: "f32[1, 8, 8, 1, 30, 2, 40, 2]" = torch.ops.aten.permute.default(view_2, [0, 7, 1, 4, 2, 5, 3, 6]); view_2 = None + clone: "f32[1, 8, 8, 1, 30, 2, 40, 2]" = torch.ops.aten.clone.default(permute_1); permute_1 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:429 in _get_output_for_patched_inputs, code: output = latents.reshape( + clone_1: "f32[1, 8, 8, 1, 30, 2, 40, 2]" = torch.ops.aten.clone.default(clone, memory_format = torch.contiguous_format); clone = None + view_3: "f32[1, 8, 8, 60, 80]" = torch.ops.aten.view.default(clone_1, [1, 8, 8, 60, 80]); clone_1 = None + # File: /home/w00685865/osl1.3/mindspeed_mm/models/predictor/dits/video_dit_sparse.py:411 in _get_output_for_patched_inputs, code: latents = self.proj_out(latents) + permute_3: "f32[32, 2304]" = torch.ops.aten.permute.default(permute, [1, 0]); permute = None + return [view_3, primals_5, getitem_3, rsqrt, add_2, view, permute_3] + + import random + import numpy as np + import os + def seed_all(seed=1234, mode=False): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.use_deterministic_algorithms(mode) + torch_npu.npu.manual_seed_all(seed) + torch_npu.npu.manual_seed(seed) + + seed_all(True) + primals_1 = torch.randn((2, 2304), device=device_npu,dtype=torch.float32) + print(primals_1) + primals_2 = torch.randn((32, 2304), device=device_npu,dtype=torch.float32) + primals_3 = torch.randn((32,), device=device_npu,dtype=torch.float32) + primals_4 = torch.randn((1, 2304), device=device_npu,dtype=torch.float32) + primals_5 = torch.randn((1, 9600, 2304), device=device_npu,dtype=torch.float32) + + ref = forward(primals_1, primals_2, primals_3, primals_4, primals_5) + forward_calc = torch.compile(forward, backend="inductor", dynamic=False) + calc = forward_calc(primals_1, primals_2, primals_3, primals_4, primals_5) + for i in range(len(ref)): + print("i=", i) + assert torch.allclose(ref[i], calc[i], equal_nan=True, rtol=1e-3, atol=1e-3) + +if __name__ == '__main__': + test_opensora_cases_model_15_forward() + #test_opensora_cases_model_15_forward() + #test_opensora_cases_model_16_forward() diff --git a/test/_inductor/test_permute.py b/test/_inductor/test_permute.py new file mode 100644 index 0000000000000000000000000000000000000000..fee281959207f1e1dfb11be26ad38c1016cf45d9 --- /dev/null +++ b/test/_inductor/test_permute.py @@ -0,0 +1,47 @@ +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +torch_npu._inductor.config.enable_npu_indexing = True + + +class TestPermute(TestUtils): + __TIME_LIMIT = 100 + + _permute_dims = [ + (0, 1, 2, 3), (0, 1, 3, 2), (0, 2, 1, 3), (0, 2, 3, 1), + (0, 3, 1, 2), (0, 3, 2, 1), (1, 0, 2, 3), (1, 0, 3, 2), + (1, 2, 0, 3), (1, 2, 3, 0), (1, 3, 0, 2), (1, 3, 2, 0), + (2, 0, 1, 3), (2, 0, 3, 1), (2, 1, 0, 3), (2, 1, 3, 0), + (2, 3, 0, 1), (2, 3, 1, 0), (3, 0, 1, 2), (3, 0, 2, 1), + (3, 1, 0, 2), (3, 1, 2, 0), (3, 2, 0, 1), (3, 2, 1, 0), + ] + + def op_calc(self, a, b, dim): + a = a.permute(dim) + b = b.permute(dim) + y = a + b + return y + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 8, 512, 128)]) + @pytest.mark.parametrize('dtype', ['float32', 'int32', 'float16', 'bfloat16', 'int64']) + def test_view_cases(self, shape, dtype, clear_cache): + print(f"shape={shape}") + print(f"dtype={dtype}") + print("npu_indexing={}".format(torch_npu._inductor.config.enable_npu_indexing)) + + a = self._generate_tensor(shape, dtype) + b = self._generate_tensor(shape, dtype) + + for dim in self._permute_dims: + print(f"start to test permute on dim :{dim}") + std_permute = self.op_calc(a, b, dim) + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_permute = compiled_op_calc(a, b, dim) + + torch.testing.assert_close(std_permute, inductor_permute, rtol=1e-3, atol=1e-3) + print("data validation passed.") diff --git a/test/_inductor/test_reduction_brocast_add.py b/test/_inductor/test_reduction_brocast_add.py new file mode 100644 index 0000000000000000000000000000000000000000..29e86fdae90af454a819ee7ddc624e8d3ab2ecd5 --- /dev/null +++ b/test/_inductor/test_reduction_brocast_add.py @@ -0,0 +1,34 @@ +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestSumAdd(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.REDUCTION + + def foo(self,a, b, dim, shape): + y = a + b + y = y.sum(dim) + y = y.unsqueeze(dim) + y = y.broadcast_to(shape) + b + return y + + # case:change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(9, 9, 31, 63)]) + @pytest.mark.parametrize('dim', [0, 1, 2]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes1(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + a, b = [torch.randn(shape, requires_grad=False, dtype=torch.float32, device="npu") for _ in range(2)] + r1 = self.foo(a, b, dim, shape) + func = torch.compile(self.foo, backend="inductor", dynamic=False) + r = func(a, b, dim, shape) + torch.testing.assert_close(r, r1, rtol=1e-3, atol=1e-3) diff --git a/test/_inductor/test_relu.py b/test/_inductor/test_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..3107d4d5cb73d3b525238687597ab557ff4e612e --- /dev/null +++ b/test/_inductor/test_relu.py @@ -0,0 +1,34 @@ +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestRelu(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + def op_calc(self, first_element): + result = torch.relu(first_element) + return result + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32', 'int64']) + def test_pointwise_cases(self, shape, dtype): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + + std_result = self.op_calc(first_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element) + + torch.testing.assert_close(std_result, inductor_result) + + +if __name__ == '__main__': + TestRelu() diff --git a/test/_inductor/test_renorm.py b/test/_inductor/test_renorm.py new file mode 100644 index 0000000000000000000000000000000000000000..f2e55a833d9849cd37a903f3d96f52ecad90e0b6 --- /dev/null +++ b/test/_inductor/test_renorm.py @@ -0,0 +1,40 @@ +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestRenorm(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, input_element, dim): + return torch.renorm(input_element, p=2, dim=dim, maxnorm=5) + + # case:change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(32, 64)]) + @pytest.mark.parametrize('dim', [-1]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + print(f"input_element= {input_element}") + std_ret = self.op_calc(input_element, dim) + print(f"std_ret= {std_ret}") + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_ret = compiled_op_calc(input_element, dim) + print(f"inductor_ret= {inductor_ret}") + + assert torch.allclose(std_ret, inductor_ret, equal_nan=True) + + +if __name__ == "__main__": + size = (32, 64) + test = TestRenorm() + test.test_reduction_cases_shapes(size, -1, 'float32', None) + diff --git a/test/_inductor/test_repeat.py b/test/_inductor/test_repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..d3df6a138d27d23f63501175402a17cc4496f0bb --- /dev/null +++ b/test/_inductor/test_repeat.py @@ -0,0 +1,40 @@ +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestRepeat(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, input_element, dim): + return input_element.repeat(dim) + + # case:change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(16, 128, 64)]) + @pytest.mark.parametrize('dim', [(1, 1, 2), (1, 2, 1), (2, 1, 1)]) #(2, 3, 4), (1, 2, 3) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + + std_ret = self.op_calc(input_element, dim) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_ret = compiled_op_calc(input_element, dim) + + torch.testing.assert_close(std_ret, inductor_ret, rtol=1e-1, atol=1e-1) + + +if __name__ == "__main__": + size = (16, 512, 64) + dim = (2, 3, 4) + test = TestRepeat() + test.test_reduction_cases_shapes(size, dim, 'float32', None) + diff --git a/test/_inductor/test_reshape.py b/test/_inductor/test_reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..910e0c83b402776be45963e4101fba839d0f941e --- /dev/null +++ b/test/_inductor/test_reshape.py @@ -0,0 +1,39 @@ +import torch +import torch_npu +import pytest +import torch_npu._inductor +from testutils import OperatorType, TestUtils + +torch_npu._inductor.config.enable_npu_indexing = True + +class TestReshape(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + B, N, S, D = (1, 12, 256, 8) + + def op_calc(self, a, b): + a = a.reshape(self.S, self.B, self.N * self.D) + b = b.reshape(self.S, self.B, self.N * self.D) + y = a + b + return y + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(1, 12, 256, 8)]) + @pytest.mark.parametrize('dtype', ['float32', 'int32', 'float16', 'bfloat16', 'int64']) + def test_view_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + a = self._generate_tensor(shape, dtype) + b = self._generate_tensor(shape, dtype) + + print(f"start to test reshape on shape :{shape} ") + std_reshape = self.op_calc(a, b) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_reshape = compiled_op_calc(a, b) + + torch.testing.assert_close(std_reshape, inductor_reshape, rtol=1e-3, atol=1e-3) + + print("data validation passed") + diff --git a/test/_inductor/test_rsqrt.py b/test/_inductor/test_rsqrt.py new file mode 100644 index 0000000000000000000000000000000000000000..b76e1779f48f43db46f2a61594d3967188fe3857 --- /dev/null +++ b/test/_inductor/test_rsqrt.py @@ -0,0 +1,35 @@ +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestRsqrt(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + def op_calc(self, first_element): + result = torch.rsqrt(first_element) + return result + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32', 'int64']) + def test_pointwise_cases(self, shape, dtype): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype, 1) + + std_result = self.op_calc(first_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element) + + torch.testing.assert_close(std_result, inductor_result, rtol=1e-1, atol=1e-1) + + +if __name__ == '__main__': + TestRsqrt() + diff --git a/test/_inductor/test_slice.py b/test/_inductor/test_slice.py new file mode 100644 index 0000000000000000000000000000000000000000..2b8e75a91ba170a172a4ef27b3730360aff6e262 --- /dev/null +++ b/test/_inductor/test_slice.py @@ -0,0 +1,55 @@ +import torch +import torch_npu +import pytest +import torch_npu._inductor +from testutils import OperatorType, TestUtils + + +class TestSlice(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + def op_calc(self, a, b, dim, step): + if dim == 0: + target = a.shape[0] + end = target // step + a = a[:end:, ::, ::, ::] + b = b[:end:, ::, ::, ::] + elif dim == 1: + target = a.shape[1] + end = target // step + a = a[::, :end:, ::, ::] + b = b[::, :end:, ::, ::] + elif dim == 2: + target = a.shape[2] + end = target // step + a = a[::, ::, :end:, ::] + b = b[::, ::, :end:, ::] + elif dim == 3: + target = a.shape[3] + end = target // step + a = a[::, ::, ::, :end:] + b = b[::, ::, ::, :end:] + y = a + b + return y + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 8, 256, 128)]) + @pytest.mark.parametrize('dtype', ['float32', 'int32', 'float16', 'bfloat16', 'int64']) + def test_view_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + a = self._generate_tensor(shape, dtype) + b = self._generate_tensor(shape, dtype) + + for dim in [3, 2, 1, 0]: + print(f"start to test slice on dim :{dim} ") + std_slice = self.op_calc(a, b, dim, min(shape)//2) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_slice = compiled_op_calc(a, b, dim, min(shape)//2) + + torch.testing.assert_close(std_slice, inductor_slice, rtol=1e-3, atol=1e-3) + + print("data validation passed") + diff --git a/test/_inductor/test_split_loop.py b/test/_inductor/test_split_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..5682a22836ec958114bc752c30b3646d180e7a55 --- /dev/null +++ b/test/_inductor/test_split_loop.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestSplitLoop(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, a, b): + return torch.nn.functional.gelu(a + b) + + # case:change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8,86,1152),(61,89,157),(7,89,971)]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_split_loop(self, shape, dtype): + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + a = self._generate_tensor(shape, dtype) + b = self._generate_tensor((shape[0],1,shape[2]), dtype) + + std_ = self.op_calc(a, b) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + inductor_ = compiled_op_calc(a, b) + # print(f"inductor_cat.shape= {inductor_cat.shape}") + torch.testing.assert_close(std_,inductor_,atol=1e-3,rtol=1e-3) + + +if __name__ == "__main__": + size = (8,86,1152) + test = TestSplitLoop() + test.test_split_loop(size, 'float32') diff --git a/test/_inductor/test_sqrt.py b/test/_inductor/test_sqrt.py new file mode 100644 index 0000000000000000000000000000000000000000..201b646f9c2cb369ce5adac6f57a4edf629a6935 --- /dev/null +++ b/test/_inductor/test_sqrt.py @@ -0,0 +1,44 @@ +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from .testutils import OperatorType, TestUtils + + +class TestSqrt(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, first_element): + result = torch.sqrt(first_element) + return result + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32', 'int64']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype, 1) + + std_result = self.op_calc(first_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(first_element) + # print(std_result[0:8]) + # print(inductor_result[0:8]) + # torch.testing.assert_close(std_result, inductor_result) + # 需要比较包含 NaN 值的张量,并且希望认为两个 NaN 值是相等的,您可以使用 torch.allclose 函数,并设置 equal_nan=True 参数 + rtol = 1e-1 + atol = 1e-1 + assert torch.allclose(std_result, inductor_result, equal_nan=True, rtol=rtol, atol=atol) + diff --git a/test/_inductor/test_sub.py b/test/_inductor/test_sub.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc44938c3554b79353b574e498a6490dd6378a6 --- /dev/null +++ b/test/_inductor/test_sub.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from .testutils import OperatorType, TestUtils + + +class TestSub(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, first_element, second_element): + result = first_element - second_element + return result + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32', 'int64']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + first_element = self._generate_tensor(shape, dtype) + second_element = self._generate_tensor(shape, dtype) + + std_sub = self.op_calc(first_element, second_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_sum = compiled_op_calc(first_element, second_element) + # print(std_sub[0:8]) + # print(inductor_sum[0:8]) + torch.testing.assert_close(std_sub, inductor_sum) diff --git a/test/_inductor/test_sum.py b/test/_inductor/test_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..6b13e88c48b854c01b2e63fe5201907fd233d598 --- /dev/null +++ b/test/_inductor/test_sum.py @@ -0,0 +1,75 @@ +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + + +class TestSum(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.REDUCTION + + def op_calc(self, input_element, dim): + return torch.sum(input_element, dim) + # 规约轴和非规约轴对齐用例 float32 XBLOCK_SUB>=8:shape=(8,32) + # non-persistent reduction 用例 规约轴>1024:shape=(8,8,8,2048) dim=-1 + _reduction_extest_shape4d_all = [(8, 32), (8, 8, 8, 2048)] + _reduction_extest_dim4d_low = [-1] + _reduction_extest_dim4d_all = [0, 1, 2] + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的case + # 若需测试更多数据类型,将dtype手动修改,若在一个ut中涉及多个dtype的更改,可能因为tiling固化导致失败 + # 对indexing开关情况的测试需要用外部参数--npu-indexing=True/False完成 + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', _reduction_extest_shape4d_all) + @pytest.mark.parametrize('dim', _reduction_extest_dim4d_low) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape={shape}") + print(f"dim={dim}") + print(f"dtype={dtype}") + print('npu_indexing={}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + std_sum = self.op_calc(input_element, dim) + compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + inductor_sum_tmp = compiled_op_calc(input_element, dim) + if dtype == 'int32' or dtype == 'int64': + # inductor return float32,need to change int64 for assert + inductor_sum = inductor_sum_tmp.long() + elif dtype == 'float16': + # inductor return float32,need to change float16 for assert + inductor_sum = inductor_sum_tmp.half() + elif dtype == 'bfloat16': + # inductor return float32,need to change float32 for assert + std_sum = std_sum.float() + inductor_sum = inductor_sum_tmp + else: + inductor_sum = inductor_sum_tmp + + # print(f"std_sum={std_sum[0:8]}") + # print(f"inductor_sum={inductor_sum[0:8]}") + torch.testing.assert_close(std_sum, inductor_sum, rtol=1e-1, atol=1e-1) + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(32, 16, 64, 128)]) + @pytest.mark.parametrize('dim', _reduction_extest_dim4d_all) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_dims(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + std_sum = self.op_calc(input_element, dim) + compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + inductor_sum = compiled_op_calc(input_element, dim) + + torch.testing.assert_close(std_sum, inductor_sum, rtol=1e-1, atol=1e-1) + +if __name__ == "__main__": + size = (32, 16, 64, 128) + test = TestSum() + test.test_reduction_cases_shapes(size, 2, 'float32', None) diff --git a/test/_inductor/test_sum_add.py b/test/_inductor/test_sum_add.py new file mode 100644 index 0000000000000000000000000000000000000000..670623d722fd07904631ed6c3028a3474eb40729 --- /dev/null +++ b/test/_inductor/test_sum_add.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestSumAdd(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.REDUCTION + def op_calc(self, input_element, dim, input_element2): + tmp = torch.sum(input_element, dim) + return tmp + input_element2 + + # case:change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(32, 64, 128, 2048)]) + @pytest.mark.parametrize('dim', [0, 1, 2, 3]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + if dim == -1 or dim == 3: + input_element2 = torch.full(size=(32, 64, 128), fill_value=1000.0, dtype=torch.float32, device=torch.device("npu")) + elif dim == 2: + input_element2 = torch.full(size=(32, 64, 2048), fill_value=1000.0, dtype=torch.float32, device=torch.device("npu")) + elif dim == 1: + input_element2 = torch.full(size=(32, 128, 2048), fill_value=1000.0, dtype=torch.float32, device=torch.device("npu")) + else: + input_element2 = torch.full(size=(64, 128, 2048), fill_value=1000.0, dtype=torch.float32, device=torch.device("npu")) + + std_sum = self.op_calc(input_element, dim, input_element2) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_sum = compiled_op_calc(input_element, dim, input_element2) + + torch.testing.assert_close(std_sum, inductor_sum, rtol=1e-1, atol=1e-1) + + +if __name__ == "__main__": + size = (32, 64, 128, 2048) + test = TestSumAdd() + test.test_reduction_cases_shapes(size, -1, 'float32', None) \ No newline at end of file diff --git a/test/_inductor/test_var.py b/test/_inductor/test_var.py new file mode 100644 index 0000000000000000000000000000000000000000..5c583452c8d54e5ab178b51271504fb6081b7fb6 --- /dev/null +++ b/test/_inductor/test_var.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestVar(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, input_element, dim): + return torch.var(input_element, dim) + + # case:change shapes + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 64, 128)]) + @pytest.mark.parametrize('dim', [0, 1, 2]) + @pytest.mark.parametrize('dtype', ['float16']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + std_ret = self.op_calc(input_element, dim) + # print(f"std_ret= {std_ret}") + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_ret = compiled_op_calc(input_element, dim) + # print(f"inductor_ret= {inductor_ret}") + rtol = 1e-1 + atol = 1e-1 + assert torch.allclose(std_ret, inductor_ret, equal_nan=True, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + size = (8, 64, 128) + test = TestVar() + test.test_reduction_cases_shapes(size, 2, 'float32', None) \ No newline at end of file diff --git a/test/_inductor/test_var_mean.py b/test/_inductor/test_var_mean.py new file mode 100644 index 0000000000000000000000000000000000000000..a36403daabde8f546fb1af2c86c6fd03d6e143fa --- /dev/null +++ b/test/_inductor/test_var_mean.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor + +import pytest +from testutils import OperatorType, TestUtils + +class TestVarMean(TestUtils): + __TIME_LIMIT = 100 + + def op_calc(self, input_element, dim): + return torch.var_mean(input_element, dim) + + # case:The shape must not be too large + #@pytest.mark.skip(reason="npu compiler bug") + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', [(8, 64, 128)]) + @pytest.mark.parametrize('dim', [0, 1, 2, (0, 2), (0, 1)]) + @pytest.mark.parametrize('dtype', ['float32']) + def test_reduction_cases_shapes(self, shape, dim, dtype, clear_cache): + print(f"shape= {shape}") + print(f"dim= {dim}") + print(f"dtype= {dtype}") + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + input_element = self._generate_tensor(shape, dtype) + + std_var, std_mean = self.op_calc(input_element, dim) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + inductor_var, inductor_mean = compiled_op_calc(input_element, dim) + + rtol = 1e-1 + atol = 1e-1 + assert torch.allclose(std_var, inductor_var, equal_nan=True, rtol=rtol, atol=atol) + assert torch.allclose(std_mean, inductor_mean, equal_nan=True, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + size = (8, 64, 1024) + test = TestVarMean() + test.test_reduction_cases_shapes(size, 2, 'float32', None) \ No newline at end of file diff --git a/test/_inductor/test_var_mean_add_mul.py b/test/_inductor/test_var_mean_add_mul.py new file mode 100644 index 0000000000000000000000000000000000000000..a20cdfe54749b0ea17d2b73667c26770be988d8d --- /dev/null +++ b/test/_inductor/test_var_mean_add_mul.py @@ -0,0 +1,45 @@ +import torch +import torch_npu +import torch_npu._inductor +import pytest + +__TIME_LIMIT = 100 +@pytest.mark.timeout(__TIME_LIMIT) +def test_reduction_cases_shapes(): + device = 'npu' + + def forward(add: "f32[1, 2, 2304]", primals_2: "f32[32, 2304]", primals_5: "f32[1, 9600, 2304]"): + split = torch.ops.aten.split.Tensor(add, 1, 1); + getitem: "f32[1, 1, 2304]" = split[0] + getitem_1: "f32[1, 1, 2304]" = split[1]; + + var_mean = torch.ops.aten.var_mean.correction(primals_5, [2], correction=0, keepdim=True) + getitem_2: "f32[1, 9600, 1]" = var_mean[0] + getitem_3: "f32[1, 9600, 1]" = var_mean[1]; + add_1: "f32[1, 9600, 1]" = torch.ops.aten.add.Tensor(getitem_2, 1e-06); + rsqrt: "f32[1, 9600, 1]" = torch.ops.aten.rsqrt.default(add_1); + sub: "f32[1, 9600, 2304]" = torch.ops.aten.sub.Tensor(primals_5, getitem_3) + mul: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(sub, rsqrt); + + add_2: "f32[1, 1, 2304]" = torch.ops.aten.add.Tensor(getitem_1, 1); + mul_1: "f32[1, 9600, 2304]" = torch.ops.aten.mul.Tensor(mul, add_2); + add_3: "f32[1, 9600, 2304]" = torch.ops.aten.add.Tensor(mul_1, getitem); + + view: "f32[9600, 2304]" = torch.ops.aten.view.default(add_3, [9600, 2304]); + return [None, primals_5, getitem_3, rsqrt, add_2, view, primals_2] + + torch_npu._inductor.config.enable_npu_indexing = True + primals_2: "f32[32, 2304]" = torch.randn((32, 2304), device = device, dtype=torch.float32) + primals_5: "f32[1, 9600, 2304]" = torch.randn((1, 9600, 2304), device = device, dtype=torch.float32) + add: "f32[1, 2, 2304]" = torch.randn((1, 2, 2304), device =device, dtype=torch.float32) + + _, primals_5_ref, getitem_3_ref, rsqrt_ref, add_2_ref, view_ref, primals_2_ref = forward(add, primals_2, primals_5) + + forward = torch.compile(forward, backend="inductor", dynamic=False) + _, primals_5, getitem_3, rsqrt, add_2, view, primals_2 = forward(add, primals_2, primals_5) + + assert torch.allclose(primals_5_ref, primals_5, equal_nan=True, rtol=1e-3, atol=1e-3) + assert torch.allclose(getitem_3_ref, getitem_3, equal_nan=True, rtol=1e-3, atol=1e-3) + assert torch.allclose(rsqrt_ref, rsqrt, equal_nan=True, rtol=1e-3, atol=1e-3) + assert torch.allclose(add_2_ref, add_2, equal_nan=True, rtol=1e-3, atol=1e-3) + assert torch.allclose(primals_2_ref, primals_2, equal_nan=True, rtol=1e-3, atol=1e-3) \ No newline at end of file diff --git a/test/_inductor/test_where.py b/test/_inductor/test_where.py new file mode 100644 index 0000000000000000000000000000000000000000..b10b0aa3d98fe5fe5cb743f1277368a16c036ec4 --- /dev/null +++ b/test/_inductor/test_where.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +import torch_npu._inductor +import pytest +from testutils import OperatorType, TestUtils + + +class TestWhere(TestUtils): + __TIME_LIMIT = 100 + __OPTYPE = OperatorType.POINTWISE + + # optimized function, auto timeout after __TIME_LIMIT seconds + + # @torch.compile(options={"aggressive_fusion": False}) + + def op_calc(self, condition, first_element, second_element): + return torch.where(condition, first_element, second_element) + + # 在连续测试场景下,测试结果不稳定,建议单独重测批量测试未通过的 case + # 若需测试更多数据类型,将dtype后面的list改成 ProtoTestCase._test_dtypes即可 + # 对indexing开关情况的测试需要用外部参数--npu_indexing=True/False完成 + + @pytest.mark.timeout(__TIME_LIMIT) + @pytest.mark.parametrize('shape', TestUtils._pointwise_demo_shapes) + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'bfloat16', 'int32']) + def test_pointwise_cases(self, shape, dtype, clear_cache): + print(shape) + print('npu_indexing= {}'.format(torch_npu._inductor.config.enable_npu_indexing)) + + first_element = self._generate_tensor(shape, dtype) + second_element = self._generate_tensor(shape, dtype) + condition = self._generate_tensor(shape, 'bool') + + std_result = self.op_calc(condition, first_element, second_element) + + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") + inductor_result = compiled_op_calc(condition, first_element, second_element) + + torch.testing.assert_close(std_result, inductor_result) \ No newline at end of file diff --git a/test/_inductor/testutils.py b/test/_inductor/testutils.py new file mode 100644 index 0000000000000000000000000000000000000000..3559820fc21e9a32a4559dbe122b96b7b3806e7d --- /dev/null +++ b/test/_inductor/testutils.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import torch +import torch_npu +from enum import Enum, unique +import os + + +@unique +class OperatorType(Enum): + POINTWISE = 'POINTWISE' + REDUCTION = 'REDUCTION' + + +class TestUtils: + _pointwise_test_shape2d = [(4096, 256), (1024, 32), (8, 2048), (8, 4096)] # (8, 4), (8, 8), not supported + _pointwise_test_shape3d = [(8, 8, 4), (8, 8, 8), (8, 8, 2048), (8, 8, 4096)] + _pointwise_test_shape4d = [(128, 128, 4096, 4), (128, 128, 4096, 8), + (32, 32, 1024, 1024)] # 128*128*4096*2048 is too big(512G) + _pointwise_test_shapes = _pointwise_test_shape2d + _pointwise_test_shape3d + _pointwise_test_shape4d + + _pointwise_demo_shapes = [(1024, 32), (8, 16, 256, 32)] + _reduction_extest_shape4d = [(8, 8, 8, 16384), (8, 8, 16384, 8), (8, 16384, 8, 8), (16384, 8, 8, 8)] + _reduction_extest_dim4d = [-1, -2, 1, 0] + _reduction_extest_SDbinding = list(zip(_reduction_extest_shape4d, _reduction_extest_dim4d)) + + _test_dtypes = ['float32', 'int32', 'float16', 'bfloat16', 'int64'] + + @staticmethod + def _generate_tensor(shape, dtype, floatPOSIFLAG=0): + if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': + if floatPOSIFLAG: + return 1000 * torch.rand(size=shape, dtype=eval('torch.' + dtype), device=torch.device("npu")) + else: + return torch.randn(size=shape, dtype=eval('torch.' + dtype), device=torch.device("npu")) * 2000 + elif dtype == 'int32' or dtype == 'int64': + return torch.randint(low=0, high=2000, size=shape, dtype=eval('torch.' + dtype), device=torch.device("npu")) + elif dtype == 'bool': + return torch.randint(low=0, high=2, size=shape, device=torch.device("npu")).bool() + else: + raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) \ No newline at end of file diff --git a/torch_npu/_inductor/__init__.py b/torch_npu/_inductor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc3e21dc93371cd5c31a3e890f33cd1c1ca968e3 --- /dev/null +++ b/torch_npu/_inductor/__init__.py @@ -0,0 +1,79 @@ +import torch +from torch._inductor.codegen.common import register_backend_for_device, register_device_op_overrides +from torch._dynamo.device_interface import register_interface_for_device, get_interface_for_device +from torch._inductor import lowering as inductor_lowering +from torch._inductor.choices import InductorChoices +from .npu_device import NewNPUDeviceOpOverrides, NewNpuInterface +from torch._inductor.runtime import autotune_cache +from torch_npu.utils._inductor import NPUDeviceOpOverrides +from torch_npu.utils._dynamo_device import NpuInterface, current_device, set_device +from torch_npu.npu.utils import device_count + +from .lowering import make_reduction +from .decomposition import _register_npu_inductor_decompositons +from .utils import get_current_raw_stream +from .config import log as npulog +from .config import aggresive_autotune, num_vector_core +from .npu_choices import should_use_persistent_reduction +from . import config as npu_config + +from .runtime import _load_cached_autotuning + +npulog.info("perform torch_npu._inductor patch") +import torch +from torch_npu.utils._inductor import NPUDeviceOpOverrides +from torch_npu.utils._dynamo_device import NpuInterface, current_device, set_device +from torch_npu.npu.utils import device_count + +def _inductor_register_backend_for_device(): + from .codegen.schduling import NPUTritonScheduling + from .codegen.wrapper import NPUWrapperCodeGen + from .codegen.cppwrapper import CppWrapperNpu + register_backend_for_device('npu', NPUTritonScheduling, NPUWrapperCodeGen, CppWrapperNpu) + +_inductor_register_backend_for_device() + +def _inductor_register_device_op_overrides(): + register_device_op_overrides('npu', NewNPUDeviceOpOverrides()) + +_inductor_register_device_op_overrides() +register_interface_for_device("npu", NewNpuInterface) +for i in range(16) : + register_interface_for_device(f"npu:{i}", NewNpuInterface) +device = get_interface_for_device("npu") + +from . import codegen + +inductor_lowering.make_reduction = make_reduction + + +if npu_config.check_accuracy: + from .codegen.ir_fx import _patch_npu_inductor_ir + _patch_npu_inductor_ir() + +if npu_config.check_accuracy: + from .lowering_fx import _register_npu_inductor_fallbacks +else: + from .lowering import _register_npu_inductor_fallbacks + +_register_npu_inductor_fallbacks() +_register_npu_inductor_decompositons() + +#register fx_pass should be put behind of _register_npu_inductor_decompositons + +from . import npu_fusion_attention_graph +from . import dynamo_embedding_backward_dispatch + +def _replace_benchmark_all_configs(): + from torch._inductor.triton_heuristics import CachingAutotuner + from .npu_triton_heuristics import benchmark_all_configs + CachingAutotuner.benchmark_all_configs = benchmark_all_configs + + +if (aggresive_autotune): + _replace_benchmark_all_configs() + import os + os.environ["TRITON_BENCH_METHOD"] = "npu" + +InductorChoices.should_use_persistent_reduction = should_use_persistent_reduction +autotune_cache._load_cached_autotuning = _load_cached_autotuning \ No newline at end of file diff --git a/torch_npu/_inductor/codegen/__init__.py b/torch_npu/_inductor/codegen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a6832665e9c1afe472d174a93d26dce1337e61 --- /dev/null +++ b/torch_npu/_inductor/codegen/__init__.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + + + +from torch._inductor.ir import Reduction, LoopBody +from torch._inductor.codegen.triton import TritonScheduling +from torch._inductor import sizevars +from torch._inductor.codegen.triton import TritonKernel +from torch._inductor.codegen.simd import SIMDKernel + +from torch_npu._inductor.codegen._sizevars import simplify +from torch_npu._inductor.codegen.ir import (num_splits, loopbody__call__, transform_dims_in_indexing, substituted_dims_in_indexing) +from torch_npu._inductor.codegen.triton import is_compatible +from torch_npu._inductor.codegen.triton import group_fn, select_index_dtype +from torch_npu._inductor.codegen.schduling import create_tiling + +from ..config import log as npulog +npulog.info("perform npu_indexing patch") +#graph +#common +#ir + + +Reduction.num_splits = num_splits +setattr(LoopBody, 'transform_dims_in_indexing', transform_dims_in_indexing) +setattr(LoopBody, 'substituted_dims_in_indexing', substituted_dims_in_indexing) + +LoopBody.__call__ = loopbody__call__ +#need to enable this to speedup attn_cp_test +#ComputedBuffer.simplify_and_reorder = simplify_and_reorder +#triton scheduling +TritonScheduling.group_fn = group_fn +TritonScheduling.select_index_dtype = select_index_dtype +TritonScheduling.create_tiling = create_tiling +#triton kernel +setattr(SIMDKernel, 'is_compatible', is_compatible) + +#util +sizevars.SizeVarAllocator.simplify = simplify \ No newline at end of file diff --git a/torch_npu/_inductor/codegen/_sizevars.py b/torch_npu/_inductor/codegen/_sizevars.py new file mode 100644 index 0000000000000000000000000000000000000000..84206554041b15e3930fead7d0759bb3b9c8ab8e --- /dev/null +++ b/torch_npu/_inductor/codegen/_sizevars.py @@ -0,0 +1,10 @@ +import sympy +from sympy import Expr +from torch._inductor.utils import sympy_subs + + +def simplify(self, expr: Expr): + if isinstance(expr, (tuple, list)): + return [sympy.expand(s).xreplace(self.replacements) for s in expr] + return sympy.expand(expr).xreplace(self.replacements) + diff --git a/torch_npu/_inductor/codegen/cppwrapper.py b/torch_npu/_inductor/codegen/cppwrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..ffa695c9109b927b3d0cb90ac7c6d699fb9cc280 --- /dev/null +++ b/torch_npu/_inductor/codegen/cppwrapper.py @@ -0,0 +1,737 @@ +# mypy: allow-untyped-defs +import functools +import os +from itertools import chain, count, zip_longest +from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, Union + +import sympy +import torch +from torch import dtype as torch_dtype +from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name +from torch._inductor.runtime.runtime_utils import dynamo_timed +from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn + +from torch._inductor import config +from torch._inductor.codecache import CudaKernelParamCache +from torch._inductor.ir import IRNode, TensorBox +from torch._inductor.utils import DeferredLineBase +from torch._inductor.virtualized import V +from torch._inductor.codegen.aoti_hipify_utils import maybe_hipify_code_wrapper +from torch._inductor.codegen.common import get_device_op_overrides +from torch._inductor.codegen.cpp_utils import cexpr, DTYPE_TO_CPP +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch._inductor.codegen.multi_kernel import MultiKernelCall +from torch._inductor.codegen.wrapper import PythonWrapperCodegen, SymbolicCallArg + +from ..config import npu_block as NPU_ALIGN_BYTES + +if TYPE_CHECKING: + from torch._inductor.graph import GraphLowering + +def checkIfTrue(value, msg): + if not value : + raise RuntimeError(msg) + return True + +class DeferredNpuKernelLine(DeferredLineBase): + """ + When using cpp wrapper, NPU kernel load and launch needs to wait for Triton kernels + to be tuned and stored as cubin files, so use a deferred line to backfill those information + """ + + def __init__( + self, + kernel_name: str, + line_template: str, + keys: Tuple[str, ...], + additional_files: List[str], + ): + super().__init__(line_template) + checkIfTrue(not isinstance(line_template, DeferredLineBase), "line template can not be DeferredLineBase") + self.additional_files = additional_files + self.kernel_name = kernel_name + self.line_template = line_template + self.keys = keys + + def __call__(self): + if self.kernel_name.startswith("multi_kernel_"): + # MultiKernel will select one kernel after running the autotune block + self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) + params = CudaKernelParamCache.get(self.kernel_name) + checkIfTrue(params is not None, f"{self.kernel_name} not found in CudaKernelParamCache") + + for key in self.keys: + checkIfTrue(key in params, f"{key} not found in CudaKernelParamCache[{self.kernel_name}]") + + if key == get_cpp_wrapper_cubin_path_name(): + checkIfTrue(os.path.exists(params[key]), f"{params[key]} does not exist") + self.additional_files.append(params[key]) + + return self.line_template % tuple(params[key] for key in self.keys) + + def _new_line(self, line): + return DeferredNpuKernelLine( + self.kernel_name, line, self.keys, self.additional_files + ) + + +class DeferredNpuDefaultGrid: + """ + A container for the default grid, which may be used by DeferredNpuGridLine + """ + + def __init__( + self, + kernel_name: str, + grid, + grid_callable: Optional[Callable[..., Any]] = None, + **grid_extra_kwargs, + ): + self.kernel_name = kernel_name + self.grid = grid + self.grid_callable = grid_callable + self.grid_extra_kwargs = grid_extra_kwargs + + def __iter__(self): + # DeferredNpuDefaultGrid can be passed to the base class, PythonWrapperCodegen, + # to generate the autotune code block, and thus we need this iterator + return iter(self.grid) + + def _process_grid(self, grid: Union[List[Any], Tuple[Any, ...]]): + if isinstance(grid, (list, tuple)): + return [self._process_grid(e) for e in grid] + else: + return grid.inner_expr if isinstance(grid, SymbolicCallArg) else grid + + def __call__(self): + if self.kernel_name.startswith("multi_kernel_"): + # MultiKernel will select one kernel after running the autotune block + self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) + + grid = self.grid + checkIfTrue(isinstance(grid, (list, tuple)), f"expected {grid=} to be a list") + + grid = self._process_grid(grid) + + checkIfTrue(self.grid_callable is not None, "grid_callable can't be None") + + if not self.grid_extra_kwargs: + grid_fn = self.grid_callable(*grid) + else: + grid_fn = self.grid_callable(*grid, **self.grid_extra_kwargs) + + params = CudaKernelParamCache.get(self.kernel_name) + checkIfTrue(params is not None, f"{self.kernel_name} not found in CudaKernelParamCache") + + return grid_fn(params["meta"]) + + +class DeferredNpuGridLine(DeferredLineBase): + """ + When using cpp wrapper, NPU kernel load and launch needs to wait for Triton kernels + to be tuned and stored as cubin files, so use a deferred line to backfill those information + """ + + def __init__( + self, + kernel_name: str, + grid_var: str, + grid, + autotune_configs, + ): + super().__init__("") + self.kernel_name = kernel_name + self.grid_var = grid_var + self.grid = grid + self.autotune_configs = autotune_configs + + def __call__(self): + if self.kernel_name.startswith("multi_kernel_"): + # MultiKernel will select one kernel after running the autotune block + self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) + + params = CudaKernelParamCache.get(self.kernel_name) + + checkIfTrue(params is not None, f"{self.kernel_name} not found in CudaKernelParamCache") + + if self.autotune_configs is not None: + # This indicates the Triton kernel is a user-defined one. + grid = None + if len(self.grid) == 1: + grid = self.grid[0] + else: + for i, c in enumerate(self.autotune_configs): + if all(arg == params["meta"][key] for key, arg in c.kwargs.items()): + grid = self.grid[i] + break + checkIfTrue(grid is not None, "grid can not be None") + grid_args_str = ", ".join( + [cexpr(V.graph.sizevars.simplify(item)) for item in grid] + ) + else: + launch_grid = (params['grid_x'], params['grid_y'], params['grid_z']) + grid_args_str = ", ".join( + [cexpr(item) for item in launch_grid] + ) + + return f"\n Grid {self.grid_var} = Grid({grid_args_str});\n" + + def _new_line(self, line): + return DeferredNpuGridLine( + self.kernel_name, self.grid_var, self.grid, self.autotune_configs + ) + + +class CppWrapperNpu(CppWrapperCpu): + """ + Generates cpp wrapper for running on NPU and calls CUDA kernels + """ + + def __init__(self) -> None: + self.device = 'npu' + self.device_codegen = get_device_op_overrides(self.device) + super().__init__() + self.grid_id = count() + + @staticmethod + def create( + is_subgraph: bool, subgraph_name: str, parent_wrapper: PythonWrapperCodegen + ): + # TODO - support subgraph codegen by lifting functions. Check the + # comment at CppWrapperCpu `codegen_subgraph` function. + return CppWrapperNpu() + + def write_header(self): + if V.graph.is_const_graph: + # We do not write header for constant graph, it will be written by main module. + return + + super().write_header() + self.header.splice("#include ") + self.header.splice("#include ") + self.header.splice(self.device_codegen.abi_compatible_header()) + self.header.splice( + maybe_hipify_code_wrapper(self.device_codegen.kernel_driver()) + ) + self.header.splice("#include ") + self.header.splice("#include \"experiment/runtime/runtime/rt.h\"") + + def write_get_raw_stream(self, device_idx: int, graph=None) -> str: + name = f"stream{device_idx}" + self.writeline( + maybe_hipify_code_wrapper( + f"{self.device_codegen.cpp_stream_type()} {name};" + ) + ) + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK({self.device_codegen.aoti_get_stream()}({device_idx}, (void**)&{name}));" + ) + return name + + def codegen_inputs(self): + # See Note: [Input Alignment handling in Inductor] + # + # JIT Inductor does not guard on input alignment. It relies on copy_misaligned_inputs to + # copy misaligned inputs to aligned buffers. For AOTInductor, we expect users to use it + # as non-Python deployment for its best performance, so implicitly copying misaligned inputs + # to aligned buffers is going to bring a surprising performance hit. Instead, we check input + # alignment and throw an error if any input is misaligned. + if V.graph.aot_mode and V.graph.inputs_to_check: + for idx in V.graph.inputs_to_check: + input_name = V.graph.graph_input_names[idx] + checkIfTrue(input_name in V.graph.graph_inputs, f"{input_name} not found in graph inputs") + + value = V.graph.graph_inputs[input_name] + checkIfTrue(isinstance(value, TensorBox), f"{input_name} is expected to be tensor but found as {type(value)}") + + self.prefix.splice( + f""" + if ((long({input_name}.data_ptr()) & ({NPU_ALIGN_BYTES} -1)) != 0) {{ + throw std::runtime_error("{input_name} is not aligned to {NPU_ALIGN_BYTES} bytes"); + }} + """ + ) + + super().codegen_inputs() + + def define_kernel( + self, + kernel_name: str, + kernel_body: str, + metadata: Optional[str] = None, + gpu=True, + ): + if gpu: + if config.triton.autotune_at_compile_time: + # Call PythonWrapperCodegen to create the autotune code block + PythonWrapperCodegen.define_kernel( + self, kernel_name, kernel_body, metadata, gpu + ) + else: + return CppWrapperCpu.define_kernel( + self, kernel_name, kernel_body, metadata, gpu + ) + + def generate(self, is_inference): + with dynamo_timed("CppWrapperNpu.generate", log_pt2_compile_event=True): + self.prefix.writeline("\n") + if not V.graph.aot_mode: + for kernel in chain( + sorted(self.src_to_kernel.values()), + sorted( + [entry[0] for entry in self.user_defined_kernel_cache.values()] + ), + ): + self.prefix.writeline( + maybe_hipify_code_wrapper( + f"static {self.device_codegen.cpp_kernel_type()} {kernel} = nullptr;" + ) + ) + self.prefix.writeline("\n") + return super().generate(is_inference) + + def generate_user_defined_triton_kernel( + self, + kernel_name: str, + raw_args: List[Any], + grid: List[Any], + configs, + triton_meta, + constexprs, + ): + if ( + config.triton.autotune_at_compile_time + and kernel_name not in self.kernel_autotune_names + ): + # Call PythonWrapperCodegen to create the autotune code block + PythonWrapperCodegen.generate_user_defined_triton_kernel( + self, + kernel_name, + raw_args, + grid, + configs, + triton_meta, + constexprs, + ) + + # in C++ wrapper, we don't pass constexpr args, as they don't + # get added as parameters to the PTX code compiled from the + # user-defined Triton kernel (only non-constexpr args do) + raw_args = [ + raw_arg for i, raw_arg in enumerate(raw_args) if i not in constexprs + ] + args = [self.val_to_arg_str(v) for v in raw_args] + arg_types = [ + arg.get_dtype() if isinstance(arg, IRNode) else type(arg) + for arg in raw_args + ] + + # Call self.generate_kernel_call to generate the real kernel call in cpp + self.generate_kernel_call( + kernel_name, + args, + arg_types=arg_types, + raw_args=raw_args, + grid=grid, + gpu=True, + triton=True, + triton_meta=triton_meta, + autotune_configs=configs, + ) + + + @functools.lru_cache(None) # noqa: B019 + def generate_load_kernel_once( + self, + kernel_name: str, + device_index, + graph: "GraphLowering", # for per-graph caching + ): + """ + typedef struct { + const char *name; //mangled_name + const char *kernelPath; //get_cpp_wrapper_cubin_path_name() + int shared; // shared_mem + int device; // device_index + } LoadKernelArgs; + """ + + # keys = ("mangled_name", get_cpp_wrapper_cubin_path_name() , "shared_mem") + keys = (get_cpp_wrapper_cubin_path_name(), "mangled_name", "shared_mem") + kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name + self.writeline(f"if ({kernel_var_name} == nullptr) {{") + deferred_gpu_kernel_line = DeferredNpuKernelLine( + kernel_name, + # " " + kernel_var_name + r' = loadKernel("%s", "%s", %s, {});'.format( + # device_index + # ), + " " + kernel_var_name + r' = loadKernel("%s", "%s", %s);', + keys, + self.additional_files, + ) + self.writeline(deferred_gpu_kernel_line) + self.writeline("}") + return kernel_var_name + + def codegen_tensor_item_npu( + self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None + ): + dtype_str = str(dtype).split(".")[-1] + writer = indented_buffer or self + + if dtype == torch.float16 or dtype == torch.bfloat16: + scalar_tmp = f"{scalar}_tmp" + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar_tmp};") + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));" + ) + writer.writeline(f"float {scalar} = float({scalar_tmp});") + struct_data = f'float {scalar} __attribute__((aligned(4)));' + arg_data = f'static_cast({scalar})' + else: + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};") + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));" + ) + struct_data = f'{DTYPE_TO_CPP[dtype]} {scalar} __attribute__((aligned(sizeof({DTYPE_TO_CPP[dtype]} ))));' + arg_data = f'static_cast<{DTYPE_TO_CPP[dtype]}>({scalar})' + + return struct_data, arg_data + + def generate_args_decl(self, call_args, arg_types, arg_signatures, kernel_id, grid_var): + new_args: list[str] = [] + + # Add more cases for other types as needed + signature2dtype = { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + } + + kernel_args_var = f"kernel_args_var_{kernel_id}" + + rtError_t_str = f'ret_{kernel_id}' + ffts_addr_str = f'ffts_addr_{kernel_id}' + ffts_len_str = f'ffts_len_{kernel_id}' + workspace_addr_str = f'workspace_addr_{kernel_id}' + + before_strucr_str =\ + f"\n rtError_t {rtError_t_str};\n" +\ + f" void* {ffts_addr_str} = NULL;\n" +\ + f" uint32_t {ffts_len_str};\n" +\ + f" {rtError_t_str} = rtGetC2cCtrlAddr((uint64_t*)&{ffts_addr_str}, &{ffts_len_str});\n" +\ + f" if ({rtError_t_str} != RT_ERROR_NONE) return;\n" +\ + f" void* {workspace_addr_str} = NULL;\n\n" + + struct_def_head = f' struct __attribute__((packed)) {{\n void* ffts_addr __attribute__((aligned(8)));\n void* workspace_addr __attribute__((aligned(8)));\n' + struct_def_end = f'\n int32_t gridX __attribute__((aligned(4))); int32_t gridY __attribute__((aligned(4))); int32_t gridZ __attribute__((aligned(4)));\n }}' + + struct_arg_head = f' {kernel_args_var} = {{\n static_cast({ffts_addr_str}),\n static_cast({workspace_addr_str}),\n' + struct_arg_end = f'\n static_cast({grid_var}.grid_x), static_cast({grid_var}.grid_y), static_cast({grid_var}.grid_z)\n }};\n' + + struct_def_body = ' ' + struct_arg_body = ' ' + + def process_args(arg, arg_type, arg_signature=None): + var_name = f"var_{next(self.arg_var_id)}" + # ignore nvTmaDesc, as host-side TMA descriptors need + # to be passed to the compiled Triton kernel by value + if isinstance(arg_type, torch_dtype) and arg_signature != "nvTmaDesc": + if arg.endswith(".item()"): # scalar + # Need to declare a scalar in this case + arg = arg[:-7] + # TODO: override to return dtype + struct_data, arg_data = self.codegen_tensor_item_npu( + arg_type, + arg, + var_name, + ) + else: + # TODO: void* + device_ptr_type = self.device_codegen.cpp_device_ptr() + self.writeline( + maybe_hipify_code_wrapper( + f"{device_ptr_type} {var_name} = reinterpret_cast<{device_ptr_type}>({arg}.data_ptr());" + ) + ) + struct_data = f'void* {var_name} __attribute__((aligned(8)));' + arg_data = f'static_cast({var_name})' + + elif arg_type in (sympy.Integer, int): + # TODO: int + self.writeline(f"int {var_name} = {cexpr(arg)};") + struct_data = f'int {var_name} __attribute__((aligned(4)));' + arg_data = f'static_cast({var_name})' + + elif arg_type in (sympy.Float, float): + # TODO: float + self.writeline(f"float {var_name} = {cexpr(arg)};") + struct_data = f'float {var_name} __attribute__((aligned(4)));' + arg_data = f'static_cast({var_name})' + + # For symbolic call arguments, examine the arg signatures from triton meta + # to explicitly cast to the right type + # Reason: `auto` can infer unexpected type against kernel input signature. + elif ( + isinstance(arg_type, type(SymbolicCallArg)) + and arg_signature is not None + and arg_signature in signature2dtype.keys() + ): + # TODO: * or scalar symbolic type,currently only support scalar symbolic type + self.writeline( + f"{signature2dtype[arg_signature]} {var_name} = {cexpr(arg)};" + ) + struct_data = f'{signature2dtype[arg_signature]} {var_name} __attribute__((aligned(sizeof({signature2dtype[arg_signature]}))));' + arg_data = f'static_cast<{signature2dtype[arg_signature]}>({var_name})' + else: + raise TypeError("Infer arg_type to cpp failed!") + # self.writeline(f"auto {var_name} = {cexpr(arg)};") + + nonlocal struct_def_body + nonlocal struct_arg_body + struct_def_body += struct_data + ' ' + struct_arg_body += arg_data + ', ' + + for arg, arg_type, arg_signature in zip_longest( + call_args, arg_types, arg_signatures + ): + process_args(arg, arg_type, arg_signature) + + return kernel_args_var, before_strucr_str +\ + struct_def_head + struct_def_body + struct_def_end +\ + struct_arg_head + struct_arg_body + struct_arg_end + + def generate_default_grid( + self, + kernel_name: str, + grid_args: List[Any], + gpu: bool = True, + grid_callable: Optional[Callable[..., Any]] = default_grid_fn, + **grid_extra_kwargs, + ): + """ + Generate grid configs for launching a CUDA kernel using the grid + function from triton_heuristics. Because its computation needs + to read kernel config after autotune, it is done in a deferred way + using DeferredNpuDefaultGrid. + """ + checkIfTrue(gpu, "CppWrapperNpu.generate_default_grid does not support non-NPU") + return DeferredNpuDefaultGrid( + kernel_name, grid_args, grid_callable, **grid_extra_kwargs + ) + + def generate_kernel_call_npu( + self, + kernel_name: str, + call_args, + grid=None, + device_index=None, + npu=True, + triton=True, + arg_types=None, + raw_args=None, + grid_fn: str = "grid", + triton_meta=None, + autotune_configs=None, + grid_extra_kwargs="", + ): + if ( + config.triton.autotune_at_compile_time + and kernel_name not in self.kernel_autotune_names + ): + # Call PythonWrapperCodegen to create the autotune code block + PythonWrapperCodegen.generate_kernel_call( + self, + kernel_name, + call_args, + grid, + device_index, + npu, + triton, + arg_types, + raw_args, + grid_fn, + triton_meta, + autotune_configs, + grid_extra_kwargs, + ) + + if device_index is None: + current_device = V.graph.get_current_device_or_throw() + device_index = current_device.index + + stream = ( + "stream" + if V.graph.aot_mode + else self.write_get_raw_stream(device_index, V.graph) + ) + + if triton: + device_index, call_args = self.prepare_triton_kernel_call( + device_index, call_args + ) + kernel_var_name = self.generate_load_kernel_once(kernel_name, device_index, V.graph) + + # args with value 1 are added into equal_to_1 and constants + # in triton_meta (in the Python codegen) which makes them + # inlined in the PTX and compiled CUBIN + arg_signatures = [] + if ( + triton_meta is not None + and triton_meta.get("configs") + and triton_meta.get("signature") + ): + equal_to_1 = triton_meta["configs"][0].equal_to_1 + call_args = [ + arg for i, arg in enumerate(call_args) if i not in equal_to_1 + ] + arg_types = [t for i, t in enumerate(arg_types) if i not in equal_to_1] + # extract the arg signatures from triton_meta + arg_signatures = triton_meta["signature"].values() + arg_signatures = [ + v for i, v in enumerate(arg_signatures) if i not in equal_to_1 + ] + + current_kernel_id = next(self.kernel_callsite_id) + current_grid_id = next(self.grid_id) + + # >>>>> gen grids + grid_var = f"{kernel_name}_grid_{current_grid_id}" + self.writeline( + DeferredNpuGridLine(kernel_name, grid_var, grid, autotune_configs) + ) + # <<<<< + + # >>>>> gen kernel args + kernel_args_var, call_args_str = self.generate_args_decl( + call_args, arg_types, arg_signatures, current_kernel_id, grid_var + ) + self.writeline(f"{call_args_str}") + # <<<<< + + kernel_var_name = ( + f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name + ) + # add debug printer code for all triton kernel related calls + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args, kernel_name, arg_types, None + ) + with debug_printer_manager: + + ''' + typedef struct { + const char* kernelName; // f"'{kernel_name}'" <- 'triton_unk_fused_sigmoid_1' + const void* func; // kernel_var_name <- kernels.triton_unk_fused_sigmoid_1 + rtStream_t stream; // stream + int gridX; // f"{grid_var}.grid_x", + int gridY; // f"{grid_var}.grid_y", + int gridZ; // f"{grid_var}.grid_z", + int *profilerRegistered; //nullptr + void *kernelArgs; // f'static_cast(&{kernel_args_var})' + int32_t kernelArgsSize; // f'sizeof({kernel_args_var})' + } LaunchKernelArgs; + ''' + + self.writeline(f"if ({grid_var}.is_non_zero()) {{") + self.writeline( + DeferredNpuKernelLine( + kernel_name, + r" launchKernel({}, {}, {}, {}, {}, {}, {}, {});".format(\ + f'"{kernel_name}"', + kernel_var_name, + stream, + f"{grid_var}.grid_x", + f"{grid_var}.grid_y", + f"{grid_var}.grid_z", + f"static_cast(&{kernel_args_var})", + f'sizeof({kernel_args_var})', + ), + tuple(), + self.additional_files, + ), + ) + + self.writeline("}\n") + else: + casted = [] + for arg_type, arg in zip(arg_types, call_args): + new_arg = arg + if arg_type.endswith("*") and arg != "nullptr": + new_arg = f"{arg}.data_ptr()" + casted.append(f"({arg_type}){new_arg}") + call_args_str = ", ".join(casted) + self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});") + + + def generate_kernel_call( + self, + kernel_name: str, + call_args, + grid=None, + device_index=None, + gpu=True, + triton=True, + arg_types=None, + raw_args=None, + grid_fn: str = "grid", + triton_meta=None, + autotune_configs=None, + grid_extra_kwargs="", + ): + """ + Override the default value of argument 'gpu' to True here. + generate_kernel_call can still be called with gpu=False because of + a mix of cpu kernels and gpu kernels. + """ + + """ + To fit with NPU: we write a new function 'generate_kernel_call_npu + and make a new parameter called 'npu', which always equals to 'gpu', + because 'gpu' parameter means 'not cpu' in upper logic + """ + + if not gpu: + # Even in CppWrapperNpu, we may see cpp kernels + return CppWrapperCpu.generate_kernel_call( + self, + kernel_name, + call_args, + grid, + device_index, + gpu, + triton, + arg_types, + raw_args, + grid_fn, + triton_meta, + autotune_configs, + grid_extra_kwargs, + ) + + self.generate_kernel_call_npu( + kernel_name, + call_args, + grid, + device_index, + gpu, + triton, + arg_types, + raw_args, + grid_fn, + triton_meta, + autotune_configs, + grid_extra_kwargs, + ) + + def make_zero_buffer(self, name): + return f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_({name}.get()));" diff --git a/torch_npu/_inductor/codegen/ir.py b/torch_npu/_inductor/codegen/ir.py new file mode 100644 index 0000000000000000000000000000000000000000..2cb83e5c1fa4d13f9e54c38a48a12a0379fce318 --- /dev/null +++ b/torch_npu/_inductor/codegen/ir.py @@ -0,0 +1,203 @@ + +from typing import List, Tuple, Dict, Any, Optional +import itertools +import sympy + + +from torch._inductor.virtualized import V +from torch._inductor.ir import (ReductionHint, IRNode, ModularIndexing, FloorDiv) +from torch._inductor.utils import sympy_subs, sympy_index_symbol +from torch_npu._inductor.codegen.triton import NPUIndexTritonKernel + +from ..config import log + + +# NPU doesn't need to support ReductionHint.OUTER, and persistent reduction +def num_splits( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + reduction_numel, + input_node: Optional[IRNode] = None, + ): + return ReductionHint.DEFAULT, 1 + + +def detect_flattened_dims(kernel, index): + new_vars = {} + if not isinstance(index, (sympy.core.add.Add, ModularIndexing, FloorDiv)): + return new_vars + + def detect_flattened_axis(expr): + def init_new_vars(var, length): + if var not in new_vars: + new_vars[var] = {length: [None, None]} + if length not in new_vars[var]: + new_vars[var][length] = [None, None] + if isinstance(expr, ModularIndexing): + var, divisor, length = expr.args + init_new_vars(var, length) + new_vars[var][length][1] = (expr, divisor, length) + elif isinstance(expr, FloorDiv): + var, divisor = expr.args + init_new_vars(var, divisor) + # over than 1 node_schedule, var may be deleted in kernel.range_tree_nodes + # it shoule be find in range_tree_nodes_removed dict + if (var in kernel.range_tree_nodes): + numel = kernel.range_tree_nodes[var].length + else: + numel = kernel.range_tree_nodes_removed[var].length + + length = expr.eval(numel, divisor) + new_vars[var][divisor][0] = (expr, divisor, length) + + else: + for x in expr.args: + detect_flattened_axis(x) + + # add + if isinstance(index, sympy.core.add.Add): + for x in index.args: + detect_flattened_axis(x) + elif isinstance(index, (ModularIndexing, FloorDiv)): + detect_flattened_axis(index) + else: + pass + + # make sure FloorDiv, MouldarIndexing must be in-pair + for var, divisors in new_vars.items(): + if var in kernel.range_tree_nodes: + parent_axis = kernel.range_tree_nodes[var] + else: + parent_axis = kernel.range_tree_nodes_removed[var] + for divisor, pair in divisors.items(): + if not pair[0] and not pair[1]: + pass + #FloorDiv not inplace + elif not pair[0]: + _, _, length = pair[1] + expr = FloorDiv(var, length) + new_vars[var][divisor][0] = (expr, length, parent_axis.length // length) + #ModularIndexing not inplace + elif not pair[1]: + expr = ModularIndexing(var, 1, divisor) + new_vars[var][divisor][1] = (expr, 1, divisor) + else: + pass + + return new_vars + + +def rebuild_flattened_dims(indexing): + def rebuild_flattened_dim(key, index, old_node, flatten_dim): + for _, pair in flatten_dim.items(): + new_var_expr = sympy.Integer(0) + origin_axis_length = 0 + pair_is_valid = True + # don't create duplicated axis, e.g. y1:1024, y1 % 1024 is duplicated + expr, divisor, length = pair[1] + if not old_node.parent.duplicated_check(divisor, length): + V.kernel.expr_substituted[expr] = old_node.symbol() + break + + for axis in pair: + expr, divisor, length = axis + # 3. try to rebuild the axis in kernel + new_node = old_node.parent.lookup(divisor, length) + + # 4. substitute div/mod expression in indexing + index = index.subs(expr, new_node.symbol()) + indexing[key] = index + if isinstance(expr, FloorDiv): + new_var_expr = new_var_expr + new_node.symbol() * divisor + origin_axis_length = divisor * length + elif isinstance(expr, ModularIndexing): + new_var_expr = new_var_expr + new_node.symbol() + V.kernel.expr_substituted[expr] = new_node.symbol() + + if var not in V.kernel.range_tree_nodes_substituted: + V.kernel.range_tree_nodes_substituted[var] = [] + V.kernel.range_tree_nodes_substituted[var].append((origin_axis_length, new_var_expr)) + + def find_index_in_substitute(index, kernel): + return any([index.find(key) for key in kernel.expr_substituted.keys()]) + + kernel = V.kernel + for key, index in indexing.items(): + # 1. try to find out flattened axis from indexing + flatten_dims = detect_flattened_dims(kernel, index) + #2. try to rebuild these flattened dims + for var, flatten_dim in flatten_dims.items(): + if (var in kernel.range_tree_nodes): + old_node = kernel.range_tree_nodes[var] + else: + old_node = kernel.range_tree_nodes_removed[var] + + rebuild_flattened_dim(key, index, old_node, flatten_dim) + + if find_index_in_substitute(index, kernel): + new_index = sympy_subs(index, kernel.expr_substituted) + indexing[key] = new_index + + +def substituted_dims_in_indexing(self, indexing, kernel, range_tree_nodes_substituted): + substituted = False + for var, candidates in range_tree_nodes_substituted.items(): + if not (len(candidates) > 0): + raise RuntimeError("assert len(candidates) > 0, candidates") + exprs = sorted(candidates, reverse=True, key=lambda x: x[0]) + # the best candidate is with the longest numel + numel = exprs[0][0] + expr = exprs[0][1] + node = kernel.range_tree_nodes[var] + if node.length != numel: + log.debug("sub nodes (expr%s, numel:%d) can not substitute parent node(%s:%d)", + expr, numel, node.symbol(), node.length) + continue + for key, index in indexing.items(): + if var in index.free_symbols: + index = index.subs(var, expr) + indexing[key] = index + substituted = True + + return substituted + + +def generate_body_indexing(body, indices): + index = list(itertools.chain.from_iterable(indices)) + if not (len(index) == len(body.var_ranges)): + raise RuntimeError("assert len(index) == len(body.var_ranges), (index, body.var_ranges)") + if not (all(v not in body.var_ranges for v in index)): + raise RuntimeError("assert all(v not in body.var_ranges for v in index)") + + replacements = dict(zip(body.var_ranges.keys(), index)) + indexing_map = dict(zip(index, body.var_ranges.keys())) + setattr(body, 'indexing_map', indexing_map) + body.indexing = { + name: sympy_subs(expr, replacements) + for name, expr in body.indexing_exprs.items() + } + + +def transform_dims_in_indexing(self, indices): + if self.indexing is None: + generate_body_indexing(self, indices) + + if V.kernel is not None and isinstance(V.kernel, NPUIndexTritonKernel): + rebuild_flattened_dims(self.indexing) + + +# select tiling axis, recover missing dimensions, +def loopbody__call__(self, *indices): + if self.indexing is None: + generate_body_indexing(self, indices) + result = self.root_block() + self.indexing = None + return result + + + diff --git a/torch_npu/_inductor/codegen/ir_fx.py b/torch_npu/_inductor/codegen/ir_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..4ac670b8588a436cfc6e08859885d80b9d86e369 --- /dev/null +++ b/torch_npu/_inductor/codegen/ir_fx.py @@ -0,0 +1,837 @@ +import traceback +from unittest.mock import patch + +import typing + +from typing import ( + Any, + Callable, + List, + Optional, + Union +) +from typing import Optional + +import sympy +from sympy import Expr + +import torch +from torch._inductor import ir +from torch._inductor import config + +from torch._inductor.virtualized import ops, V +from torch.utils._ordered_set import OrderedSet + +from ..lowering_fx import ( + fetch_graphs, + merge_traced_graphs, + node_id, + clone, + create_fake_input, + subtract_graph +) + + +def _patch_loops_get_name(self): + return self.node_name + +def _patch_loops_get_traced_graph(self): + return self.traced_graph + +@classmethod +def _patch_loops_create(cls, *args, **kwargs): + origin_node = kwargs.pop("origin_node", None) + traced_graph = kwargs.pop("traced_graph", None) + node_name = kwargs.pop("node_name", None) + tb = kwargs.pop("traceback", None) + r = cls(*args, **kwargs) + # Need to explicitly set origin_node here to propagate it down. + # todo(chilli): I think it would be better for IRNode to directly set + # origin_node + r._post_init_setattr("origin_node", origin_node) + r._post_init_setattr("traceback", tb or r.traceback) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return ir.TensorBox.create(r) + +def _patch_pointwise_constant_to_device(self, device, traced_graph=None, node_name=None): + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ir.ConstantBuffer, "override_device", device)(loader) + + r = ir.Pointwise(device, self.dtype, loader, self.ranges) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + +@classmethod +def _patch_reduction_create( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + ranges: ir.Sequence[Expr], + reduction_ranges: ir.Sequence[Expr], + reduction_type: str, + reduction_hint: ir.ReductionHint = ir.ReductionHint.DEFAULT, + input_node: Optional[ir.IRNode] = None, + traced_graph = None, + node_name: str = None +) -> ir.TensorBox: + reduction_numel = V.graph.sizevars.simplify(ir.sympy_product(reduction_ranges)) + + if reduction_numel == 0: + # N.B. This is a hack to generate the literal of the given type + # Ideally, we should be fixing `def constant` in triton.py + # but it breaks due to hardcoded dtypes in other places + def py_cnst(val: object) -> Union[bool, float, int]: + if dst_dtype == torch.bool: + return bool(val) + elif dst_dtype.is_floating_point: + assert isinstance(val, typing.SupportsFloat) + return float(val) + else: + assert isinstance(val, typing.SupportsInt) + return int(val) + + rtypes_to_inits = { + "sum": py_cnst(0), + "xor_sum": py_cnst(0), + "prod": py_cnst(1), + "any": py_cnst(0), + # "all" is desugared to `!any(!val)` + } + + assert ( + reduction_type in rtypes_to_inits.keys() + ), f"{reduction_type} not supported for zero-dimension tensors!" + + def const_fn(index: int) -> ir.OpsValue: + return ops.constant(rtypes_to_inits[reduction_type], dst_dtype) + + return ir.Pointwise.create( + device=device, + dtype=src_dtype, + inner_fn=const_fn, + ranges=list(ranges), + traced_graph=traced_graph, + node_name=node_name + ) + + if reduction_numel == 1: + # this reduction is actually a pointwise op + if reduction_type in ("argmin", "argmax"): + + def fn(index: int) -> ir.OpsValue: + return ops.constant(0, dst_dtype) + + else: + + def fn(index: int) -> ir.OpsValue: + reduction_index = [sympy.S.Zero for _ in reduction_ranges] + return inner_fn(index, reduction_index) + + return ir.Pointwise.create( + device=device, dtype=dst_dtype, inner_fn=fn, ranges=ranges + ) + + if ( + isinstance(reduction_numel, ir.Integer) + and V.graph.sizevars.size_hint(reduction_numel) + < config.unroll_reductions_threshold + and (ir.sympy_product(ranges) != 1 or ir.is_gpu(device.type)) + ): + # NB: This works around https://github.com/pytorch/pytorch/issues/140457 + # since turning reductions into pointwise ops can exacerbate this problem + return ir.Pointwise.create( + device=device, + dtype=dst_dtype, + inner_fn=cls._unroll_reduction_fn( + inner_fn, reduction_ranges, reduction_type, src_dtype + ), + ranges=ranges, + traced_graph=traced_graph, + node_name=node_name + ) + + # triton doesn't support reduce to single element well, so break it up + hint, split = cls.num_splits( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + reduction_numel, + input_node, + ) + # intermediate reduction in split can contain complex indexing, + # and num_splits will fail to correctly set the hint + # reuse the passed hint if available + if reduction_hint == ir.ReductionHint.DEFAULT: + reduction_hint = hint + if split == -1: + assert input_node is not None + new_ranges, new_reduction_ranges = ir.extract_input_node_reduction_ranges( + input_node + ) + assert new_ranges is not None + assert new_reduction_ranges is not None + return cls.create_multilayer_existing_ranges( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + new_ranges, + new_reduction_ranges, + reduction_type, + reduction_hint, + ) + elif split > 1: + # triton doesn't support reduce to single element well, so break it up + return cls.create_multilayer( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + split, + reduction_hint, + ) + + r = ir.Reduction( + device=device, + dtype=dst_dtype, + inner_fn=inner_fn, + ranges=ranges, + reduction_ranges=reduction_ranges, + reduction_type=reduction_type, + src_dtype=src_dtype, + reduction_hint=reduction_hint, + ) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + + return ir.TensorBox.create(r) + + +def _patch_baseview_get_traced_graph(self): + if hasattr(self, 'traced_graph') and self.traced_graph is not None: + return self.traced_graph + return self.data.get_traced_graph() + + +def _patch_base_view_get_reads(self): + with patch.object(ir.FlexibleLayout, "allow_indexing", True): + r = ir.extract_read_writes( + self.make_loader(), + self.get_size(), + ).reads + for md in r: + if md.index.has(ir.ModularIndexing): + if md.index.has(ir.FloorDiv): + self.realize() + return r + else: + for m in md.index.find(ir.ModularIndexing): + for arg in m.args: + if arg.has(ir.ModularIndexing): + self.realize() + return r + return r + + +def has_buffer(inp): + if not hasattr(inp, 'data'): + return False + if isinstance(inp.data, ir.Buffer): + return True + return has_buffer(inp.data) + +def get_buffer(inp): + if isinstance(inp.data, ir.Buffer): + return inp.data + return get_buffer(inp.data) + +def _patch_baseview_realize(self): + if hasattr(self, 'traced_graph') and self.traced_graph is not None: + r = self.data.realize() + buffer = get_buffer(self) + if isinstance(buffer, (ir.MultiOutput, ir.InputBuffer, ir.ConcatKernel)): + return r + traced_graph = buffer.data.get_traced_graph() + buf_name = buffer.get_name() + new_traced_graph, placeholder = subtract_graph(self.traced_graph, traced_graph, node_name=buf_name) + if placeholder is not None: + placeholder.name = buf_name + device = buffer.get_device() + dtype = buffer.get_dtype() + size = buffer.get_size() + stride = buffer.get_stride() + fake_input = create_fake_input(size, stride, device, dtype) + placeholder.meta['val'] = fake_input + self._post_init_setattr("traced_graph", new_traced_graph) + return r + else: + return self.data.realize() + +def _patch_baseview_realize_hint(self): + if hasattr(self, 'traced_graph') and self.traced_graph is not None: + r = self.data.realize_hint() + if not has_buffer(self): + return r + buffer = get_buffer(self) + if isinstance(buffer, (ir.MultiOutput, ir.InputBuffer, ir.ConcatKernel)): + return r + traced_graph = buffer.data.get_traced_graph() + buf_name = buffer.get_name() + new_traced_graph, placeholder = subtract_graph(self.traced_graph, traced_graph, node_name=buf_name) + if placeholder is not None: + placeholder.name = buf_name + device = buffer.get_device() + dtype = buffer.get_dtype() + size = buffer.get_size() + stride = buffer.get_stride() + fake_input = create_fake_input(size, stride, device, dtype) + placeholder.meta['val'] = fake_input + self._post_init_setattr("traced_graph", new_traced_graph) + return r + else: + return self.data.realize_hint() + + +def _patch_mark_reuse(self, users): + if isinstance(self.data, ir.StorageBox): + if self.data.should_realize_on_reuse(users): + if hasattr(self, 'traced_graph') and self.traced_graph is not None: + r = self.data.realize() + buffer = get_buffer(self) + if isinstance(buffer, (ir.MultiOutput, ir.InputBuffer, ir.ConcatKernel)): + return r + traced_graph = buffer.data.get_traced_graph() + buf_name = buffer.get_name() + new_traced_graph, placeholder = subtract_graph(self.traced_graph, traced_graph, node_name=buf_name) + if placeholder is not None: + placeholder.name = buf_name + device = buffer.get_device() + dtype = buffer.get_dtype() + size = buffer.get_size() + stride = buffer.get_stride() + fake_input = create_fake_input(size, stride, device, dtype) + placeholder.meta['val'] = fake_input + self._post_init_setattr("traced_graph", new_traced_graph) + return r + else: + return self.data.realize() + else: + return self.data.mark_reuse(users) + + +@classmethod +def _patch_expandview_create(cls, x, new_size, traced_graph=None, node_name=None): + new_size = cls._normalize_size(x, new_size) + + if ir.is_storage_and_layout(x): + storage, old_layout = ir.as_storage_and_layout(x) + skip = len(new_size) - len(old_layout.size) + assert skip >= 0 + new_stride = [sympy.Integer(0)] * skip + for stride, size in zip(old_layout.stride, old_layout.size): + new_stride.append( + stride + if not V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(size, 1), size_oblivious=True + ) + else sympy.Integer(0) + ) + new_layout = ir.FixedLayout( + old_layout.device, + old_layout.dtype, + list(new_size), + new_stride, + old_layout.offset, + ) + + r = ir.ReinterpretView(data=storage, layout=new_layout) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + r = ir.ExpandView(data=x, size=new_size) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + + return r + + +@classmethod +def _patch_permuteview_create(cls, x, dims, traced_graph=None, node_name=None): + dims = cls._map_neg_dims(dims) + assert OrderedSet(dims) == OrderedSet(range(len(dims))) + + if ir.is_storage_and_layout(x): + storage, old_layout = ir.as_storage_and_layout(x) + new_layout = ir.FixedLayout( + old_layout.device, + old_layout.dtype, + [old_layout.size[i] for i in dims], + [old_layout.stride[i] for i in dims], + old_layout.offset, + ) + r = ir.ReinterpretView(data=storage, layout=new_layout) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + + r = ir.PermuteView(data=x, dims=dims) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + +@classmethod +def _patch_view_create(cls, x, new_size, traced_graph=None, node_name=None): + assert isinstance(new_size, (tuple, list)) + old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size) + # Skip pointless views + if V.graph.sizevars.statically_known_list_equals(old_size, new_size): + return x + + unbacked_symbols_in_sizes = False + if ( + len(ir.free_unbacked_symbols(old_size)) > 0 + or len(ir.free_unbacked_symbols(new_size)) > 0 + ): + unbacked_symbols_in_sizes = True + + if 0 in new_size: + + def fake_reindex(index): + return tuple([0] * len(old_size)) + + r = cls(x, list(new_size), fake_reindex) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout + elif (ir.is_contiguous_storage_and_layout(x) or unbacked_symbols_in_sizes): # and not isinstance(x.data, ir.ReinterpretView): + if unbacked_symbols_in_sizes and (not ir.is_contiguous_storage_and_layout(x)): + # realize x; otherwise, the dynamic_reshape_indexer below will fail + # due to the size_hint's inability to process unbacked SymInts + x = ir.ExternKernel.realize_input(x) + + storage, old_layout = ir.as_contiguous_storage_and_layout(x) + new_layout = ir.FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + ir.FlexibleLayout.contiguous_strides(new_size), + old_layout.offset, + ) + + r = ir.ReinterpretView(data=storage, layout=new_layout) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + reindex = cls.dynamic_reshape_indexer(old_size, new_size) + + r = cls(data=x, size=list(new_size), reindex=reindex) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + +@classmethod +def _patch_sliceview_create(cls, x, dim, start, end, step=1, clamp=True, traced_graph=None, node_name=None): # TODO: crm, clamp=True + step = sympy.expand(step) + assert isinstance(step, sympy.Expr) or step > 0 + try: + if start == 0 and end >= 2**63 - 1 and step == 1: + return x + except TypeError: + pass + sizevars = V.graph.sizevars + new_size = list(x.get_size()) + + if clamp: + start, end = cls.normalize_start_end(x, dim, start, end) + + new_size[dim] = ir.FloorDiv(end - start + (step - 1), step) + + if ir.is_storage_and_layout(x): + # Fast path + storage, old_layout = ir.as_storage_and_layout(x) + new_stride = list(old_layout.stride) + new_stride[dim] = new_stride[dim] * step + new_layout = ir.FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + new_stride, + old_layout.offset + old_layout.stride[dim] * start, + ) + r = ir.ReinterpretView(data=storage, layout=new_layout) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + def reindex(index): + assert len(index) == len(new_size), f"wrong ndim {index} {new_size}" + index = list(index) + index[dim] = index[dim] * step + start + return index + + # redirect to a generic view + r = ir.SliceView(data=x, size=new_size, reindex=reindex) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + +def _patch_buffer_get_traced_graph(self): + return self.traced_graph + + +@classmethod +def _patch_concatkernel_create(cls, inputs, dim): + device = inputs[0].get_device() + dtype = inputs[0].get_dtype() + new_size = list(inputs[0].get_size()) + offsets_start = [0] + offsets_end = [new_size[dim]] + assert 0 <= dim < len(new_size) + for i in range(1, len(inputs)): + input_size = inputs[i].get_size() + offsets_start.append(new_size[dim]) + assert len(input_size) == len(new_size) + assert inputs[i].get_dtype() == dtype + assert inputs[i].get_device() == device + for j in range(len(new_size)): + if j == dim: + new_size[j] = new_size[j] + input_size[j] + else: + new_size[j] = V.graph.sizevars.guard_equals( + new_size[j], input_size[j] + ) + offsets_end.append(new_size[dim]) + + output_stride = ir.FlexibleLayout.contiguous_strides(new_size) + # If any of the inputs is in CL format, use CL format for the output + for i in range(len(inputs)): + x = inputs[i] + if ir.is_storage_and_layout(x): + layout = x.get_layout() + if ( + isinstance(layout, ir.FixedLayout) + and layout.is_channels_last_contiguous(layout.size, layout.stride) + ): + # use CL stride for the output + output_stride = ir.make_channels_last_strides_for(new_size) + break + + any_input_is_storage_and_layout = any(ir.is_storage_and_layout(x) for x in inputs) + fx_node_args = V.graph.current_node.args[0] + assert isinstance(fx_node_args, list) + # If any of the inputs has meta tensor and the meta tensor is in CL format, use CL format for the output + if any_input_is_storage_and_layout is False and any( + "val" in arg.meta + and ( + arg.meta["val"].is_contiguous(memory_format=torch.channels_last) + or arg.meta["val"].is_contiguous(memory_format=torch.channels_last_3d) + ) + for arg in fx_node_args + ): + output_stride = ir.make_channels_last_strides_for(new_size) + + concat_kernel = ir.ConcatKernel( + name=None, + layout=ir.FixedLayout( + device=device, + dtype=dtype, + size=new_size, + stride=output_stride, + ), + inputs=[], + ) + + kernel = ir.StorageBox(concat_kernel) + op_names = [] + for i in range(len(inputs)): + input_buffer = cls.realize_into( + inputs[i], + ir.SliceView.create( + kernel, dim, offsets_start[i], offsets_end[i], clamp=False + ), + ) + concat_kernel.inputs.append(input_buffer) + + if isinstance(inputs[i].data, ir.BaseView): + input_unwrapped = inputs[i].data.unwrap_view() + else: + input_unwrapped = inputs[i].data + + if ( + input_unwrapped.is_input_buffer() + and ir.is_gpu(inputs[i].get_device().type) + and not ir.is_dynamic(input_buffer) + ): + op_names.append(input_buffer.get_operation_name()) + + if len(op_names) > 1 and V.graph.has_feature(device, ir.BackendFeature.FOREACH): + V.graph.register_operation_list(op_names) + + cat_inputs = [ir.TensorBox(ir.StorageBox(inp)) for inp in concat_kernel.inputs] + input_graphs = fetch_graphs([cat_inputs]) + node_name = f'cat_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, torch.ops.aten.cat, node_name, dim=dim) + + concat_kernel._post_init_setattr("name", V.graph.register_buffer(concat_kernel)) + concat_kernel._post_init_setattr("inputs", cls.unwrap_storage(concat_kernel.inputs)) + concat_kernel._post_init_setattr("traced_graph", new_graph) + concat_kernel._post_init_setattr("node_name", node_name) + + return kernel + +def _patch_concatkernel_get_traced_graph(self): + return self.traced_graph + +@classmethod +def _patch_concatkernel_realize_into(cls, src, dst): + # Attempt to turn this into a ReinterpretView rather than assert. + # This has concessions around layout, as as_storage_and_layout + # can cause us to go from flexible to fixed layout. + if not isinstance(dst, ir.ReinterpretView): + if ir.is_storage_and_layout(dst): + storage, layout = ir.as_storage_and_layout(dst) + dst = ir.ReinterpretView(data=storage, layout=layout) + assert isinstance(dst, ir.ReinterpretView), dst + if isinstance(src, ir.TensorBox): + # unwrap a TensorBox + return cls.realize_into(src.data, dst) + if isinstance(src, ir.StorageBox): + src.realize() + # ExternKernelAlloc has specific requirements for output layout, should create a copy + assert hasattr(src.data, "layout") + if cls.can_realize_into_without_copy(src): + src.data.layout = ir.NonOwningLayout(dst) + return src.data + pw = clone(src, memory_format=torch.contiguous_format) + return cls.realize_into(pw, dst) + + +def _patch_externkernel_copy_input(x): + traced_graph = x.get_traced_graph() + node_name = x.get_name() + if traced_graph is None: + traced_graph = fetch_graphs([x])[0] + node_name = f'getitem_{next(node_id)}' + + pw = ir.Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + ranges=x.get_size(), + origin_node=x.get_origin_node(), + traceback=x.get_traceback(), + traced_graph=traced_graph, + node_name=node_name + ) + pw.realize() + return pw + + +@classmethod +def _patch_externkernel_convert_to_reinterpret_view(cls, x): + """ + In order to pass this to an extern kernel we need a + ReinterpretView not a View. This allows us to avoid some + unneeded copies. + """ + assert isinstance(x, ir.BaseView) + if isinstance(x, ir.ReinterpretView): + return x + + # NOTE: Don't use extract_read_writes here as it fails when + # make_loader() inlines the computation + x_unwrap_view = x.unwrap_view() + buf = V.graph.get_buffer(x_unwrap_view.get_name()) + assert buf is not None + x_unwrap_view_fx_node = buf.get_origin_node() + # Prefer channels last format according to how the format is set from eager. + if ( + x_unwrap_view_fx_node is not None + and "val" in x_unwrap_view_fx_node.meta + and isinstance(x_unwrap_view.layout, ir.FlexibleLayout) + and ( + x_unwrap_view_fx_node.meta["val"].is_contiguous( + memory_format=torch.channels_last + ) + or x_unwrap_view_fx_node.meta["val"].is_contiguous( + memory_format=torch.channels_last_3d + ) + ) + ): + x_unwrap_view.freeze_layout_with_same_order( + ir.make_channels_last_strides_for(x_unwrap_view.get_size()) + ) + else: + x_unwrap_view.freeze_layout() + + index_args, var_ranges = ir.dependencies.index_vars_squeeze( + x.get_size(), prefix="r" + ) + range_vars = index_args[0] + index = x.make_indexer()(range_vars) + + index = V.graph.sizevars.simplify_with_ranges(index, var_ranges) + strides = V.graph.sizevars.stride_vars(index, range_vars) + offset = V.graph.sizevars.offset_var(index, range_vars) + expected = ir.sympy_dot(range_vars, strides) + offset + + if index != expected: + ir.log.debug( + "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s", + strides, + offset, + index, + ) + raise NotImplementedError + + r = ir.ReinterpretView( + data=x.data, + layout=ir.FixedLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=x.get_size(), + stride=strides, + offset=offset, + ), + ) + r._post_init_setattr("traced_graph", x.get_traced_graph()) + r._post_init_setattr("node_name", x.get_name()) + return r + + +@classmethod +def _patch_devicecopy_create(cls, x, device, non_blocking, traced_graph=None, node_name=None): + if ( + not x.is_extern() + and all(r in V.graph.constants for r in x.get_read_names()) + and not config.aot_inductor.use_runtime_constant_folding + ): + return x.constant_to_device(device) + + V.graph.add_device_info(device) + V.graph.add_device_info(x.get_device()) + + ir.developer_warning("DeviceCopy in input program") + constant_args = (non_blocking,) + r = ir.DeviceCopy( + ir.FlexibleLayout( + device=device, + dtype=x.get_dtype(), + size=x.get_size(), + ), + [cls.realize_input(x)], + constant_args, + ) + r._post_init_setattr("traced_graph", traced_graph) + r._post_init_setattr("node_name", node_name) + return r + + +def _patch_devicecopy_get_traced_graph(self): + return self.traced_graph + + +def _patch_multioutput_get_traced_graph(self): + return None + +ir.MultiOutput.get_traced_graph = _patch_multioutput_get_traced_graph + +def _patch_mutablebox_get_name(self): + return self.data.get_name() + +def _patch_mutablebox_get_traced_graph(self): + return self.data.get_traced_graph() + + +@classmethod +def _patch_mutationlayout_realize_into(cls, src, dst, unsafe_alias=False): + dst.realize() + # NOTE: We must realize users of `dst` before we realize `src`, since + # realization order determines scheduling order. Otherwise, src's + # mutation would be scheduled before the existing users of dst! + V.graph.mark_buffer_mutated(dst.get_name()) + + if isinstance(src, ir.TensorBox): + src = src.data + + # We copy the contents of src into dst. In most cases this should + # be fused into a single kernel by the scheduler. + # NOTE: We cannot change src's layout to mutate dst directly as this + # would alias src to dst, which is not correct as further s to + # dst would effect users of src. However if there are no more users of + # dst, we can alias src to dst. + src.realize_hint() + + if not unsafe_alias: + + input_graphs = fetch_graphs([dst, src]) + node_name = f'copy__{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, torch.ops.aten.copy, node_name) + + src = ir.Pointwise.create( + device=src.get_device(), + dtype=src.get_dtype(), + inner_fn=src.make_loader(), + ranges=[ + V.graph.sizevars.guard_equals(a, b) + for a, b in zip(src.get_size(), dst.get_size()) + ], + traced_graph=new_graph, + node_name=node_name, + ).data + + src.realize() + assert isinstance(src.data.layout, ir.FlexibleLayout) + src.data.layout = ir.MutationLayoutSHOULDREMOVE(dst) + return src.data + +def _patch_npu_inductor_ir(): + ir.Reduction.create = _patch_reduction_create + ir.BaseView.get_traced_graph = _patch_baseview_get_traced_graph + ir.BaseView.get_reads = _patch_base_view_get_reads + ir.BaseView.realize = _patch_baseview_realize + ir.BaseView.realize_hint = _patch_baseview_realize_hint + ir.BaseView.mark_reuse = _patch_mark_reuse + ir.ExpandView.create = _patch_expandview_create + ir.PermuteView.create = _patch_permuteview_create + ir.View.create = _patch_view_create + ir.SliceView.create = _patch_sliceview_create + ir.Buffer.traced_graph = None + ir.Buffer.get_traced_graph = _patch_buffer_get_traced_graph + ir.ConcatKernel.create = _patch_concatkernel_create + ir.ConcatKernel.get_traced_graph = _patch_concatkernel_get_traced_graph + ir.ConcatKernel.realize_into = _patch_concatkernel_realize_into + ir.ExternKernel.copy_input = _patch_externkernel_copy_input + ir.ExternKernel.convert_to_reinterpret_view = _patch_externkernel_convert_to_reinterpret_view + ir.DeviceCopy.create = _patch_devicecopy_create + ir.DeviceCopy.get_traced_graph = _patch_devicecopy_get_traced_graph + ir.MutableBox.get_name = _patch_mutablebox_get_name + ir.MutableBox.get_traced_graph = _patch_mutablebox_get_traced_graph + ir.Loops.get_name = _patch_loops_get_name + ir.Loops.get_traced_graph = _patch_loops_get_traced_graph + ir.Loops.create = _patch_loops_create + ir.Pointwise.constant_to_device = _patch_pointwise_constant_to_device + ir.MutationLayoutSHOULDREMOVE.realize_into = _patch_mutationlayout_realize_into \ No newline at end of file diff --git a/torch_npu/_inductor/codegen/kernel_analysis.py b/torch_npu/_inductor/codegen/kernel_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..266a7526264ea6a763d8e4e9b71250c53d161e91 --- /dev/null +++ b/torch_npu/_inductor/codegen/kernel_analysis.py @@ -0,0 +1,312 @@ +from torch._inductor.virtualized import V +from torch._inductor.utils import sympy_index_symbol +from torch._inductor.scheduler import SchedulerNode +from typing import List, Tuple +from torch._inductor import ir +import sympy +import pdb + +class IndexAnalysis : + def __init__(self, kernel, raw_index, is_store_index = False ) : + self.index = raw_index.subs(V.graph.sizevars.var_to_val) + self.kernel = kernel + self.tiling_axis = [x.symbol() for x in self.kernel.tiling_axis] + # self.var_stride = None # var list [(r,1),(x,2),(y,4),(z,24)] + # self.var_list = None # sorted by stride [r,x,y,z], in reversed order + self.stride_list = None # stride list [1,2,4,24] + self.reshape_sizes = [] # [RBLOCK, 1, 1, XBLOCK_SUB] + self.broadcast_sizes = [] # [RBLOCK, XBLOCK_SUB] + self.permute_shape = [] # [0,2,1,3] + self.var_replacements = {} # r2 ->r2_0, etc + self.var_directions = {} # r2_0 -> [None,:,None] + self.similar = None #(r,x,z,y) + self.need_permute = False + self.need_broadcast = False + self.need_reshape = False + self.gold = kernel.golden_var_list #tuple([x.symbol() for x in reversed(kernel.tiling_axis)]) + self.var_stride = [(key,coeff) for key, coeff in self.index.as_coefficients_dict().items() if not isinstance(key, sympy.Integer)] + # sort by stride + self.var_stride.sort(key = lambda x : x[1] ) + # only contains tiing axis var + self.var_list= tuple([x[0] for x in self.var_stride if x[0] in self.tiling_axis ]) + self.stride_list = tuple([x[1] for x in self.var_stride if x[0] in self.tiling_axis]) + self.is_store_index = is_store_index + + + def get_most_similar_shape(self) : + matched_dims = 0 + self.similar = None + for vars in self.kernel.index_analysis.keys() : + if len(vars) != len(self.gold) : + continue + i = 0 + while i < len(self.var_list) : + if vars[i] == self.var_list[i] : + i = i + 1 + else : + break + + if i > matched_dims : + matched_dims = i + self.similar = vars + return self.similar + + def same_var_list(self, var1, var2) : + if len(var1) != len(var2) : + return False + for i, v in enumerate(var1) : + if v != var2[i] : + return False + return True + + def shrink_permute_shape(self, permute_shape) : + diff = len(self.gold) - len(self.kernel.tiling_axis) + new_shape = [x for x in permute_shape if x - diff >= 0] + return new_shape + + def analyze_permute_shape(self): + if self.gold == self.similar: + self.need_permute = False + return + + similar = tuple(reversed(self.similar)) + gold = tuple(reversed(self.gold)) + self.permute_shape = [None] * len(gold) + + # kernel_name = self.kernel.get_kernel_name("", self.kernel.node_schedule, self.kernel) + # if kernel_name == "triton_unk_fused_add_clone_tanh_20" : + # pdb.set_trace() + + if self.is_store_index : + for i, x in enumerate(similar) : + if x != gold[i] : + index = gold.index(x) + self.permute_shape[i] = index + self.need_permute = True + else : + self.permute_shape[i] = i + return + + for i, x in enumerate(gold) : + if x != similar[i] : + index = similar.index(x) + self.permute_shape[i] = index + self.need_permute = True + else : + self.permute_shape[i] = i + + def analyze_broadcast_sizes(self) : + if not self.need_reshape : + self.need_broadcast = False + return + self.need_broadcast = True + reversed_similar = reversed(self.similar) + similar = [x for x in reversed_similar] + self.broadcast_sizes = ["1"] * len(similar) + for i, x in enumerate(similar) : + self.broadcast_sizes[i] = f"{x.name.upper()}BLOCK_SUB" + + def analyze_reshape_sizes(self) : + if all(x in self.var_list for x in self.tiling_axis ) : + self.need_reshape = False + return + self.need_reshape = True + reversed_similar = reversed(self.similar) + similar = [x for x in reversed_similar ] + var_list = [x for x in reversed(self.var_list) ] + self.reshape_sizes = ["1"] * len(similar) + for i, x in enumerate(var_list): + index = similar.index(x) + self.reshape_sizes[index] = f"{x.name.upper()}BLOCK_SUB" + + def analyze_var_direction(self) : + if self.var_list == self.gold : + return + var_list = self.var_list if len(self.var_list) == len(self.gold) else self.similar + if var_list == self.gold : + return + if not var_list : + return + var_list = list(tuple(reversed(var_list))) + gold = list(tuple(reversed(self.gold))) + assert len(var_list) == len(gold) + var_list = [x for x in var_list if x in self.kernel.tiling_axis] + gold = [x for x in gold if x in self.kernel.tiling_axis] + for i, x in enumerate(gold ): + index = var_list.index(x) + if(index == i) : + continue + new_var = sympy_index_symbol(f"{x}_{index}") + if new_var in self.var_replacements: + continue + direction = ["None"] * len(gold) + direction[index] = ":" + direction_str = f"[{','.join(direction)}]" + self.var_replacements[x] = new_var + self.var_directions[new_var] = direction_str + self.kernel.range_tree_nodes[x].var_directions[new_var] = direction_str + + + def analyze_index(self) : + if isinstance(self.index, sympy.Integer ) : + return + if not self.kernel.golden_var_list : + self.kernel.select_golden_varlist() + self.gold = self.kernel.golden_var_list + + assert self.gold is not None + assert len(self.gold) == len(self.tiling_axis) + + def all_tiling_in_var_list() : + return all([x in self.var_list for x in self.tiling_axis] ) + #2 analyze permute shape for full_dim_len index + if all_tiling_in_var_list() : + self.similar = self.var_list + self.analyze_permute_shape() + if self.var_list not in self.kernel.index_analysis : + self.kernel.index_analysis[self.var_list] = self + #3. analyze reshape and broadcast sizes + else : + pass + # self.similar = self.get_most_similar_shape() + # if self.similar is None : + # return + # self.analyze_reshape_sizes() + # self.analyze_broadcast_sizes() + # self.analyze_permute_shape() + + #4 analyze var direction + self.analyze_var_direction() + + def generate_statement(self) : + statement = "" + if self.need_reshape : + reshape_sizes = f"[{','.join(self.reshape_sizes)}]" + statement = f".reshape({reshape_sizes})" + if self.need_broadcast: + broadcast_sizes = f"[{','.join(self.broadcast_sizes)}]" + statement = f"{statement}.broadcast_to({broadcast_sizes})" + if self.need_permute: + statement = f"{statement}.permute({self.permute_shape})" + return statement + +class ReductionAnalysis : + def __init__(self, kernel) : + self.kernel = kernel + self.reduction = None + self.reduced_dim = None + if self.numof_reduction_axis() > 1 : + self.kernel.persistent_reduction = True + self.reduced_dim = 0 + return + + reduction = self.kernel.find_reduction_node() + if reduction is None or not isinstance(reduction, ir.Reduction) : + raise RuntimeError("failed to get one reduction node") + if not hasattr(reduction, "reduced_idx") : + raise RuntimeError("reduction node doesn't have attr reduced_idx") + self.reduction = reduction + self.reduced_dim = self.analyze_reduction_dim() + + def is_higher_order_reduction(self ): + return self.dim < len(self.kernel.tiling_axis) -1 + + def is_1d_reduction(self) : + return self.kernel.numels["r"] > 1 and len(self.kernel.numels) == 1 + + def get_reduce_dim_reshape(self, reduce_axis) : + if self.is_1d_reduction(): + shape_str = f"[{reduce_axis.name.upper()}BLOCK_SUB]" + else : + shape = ["1"] * len(self.kernel.tiling_axis) + shape[self.reduced_dim] = f"{reduce_axis.name.upper()}BLOCK_SUB" + shape_str = f"[{','.join(shape)}]" + return shape_str + + def dense_size_list(self) -> List[str]: + sizes = [f"{x.name.upper()}BLOCK_SUB" for x in self.kernel.tiling_axis] + if self.numof_reduction_axis() > 1 : + return sizes + + reduce_axis = self.kernel.tiling_axis[-1] + sizes.pop(-1) + sizes.insert(self.reduced_dim, f"{reduce_axis.name.upper()}BLOCK_SUB" ) + return sizes + + def dense_size_str(self) : + sizes = self.dense_size_list() + if self.numof_reduction_axis() > 1: + return f"[{'* '.join(sizes)}]" + return f"[{', '.join(sizes)}]" + + def numof_reduction_axis(self): + return self.kernel.numof_reduction_axis() + + def reduction_axis_list(self): + return self.kernel.reduction_axis_list() + + def analyze_reduction_dim(self) : + + if self.numof_reduction_axis() > 1 : + self.kernel.persistent_reduction = True + self.reduced_dim = 0 + return 0 + + if not self.kernel.golden_var_list : + self.kernel.select_golden_varlist() + assert self.kernel.golden_var_list is not None + + dim = -1 + for i, x in enumerate(reversed(self.kernel.golden_var_list)) : + if x.name[0] == 'r' : + dim = i + break + return dim + + + + def analyze_reduction_dim1(self) : + # kernel_name = self.kernel.get_kernel_name("", self.kernel.node_schedule, self.kernel) + # if kernel_name == "triton_unk_fused_14" : + # pdb.set_trace() + + if self.numof_reduction_axis() > 1 : + self.kernel.persistent_reduction = True + self.reduced_dim = 0 + return 0 + reduction = self.reduction + # kept = [0,1,3], reduced = [2] + for i,x in enumerate(reduction.reduced_idx) : + if reduction.reduction_ranges[i] <=1 : + continue + reduced_idx = x + break + # the index (in reduction.ranges) of low_dims + low_dims = [i for i, x in enumerate(reduction.kept_idx) if x > reduced_idx] + if not low_dims : + return len(self.kernel.tiling_axis) -1 + elif len(low_dims) == len(reduction.kept_idx) : + return 0 + # reduction dim when low_dims are not meraged + dim = len(reduction.kept_idx) - len(low_dims) + + tiling_axis = self.kernel.tiling_axis[:-1] + merged =1 + j = len(tiling_axis) -1 + # remove all low_dims from tiling_axis + # all axis before ahead of j are high-orders + # then following is reduced dim + ranges = [x for x in reduction.ranges if x > 1] + for i in reversed(low_dims) : + len_axis = tiling_axis[j].length + len_reduction = ranges[i] * merged + if len_reduction < len_axis : + merged = merged * len_reduction + elif len_reduction == len_axis: + j = j - 1 + merged = 1 + else : + assert False, f"should not reach here low_dims({i})={len_reduction}, axis[{j}]=len)" + dim = j + 1 + return dim + \ No newline at end of file diff --git a/torch_npu/_inductor/codegen/npu_kernel_features.py b/torch_npu/_inductor/codegen/npu_kernel_features.py new file mode 100644 index 0000000000000000000000000000000000000000..6dd3189b679c93ccbcadde457aa5814688df709d --- /dev/null +++ b/torch_npu/_inductor/codegen/npu_kernel_features.py @@ -0,0 +1,94 @@ +import functools +from typing import Tuple, List +from typing import Iterable +import sympy + +import torch +from torch._inductor.codegen.simd_kernel_features import SIMDKernelFeatures, NodeScheduleEntry +from torch._inductor.utils import cache_on_self +from torch.utils._ordered_set import OrderedSet +from torch._inductor.virtualized import V +from torch._inductor.codegen.simd import SIMDScheduling +from typing import Iterable + +class NumelList(Tuple): + + def numels(self): + numel = functools.reduce(lambda a, b: a * b, self) + return numel + + def __eq__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel == numel2 + + def __le__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel <= numel2 + + def __lt__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel < numel2 + + def __ge__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel >= numel2 + + def __gt__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel > numel2 + + + def __mod__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel % numel2 + + def __truediv__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel / numel2 + + def __floordiv__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel // numel2 + + def __mul__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel * numel2 + + def __rmul__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel * numel2 + + def __add__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel + numel2 + + def __radd__(self, other): + numel = self.numels() + numel2 = other.numels() if isinstance(other, NumelList) else other + return numel + numel2 + + def __hash__(self): + return super(NumelList, self).__hash__() + + +class NPUKernelFeatures(SIMDKernelFeatures): + def __init__( + self, + node_schedule: List[NodeScheduleEntry], + numel: sympy.Expr, + reduction_numel: sympy.Expr = sympy.S.One, + ): + super().__init__(node_schedule, numel, reduction_numel) + self.numel = NumelList(self.numel) if isinstance(self.numel, Iterable) else self.numel + self.reduction_numel = NumelList(self.reduction_numel) if isinstance(self.reduction_numel, Iterable) else self.reduction_numel diff --git a/torch_npu/_inductor/codegen/schduling.py b/torch_npu/_inductor/codegen/schduling.py new file mode 100644 index 0000000000000000000000000000000000000000..2fae3cd0185cb386bd29843da9b9dadb07957730 --- /dev/null +++ b/torch_npu/_inductor/codegen/schduling.py @@ -0,0 +1,343 @@ +import itertools +import contextlib +from typing import Union, Iterable +from typing import Dict, Sequence, List, Iterable +import sympy + + +from torch.fx.immutable_collections import immutable_dict +from torch._inductor.codegen.triton import (TritonScheduling, log, config) +from torch._inductor.codegen.simd import DisableReduction, EnableReduction, SIMDKernelFeatures, SIMDKernel +from torch._inductor.codegen.simd import schedule_log, scheduler +from torch._inductor.codegen.multi_kernel import MultiKernel +from torch._inductor.virtualized import (V,) +from torch._inductor.codecache import code_hash +from torch._dynamo.utils import counters +from torch._inductor.utils import sympy_index_symbol, ModularIndexing, FloorDiv + +from torch_npu._inductor.codegen.triton import NPUIndexTritonKernel, flatten +from .split_tiling import SplitTiling +from torch.fx.immutable_collections import immutable_dict +from .npu_kernel_features import NumelList, NPUKernelFeatures + + +import os +from typing import List, Union, Any +import collections + +from .triton import NPUIndexTritonKernel +from .. import config as npu_config +from torch._inductor.codegen.triton import ( + TritonScheduling, + log, + config, + schedule_log, + get_fused_kernel_name, + get_kernel_category_by_source_code, + Placeholder, + get_kernel_metadata, + get_path, + IndentedBuffer + ) +from torch._inductor.codegen.simd import DisableReduction, EnableReduction +from torch._inductor import scheduler, metrics +from torch._inductor.virtualized import ( + V, +) +from torch._inductor.codecache import code_hash +from torch._dynamo.utils import counters +import itertools, contextlib +from torch._inductor.utils import sympy_index_symbol +import sympy +from .split_tiling import SplitTiling +from ..lowering_fx import ( + create_fx_from_snodes_by_traced_graph, + create_compile_kwargs, + generate_fx_graph_code, + dump_fx_graph_code + ) +from .kernel_analysis import ReductionAnalysis + +def flatten_groups(nums): + res = [] + for i in nums: + if isinstance(i, Iterable): + for x in i: + res.append(x) + else: + res.append(i) + return res + + +@classmethod +def create_tiling( + cls, pw_tiling: Sequence[sympy.Expr], reduction_tiling: Sequence[sympy.Expr] + ) -> Dict[str, sympy.Expr]: + """ + Create a tiling dict from pointwise and reduction splits. + """ + + pw_tiling = flatten_groups(pw_tiling) + pw_prefixes = ["w", "v", "t", "z", "y", "x"][-len(pw_tiling):] + reduction_tiling = flatten_groups(reduction_tiling) + reduction_tiling = [NumelList(reduction_tiling).numels()] + reduction_prefixes = ["r"][: len(reduction_tiling)] + tiling = immutable_dict( + list(zip(pw_prefixes, pw_tiling)) + + list(zip(reduction_prefixes, reduction_tiling))) + return tiling + + + +class NPUTritonScheduling(TritonScheduling): + def __init__(self, input_scheduler): + super().__init__(input_scheduler) + self.kernel_type = NPUIndexTritonKernel + + def create_kernel_choices( + self, kernel_features: SIMDKernelFeatures, kernel_args, kernel_kwargs + ) -> List[SIMDKernel]: + + return [ + self.kernel_type( + *kernel_args, + **kernel_kwargs, + ) + ] + + # transform indexing before call codegen_node_schedule_with_kernel + def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures, nodes): + node_schedule = kernel_features.node_schedule + tiling = self.select_tiling( + node_schedule, kernel_features.numel, kernel_features.reduction_numel + ) + + kernels = self.create_kernel_choices( + kernel_features, [tiling], {"features": kernel_features} + ) + kernel = kernels[0] + setattr(kernel, "node_schedule", node_schedule) + self.decide_codegen_dims_in_kernel(node_schedule, kernel) + + for kernel in kernels: + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + + MultiKernel.merge_workspaces_inplace(kernels) + for kernel in kernels: + with V.set_kernel_handler(kernel): + src_code = kernel.codegen_kernel() + + V.graph.removed_buffers |= kernel.removed_buffers + V.graph.inplaced_to_remove |= kernel.inplaced_to_remove + + if npu_config.check_accuracy: + if not npu_config.traced_fx_graph_cache: + npu_config.traced_fx_graph_cache = os.path.join(os.getenv("TORCHINDUCTOR_CACHE_DIR"), 'traced_fx_graph_cache') + os.makedirs(npu_config.traced_fx_graph_cache, exist_ok=True) + traced_graph, fx_call_args, fx_args, compile_kwargs = create_fx_from_snodes_by_traced_graph(nodes) + traced_graph_hash = code_hash(traced_graph.print_readable(print_output=False)) + + kernel_name, src_code = self.define_kernel(src_code, node_schedule, kernel, traced_graph_hash \ + if npu_config.check_accuracy else None) + + log.debug("Generating kernel code with kernel_name: %s", kernel_name) + kernel.kernel_name = kernel_name + kernel.code_hash = code_hash(src_code) + del kernel + + final_kernel: Union[SIMDKernel, MultiKernel] + if len(kernels) > 1: + final_kernel = MultiKernel(kernels) + else: + (final_kernel,) = kernels + + with V.set_kernel_handler(final_kernel): + for node in kernel_features.scheduler_nodes(): + node.mark_run() + + self.codegen_comment(node_schedule) + final_kernel.call_kernel(final_kernel.kernel_name) + + if npu_config.check_accuracy: + new_compile_kwargs = create_compile_kwargs(final_kernel, fx_call_args, fx_args) + if new_compile_kwargs: + compile_kwargs |= new_compile_kwargs + fx_dump_path = os.path.join(npu_config.traced_fx_graph_cache, traced_graph_hash) + os.makedirs(fx_dump_path, exist_ok=True) + fx_code = generate_fx_graph_code(traced_graph.code, src_code, kernel_name, compile_kwargs) + dump_fx_graph_code(fx_code, fx_dump_path, traced_graph_hash) + os.environ[traced_graph_hash] = fx_dump_path + + if config.nan_asserts: + final_kernel.codegen_nan_check() + if config.warn_mix_layout: + final_kernel.warn_mix_layout(kernels[0].kernel_name) + + V.graph.removed_buffers |= final_kernel.removed_buffers + V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove + + if ( + V.graph.wrapper_code.supports_intermediate_hooks + and config.generate_intermediate_hooks + ): + # Not every node in the schedule will actually be live on output; + # we can't check dead buffers. + live_outs = kernels[0].args.live_output_buffers() + for node in kernel_features.scheduler_nodes(): + name = node.get_name() + if name not in live_outs: + continue + if node.node is None: + raise RuntimeError("assert node.node is not None") + + origin_node = node.node.get_origin_node() + if origin_node is not None: + counters["inductor"]["intermediate_hooks"] += 1 + V.graph.wrapper_code.writeline( + f"run_intermediate_hooks({origin_node.name!r}, {name})" + ) + + self.scheduler.free_buffers() + + + def define_kernel(self, src_code, node_schedule, kernel, traced_graph_hash: str): + wrapper = V.graph.wrapper_code + if (src_code, traced_graph_hash) in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[(src_code, traced_graph_hash)] + if npu_config.check_accuracy: + src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name) + subs_name = kernel_name if config.triton.unique_kernel_names else "triton_" + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), subs_name) + if traced_graph_hash: + src_code = src_code.replace('TRACED_GRAPH_HASH', traced_graph_hash) + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_category = get_kernel_category_by_source_code(src_code)[:3] + kernel_name = "_".join( + ["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()] + ) + # use the original src_code as the key + wrapper.src_to_kernel[(src_code, traced_graph_hash)] = kernel_name + subs_name = kernel_name if config.triton.unique_kernel_names else "triton_" + + # DESCRIPTIVE_NAME is used for profiling purposes; it shows the full kernel name + # even when unique_kernel_names is turned off. Meanwhile, KERNEL_NAME is sometimes set + # to "triton_" to maximize caching opportunities (when unique_kernel_names = False). + src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name) + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), subs_name) + if traced_graph_hash: + src_code = src_code.replace('TRACED_GRAPH_HASH', traced_graph_hash) + + # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does + # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. + src_code = src_code.replace("#pragma CMT", "#") + + basename, _, kernel_path = get_path(code_hash(src_code.strip()), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''") + compile_wrapper.splice(src_code, strip=True) + current_device = V.graph.get_current_device_or_throw() + compile_wrapper.writeline(f"''', device_str='{current_device.type}')") + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + + # log kernel metadata for offline analysis. + # E.g. one can find all unaligned inner reduction and check if + # padding helps with the perf kernel by kernel. + if metrics.is_metric_table_enabled("kernel_metadata"): + metrics.log_kernel_metadata(kernel_name, kernel_path, src_code) + + return kernel_name, src_code + + + def codegen_node( + self, node: Union[scheduler.FusedSchedulerNode, scheduler.SchedulerNode] + ): + """ + Given a set of pre-fused nodes, generate a Triton kernel. + """ + + nodes: List[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment] + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + schedule_log.debug("Schedule:\n %s", node_schedule) + + return self.codegen_node_schedule( + NPUKernelFeatures(node_schedule, numel, rnumel), nodes + ) + + + def decide_codegen_dims_in_kernel(self, node_schedule, kernel): + def current_reduction_nodes(nodes): + return itertools.takewhile(lambda n: n is not DisableReduction, nodes) + + with kernel: + # 1. transform dims: create new dims to substitute floor_divide and modular expression + stack = contextlib.ExitStack() + for _, node in enumerate(node_schedule): + if node is DisableReduction: + stack.enter_context(kernel.disable_reduction()) + elif node is EnableReduction: + stack.close() + else: + index_vars = kernel.split_and_set_ranges(node.get_ranges()) + node._body.transform_dims_in_indexing(index_vars) + # 2. go through range_tree_nodes to findout, to find one axis could be substituted by others + self.additional_nodes_to_be_subs(kernel, kernel.range_tree_nodes_substituted) + # 3.do the substitution on all indexing + for node in node_schedule: + if node in (EnableReduction, DisableReduction): + continue + indexing = node._body.indexing + node._body.substituted_dims_in_indexing(indexing, kernel, kernel.range_tree_nodes_substituted) + + # 4.remove the substituted dims from kernel + for var, _ in kernel.range_tree_nodes_substituted.items(): + if (var in kernel.range_tree_nodes): + root = kernel.range_tree_nodes[var].parent + root.remove_entry(var) + # select split and tiling axis + split_tiling = SplitTiling(kernel) + split_tiling.select_split_tiling_axis() + kernel.load_store_indexing = split_tiling.indexing + # debug print index transforms + for node in node_schedule: + if node in (EnableReduction, DisableReduction): + continue + for x, y in zip(node._body.indexing_exprs.values(), node._body.indexing.values()): + print(f"index transform:{x}->{y}") + # ReductionAnalysis depends on kernel.load_store_indexing + if kernel.inside_reduction : + kernel.reduce_analysis = ReductionAnalysis(kernel) + + def additional_nodes_to_be_subs(self, kernel, node_to_be_substituted): + for node in kernel.range_tree_nodes.values(): + if node.expr != sympy_index_symbol(f"{node.parent.prefix}index") \ + or len(node.parent.var_ranges) == 1 \ + or node.symbol() in node_to_be_substituted: + continue + numel = sympy.Integer(1) + new_var_expr = sympy.Integer(0) + for k, s in node.parent.var_ranges.items(): + if k == node.symbol(): + continue + numel = numel * s + sub_node = kernel.range_tree_nodes[k] + new_var_expr = new_var_expr + sub_node.symbol() * sub_node.divisor + + if numel == node.length: + node_to_be_substituted[node.symbol()] = [(node.length, new_var_expr)] + else: + log.warning("sub nodes (expr%s, numel:%d) can not make up parent node(%s:%d)", + new_var_expr, numel, node.symbol(), node.length) \ No newline at end of file diff --git a/torch_npu/_inductor/codegen/split_tiling.py b/torch_npu/_inductor/codegen/split_tiling.py new file mode 100644 index 0000000000000000000000000000000000000000..d52717fe711905dd73487e4fcd0623280e65a616 --- /dev/null +++ b/torch_npu/_inductor/codegen/split_tiling.py @@ -0,0 +1,278 @@ +import pdb + +from torch._inductor.codegen.triton import TritonKernel +from torch._inductor.utils import ModularIndexing,sympy_subs +import sympy as sympy +from ..config import num_vector_core, log +from torch._inductor.virtualized import V +from torch._inductor.codegen.simd import ( EnableReduction, DisableReduction) +from torch._inductor.runtime.runtime_utils import next_power_of_2 +from .triton_utils import get_aligned_numel +from torch._inductor.loop_body import MemoryUsageType +from functools import reduce +from .kernel_analysis import IndexAnalysis + +# split and tiling axis selector +class SplitTiling : + def __init__(self, kernel : TritonKernel) : + self.kernel = kernel + self.indexing = [] # load and store indexing among all scheduler nodes + kernel.sorted_axis = [x for x in kernel.range_tree_nodes.values()] + kernel.sorted_axis.sort(reverse=True, key = self.key) + for i, dim in enumerate(kernel.sorted_axis): + dim.sorted_order = i + + self.find_lowest_dimension() + self.should_outer_reduce = False + self.possible_need_permute = self.find_possible_permutes() + + def find_possible_permutes(self) : + if len(self.kernel.low_dims) <= 1 : + return False + var_lists = [] + low_dims = [self.kernel.sorted_axis[x].symbol() for x in self.kernel.low_dims] + for index in self.indexing : + var_stride = [(key,coeff) for key, coeff in index.as_coefficients_dict().items() if not isinstance(key, sympy.Integer)] + var_stride.sort(key = lambda x : x[1] ) + var_list= tuple([x[0] for x in var_stride if x[0] in low_dims ]) + var_lists.append(var_list) + for i, var_list in enumerate(var_lists) : + if len(var_list) < len(low_dims) : + continue + for j, other in enumerate(var_lists) : + if i == j or len(other) < len(low_dims): + continue + if var_list != other : + return True + return False + + + def key(self, x) : + # to be higher than x and y + if x.name[0] == 'w' or x.name[0] == 'v' or x.name[0] == 't': + return "zz" + x.name + # to be lower than floor_dir + elif isinstance(x.expr, ModularIndexing): + return x.name[0] + "0" + x.name[1:] + else : + return x.name + + def total_split_numels(self, axis_list): + numels = [x.length for x in axis_list] + return reduce(lambda x,y:x*y, numels) if numels else 1 + + # Split 原则1 :先做维度合并,再切分 。通过维度合并降维降低split和tiling轴选择策略的复杂性 。 + # Split 原则2 : 切分轴尽量选择高维度的轴, 这样load/store 能够有比较好的线性度 , + # Split 原则3 : 规约轴和低维轴不应选为切分轴 。但如果高维规约类融合算子,而且高维尺寸非常大( >= 64KB),其他维度不足以支持切分,可以考虑对规约轴切分。 + # Split 原则4 :切分轴的总numel 要超过 aicore总数。切分轴的数量最好不要超过3个(triton 最多支持三维发射), 因此 如果一点要超, 需要维度合并。 + def select_split_axis(self): + self.kernel.split_axis.clear() + + # total numel exceed aicore or total split axis exceed 3 + def meet_stop_condition() : + if self.total_split_numels(self.kernel.split_axis) >= num_vector_core : + return True + if len(self.kernel.split_axis) == 3 : + return True + return False + + def select_one_split_axis(not_reduction = True, not_low_dims = True ) : + for axis in self.kernel.sorted_axis : + if not_reduction and axis.prefix == "r" : + continue + if not_low_dims and axis.sorted_order in self.kernel.low_dims : + continue + if axis in self.kernel.split_axis : + continue + axis.is_split_axis = True + return axis + return None + count = 0 + while not meet_stop_condition() : + count += 1 + axis = select_one_split_axis(not_reduction=True, not_low_dims=True) + if axis is not None : + self.kernel.split_axis.append(axis) + continue + axis = select_one_split_axis(not_reduction=True, not_low_dims=False ) + if axis is not None : + self.kernel.split_axis.append(axis) + continue + #fixme later, to split reduction dim + if count > 10 : + break + + if not self.kernel.split_axis and self.kernel.sorted_axis: + self.kernel.split_axis.append(self.kernel.sorted_axis[0]) + + self.kernel.split_axis.sort(reverse=True, key = self.key) + for i, x in enumerate(self.kernel.split_axis) : + x.split_order = i + + + # Tiling 原则1:load / store 中索引表达式的中的低维轴都要成为tiling 轴. + # Tiling 原则2:对于规约算子,规约轴要成为tiling轴。 + # Tiling 原则3: 多维规约, 只有规约轴可以被选择为tiling轴 + # Tiling 原则4: tiling轴 要覆盖 total numel 的 80% + + + # fixme, two tiling axis might be insufficient when there're 3 or more low-dims in indexing + def select_tiling_axis(self ): + self.kernel.tiling_axis.clear() + #longest = self.find_longest_dimension() + # cover the biggest axis and not exceed 3 axis + def meet_stop_condition() : + total_numel = reduce(lambda x,y : x + y, map(lambda x:x.length, self.kernel.sorted_axis)) if self.kernel.sorted_axis else 1 + tiling_numel = reduce(lambda x,y :x + y, map(lambda x:x.length, self.kernel.tiling_axis)) if self.kernel.tiling_axis else 1 + if self.kernel.numof_reduction_axis() > 1 and all(self.kernel.range_tree_nodes[var].is_tiling_axis for var in self.kernel.reduction_axis_list()) : + return True + #currently, the maximum dim that triton-ascend support is 2 + max_transpose_dims = 2 + if (self.possible_need_permute or tiling_numel / total_numel >= 0.8) and \ + len(self.kernel.tiling_axis) >= min(max_transpose_dims, len(self.kernel.sorted_axis)) : + return True + return False + + def select_tiling(low_dim = True, reduction = True ) : + for axis in reversed(self.kernel.sorted_axis) : + if low_dim and axis.sorted_order in self.kernel.low_dims and axis not in self.kernel.tiling_axis: + axis.is_tiling_axis = True + self.kernel.tiling_axis.append(axis) + if reduction and axis.prefix == 'r' and axis not in self.kernel.tiling_axis: + axis.is_tiling_axis = True + self.kernel.tiling_axis.append(axis) + if low_dim or reduction : + continue + # using principle 4, select one longest + longest = axis #self.find_longest_dimension(check_in_tiling = True) + if longest and longest not in self.kernel.tiling_axis: + self.kernel.tiling_axis.append(longest) + longest.is_tiling_axis = True + if meet_stop_condition(): + break + + select_tiling(low_dim=True, reduction=True) + count = 0 + while not meet_stop_condition(): + select_tiling(low_dim=False, reduction=False) + count += 1 + if count > 10 : + break + self.kernel.tiling_axis.sort(reverse=True, key = self.key) + for i , x in enumerate(self.kernel.tiling_axis) : + x.tiling_order = i + + + def select_split_tiling_axis(self) : + self.select_split_axis() + self.select_tiling_axis() + log.info(f"split_tiling numels:{self.kernel.numels} split_axis: {','.join([x.name for x in self.kernel.split_axis])} " + f"tiling_axis: {','.join([x.name for x in self.kernel.tiling_axis])} low_dims:{self.kernel.low_dims}, " + f"indexing: {self.indexing} possible_need_permute:{self.possible_need_permute}" ) + + # fixme the below logic doesn't work when there're two reduction axis, but only one need outer reduction + def should_outer_reduce_me(self, x): + should_outer = self.kernel.is_higher_order_reduction(True) and SplitTiling.great_than(x.length, 32768 ) and x.is_loop + if should_outer : + self.should_outer_reduce = True + self.kernel.split_axis = x + self.kernel.split_axis.is_split_axis = True + return should_outer + + def find_longest_dimension(self, check_in_tiling = False ): + longest = None + for axis in self.kernel.sorted_axis: + if (longest is None or axis.length > longest.length) and \ + (not check_in_tiling or axis not in self.kernel.tiling_axis ) : + longest = axis + return longest + + # return True when x is the low-dim in indexing + def is_lowest_dimension(self, x): + return x.sorted_order in self.kernel.low_dims + + def find_lowest_dimension(self): + def construct_low_dim() : + for index in self.indexing: + coefficients_dict = index.as_coefficients_dict() + for key, value in coefficients_dict.items(): + if not key.free_symbols: + continue + key = list(key.free_symbols)[0] + if key not in self.kernel.range_tree_nodes: + continue + + if value == sympy.Integer(1): + axis = self.kernel.range_tree_nodes[key] + self.kernel.low_dims.add(axis.sorted_order) + + # all read index should be considered + buf_names = [node.node.name for node in self.kernel.node_schedule if + node not in (EnableReduction, DisableReduction)] + for node in self.kernel.node_schedule: + if node in (EnableReduction, DisableReduction): + continue + names = [] + + for read in node._body.memory_usage[MemoryUsageType.LOAD]: + name = read.index_name + arg = read.buffer_name + read_is_inptr = False if arg[:3] != 'arg' and arg in buf_names else True + if read_is_inptr: + names.append(name) + for key, index in node._body.indexing.items(): + if key in names and index not in self.indexing: + self.indexing.append(index) + + if self.kernel.inside_reduction : + construct_low_dim() + return + + # for non-reduction, write index should be considered + for node in self.kernel.node_schedule: + if node in (EnableReduction, DisableReduction): + continue + names = [] + for write in node._body.memory_usage[MemoryUsageType.STORE]: + names.append(write.index_name) + for write in node._body.memory_usage[MemoryUsageType.STORE_REDUCTION]: + names.append(write.index_name) + for key, index in node._body.indexing.items(): + if key in names and index not in self.indexing: + self.indexing.append(index) + + construct_low_dim() + + @staticmethod + def convert(x, y): + xnumel = x + ynumel = y + if isinstance(xnumel, (sympy.Symbol, sympy.Expr)) and not isinstance(xnumel, sympy.Integer): + xnumel = xnumel.subs(V.graph.sizevars.var_to_val) + + if isinstance(ynumel, (sympy.Symbol, sympy.Expr)) and not isinstance(ynumel, sympy.Integer): + ynumel = ynumel.subs(V.graph.sizevars.var_to_val) + + if isinstance(xnumel, sympy.Integer) and isinstance(ynumel, int): + ynumel = sympy.Integer(ynumel) + + if isinstance(ynumel, sympy.Integer) and isinstance(xnumel, int): + xnumel = sympy.Integer(xnumel) + + return (xnumel, ynumel) + + + @staticmethod + def less_than(x, y): + xnumel, ynumel = SplitTiling.convert(x, y) + return xnumel < ynumel + + @staticmethod + def great_than(x, y): + xnumel, ynumel = SplitTiling.convert(x, y) + return xnumel > ynumel + + @staticmethod + def ge_than(x, y): + xnumel, ynumel = SplitTiling.convert(x, y) + return xnumel >= ynumel diff --git a/torch_npu/_inductor/codegen/tile_generator.py b/torch_npu/_inductor/codegen/tile_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..40e60bd3cafbb15d757c167170ca9480fafb2573 --- /dev/null +++ b/torch_npu/_inductor/codegen/tile_generator.py @@ -0,0 +1,253 @@ +import copy +import math + +from torch._inductor.runtime.triton_heuristics import Config +from torch._inductor.runtime.runtime_utils import next_power_of_2 +from .triton_utils import byte_per_numel +import functools +from ..config import num_vector_core +import sys +# generate tiling configs +class TileGenerator: + + def __init__(self, numels, axis_names, tiling_axis, split_axis, low_dims, persistent_reduction, + configs, dtype, dual_reduction = False) : + self.numels = numels.copy() + + self.blocks = [x for x in self.numels] + self.candidate_blocks=[] + self.sub_blocks = self.blocks.copy() + self.axis_name = axis_names + self.tiling_axis = tiling_axis + self.split_axis = split_axis + self.low_dims = low_dims + self.configs = configs + self.dtype_bytes = self.get_byte_per_numel(dtype) + self.stop_numel = 1024 // self.dtype_bytes + self.block_name = {} + self.sub_block_name = {} + self.persistent_reduction = persistent_reduction + self.dual_reduction = dual_reduction + for axis, name in enumerate(self.axis_name) : + if axis not in tiling_axis and axis not in split_axis : + self.blocks[axis] = 1 + self.sub_blocks[axis] =1 + continue + if axis in self.split_axis : + self.block_name[axis] = f"{name.upper()}BLOCK" + if axis in self.tiling_axis : + self.sub_block_name[axis] = f"{name.upper()}BLOCK_SUB" + + + def aligned_numel(self, numel): + aligned = next_power_of_2(numel) + return aligned + + + def get_byte_per_numel(self, dtype): + if dtype is None : + return 1 + return byte_per_numel[dtype] + + + def valid_tile_numel(self, total_numel): + bytes = self.dtype_bytes + max_numel = 16384 * 4 // bytes + return total_numel <= max_numel + + + def calculate_config_numel(self, config) : + total_numel = 1 + # for axis in self.split_axis : + # if axis not in self.tiling_axis : + # total_numel = total_numel * config[self.block_name[axis]] + for axis in self.tiling_axis : + total_numel = total_numel * config[self.sub_block_name[axis]] + return total_numel + + def calculate_total_numel(self) : + smallest = sys.maxsize + def calculate_total_numel_candi( blocks) : + total_numel = 1 + # for axis in self.split_axis : + # if axis not in self.tiling_axis : + # total_numel = total_numel * blocks[axis] + for axis in self.tiling_axis : + total_numel = total_numel * self.sub_blocks[axis] + return total_numel + for candi_blocks in self.candidate_blocks : + numel = calculate_total_numel_candi(candi_blocks) + if numel < smallest : + smallest = numel + return smallest + + def fill_config(self, config, blocks) : + for axis in self.split_axis : + config[self.block_name[axis]] = blocks[axis] + for axis in self.tiling_axis : + tiling_numel = self.aligned_numel(self.sub_blocks[axis] ) + config[self.sub_block_name[axis]] = tiling_numel + def find_config(self, cfg) : + for config in self.configs : + if config.kwargs == cfg : + return True + return False + + def add_to_configs(self, candi_block) : + newcfg = {} + self.fill_config(newcfg, candi_block ) + total_numel = self.calculate_config_numel(newcfg) + if self.valid_tile_numel(total_numel) and not self.find_config(newcfg): + self.configs.append(Config(newcfg, num_warps=1, num_stages=1)) + + def descend_one_axis(self, axis, is_split = False ): + def calc_total_programs(): + grids = [] + for axis in self.split_axis : + numel = self.numels[axis] + block_size = self.blocks[axis] + programs = (numel + block_size -1 ) // block_size + grids.append(programs) + + total_programs = functools.reduce(lambda x,y : x * y, grids) if grids else 1 + return total_programs + + reached_stop_numel = False + slow_decend_split = False + + while True : + total_numel = self.stop_numel + 100 + for candi_block in self.candidate_blocks : + self.add_to_configs(candi_block) + + # tile numel reached threshold + total_numel = self.calculate_total_numel() + if total_numel <= self.stop_numel: + self.add_to_configs(self.blocks) + reached_stop_numel = True + break + + numel = self.blocks[axis] if is_split else self.sub_blocks[axis] + if numel == 1 : + self.add_to_configs(self.blocks) + break + + if is_split : + if self.persistent_reduction and self.axis_name[axis][0] == "r" : + reached_stop_numel = True + break + total_programs = calc_total_programs() + if total_programs > num_vector_core : + break + if total_programs > num_vector_core // 2 or self.dual_reduction: + if len(self.candidate_blocks) > 2 : + self.candidate_blocks.pop(0) + self.candidate_blocks.append(tuple(self.blocks)) + + self.blocks[axis] = numel // 2 + self.sub_blocks[axis] = self.blocks[axis] + total_programs = calc_total_programs() + if total_programs > num_vector_core: + slow_decend_split = True + step = numel // 4 if numel // 4 > 1 else 1 + self.blocks[axis] = numel // 2 if not slow_decend_split else numel -step + self.sub_blocks[axis] = self.blocks[axis] + else : + if numel >= 128 : + self.sub_blocks[axis] = next_power_of_2(numel // 2 ) + else :# numel >4 and numel < 128 : + self.slow_descend_axis(axis) + # else : + # break + return reached_stop_numel + + def slow_descend_axis(self, axis) : + numel = self.sub_blocks[axis] + self.sub_blocks[axis] = self.aligned_numel( numel // 2 ) + # numel = self.aligned_numel( max(numel - 4, numel //2 )) + # if (numel == self.sub_blocks[axis]) : + # numel = self.aligned_numel( max(numel - 8, numel //2 )) + # self.sub_blocks[axis] = numel + + def descend_all_low_dims(self) : + low_dim_numels = [self.sub_blocks[x] for x in self.low_dims] + if not low_dim_numels : + return + + def descent_all_axis(min_numel ) : + for axis in self.low_dims : + if self.axis_name[axis][0] == "r" and self.persistent_reduction : + continue + numel = self.sub_blocks[axis] + if numel == 1 : + continue + if min_numel > 1 and abs(numel - min_numel) / min_numel < 0.2 : + continue + if numel >= 128 : + self.sub_blocks[axis] = next_power_of_2(numel // 2 ) + else :# numel >4 and numel < 128 : + self.slow_descend_axis(axis) + + count = 0 + total_numel = self.calculate_total_numel() + while total_numel > self.stop_numel and count < 100: + count += 1 + total_numel = self.calculate_total_numel() + for candi_block in self.candidate_blocks : + self.add_to_configs(candi_block) + min_numel = min(low_dim_numels) + descent_all_axis(min_numel) + total_numel_2 = self.calculate_total_numel() + if total_numel == total_numel_2 : + descent_all_axis(0) + + return total_numel < self.stop_numel + + def descend_split_tiling(self ): + + tiling_not_low_dims = [x for x in self.tiling_axis if x not in self.low_dims ] + def descend_split_axis () : + + for axis in self.split_axis : + if self.descend_one_axis(axis, is_split=True) : + return True + + total = self.calculate_total_numel() + return total <= self.stop_numel + + def desceond_tiling_not_low_dims() : + for axis in tiling_not_low_dims : + if self.axis_name[axis][0] == "r" and self.persistent_reduction : + continue + if self.descend_one_axis( axis) : + return True + total = self.calculate_total_numel() + return total <= self.stop_numel + + #fixme, need to all low dims fairly + def descend_low_dims() : + for axis in self.tiling_axis : + if self.axis_name[axis][0] == "r" and self.persistent_reduction : + continue + if axis in tiling_not_low_dims : + continue + if self.descend_one_axis(axis) : + return True + total = self.calculate_total_numel() + return total <= self.stop_numel + + while True : + # descend split axis + if descend_split_axis() : + break + if len(self.candidate_blocks) > 0 : + self.sub_blocks = list(self.candidate_blocks[0]) + # descend tiling but not low dims + if desceond_tiling_not_low_dims() : + break + # descend low dims, fixme, need to descend all axis at the same time + # descend_low_dims() + self.descend_all_low_dims() + break + + \ No newline at end of file diff --git a/torch_npu/_inductor/codegen/triton.py b/torch_npu/_inductor/codegen/triton.py new file mode 100644 index 0000000000000000000000000000000000000000..dd86d9f723e588e343f84c4994ba2e685b52140c --- /dev/null +++ b/torch_npu/_inductor/codegen/triton.py @@ -0,0 +1,1908 @@ +import os +from typing import List, Set, Iterable, Callable, Sequence +import operator +import itertools +from enum import Enum +import functools + +from typing import ( + Optional, + Union, + Tuple, + Any, + cast, + Dict +) + +import re +import textwrap +import sympy + +import torch +from torch._inductor.utils import sympy_subs +from torch._inductor.scheduler import SchedulerNode + +from torch._inductor.codegen.simd import CantSplit, DisableReduction, EnableReduction +from torch._inductor.codegen.common import free_symbol_is_type +from torch._inductor.codegen.triton import ( + IndexingOptions, + triton_reshape, + TritonCSEVariable, + OpsHandler, +) +from torch._inductor.runtime.hints import ReductionHint +from torch._inductor.codegen.triton import ( + TritonKernel, + TritonKernelOverrides, + IterationRangesRoot, + IterationRangesEntry, + CSEVariable, + gen_common_triton_imports, + BlockPtrOptions, + triton_acc_type, + constant_repr, + is_welford_reduction, FixedTritonConfig, + prefix_is_reduction, upcast_acc_dtype, + get_kernel_category_by_source_code, + get_fused_kernel_name +) + +from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing +from torch._inductor.utils import sympy_index_symbol, generate_assert +from torch.utils import _pytree as pytree +from torch.utils._sympy.value_ranges import ValueRanges +from torch._inductor import config, ir +from torch._inductor.virtualized import ( + V, + StoreMode, + ReductionType, + _ops as ops, +) + +from torch._inductor.utils import ( + Placeholder, +) +from torch._inductor.runtime.runtime_utils import next_power_of_2 +from torch._inductor.codegen.common import ( + IndentedBuffer, + SizeArg, + DeferredLine, +) +from torch._inductor.codegen.triton_utils import config_of, signature_of, signature_to_meta +from torch.utils._sympy.symbol import SymT, symbol_is_type +from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges +from torch.utils._sympy.numbers import int_oo +from torch._inductor.dtype_propagation import DtypePropagationOpsHandler +from .kernel_analysis import IndexAnalysis, ReductionAnalysis +import torch_npu._inductor.config as inductor_npu_config + +from ..runtime import NPUDeviceProperties +from .npu_kernel_features import NumelList + + +def flatten(nums): + res = [] + for i in nums: + if isinstance(i, list): + res.extend(flatten(i)) + else: + res.append(i) + return res + +class NPUTritonKernelOverrides(TritonKernelOverrides): + + @staticmethod + def exp(x): + return f"tl_math.exp({x})" + + @staticmethod + def sqrt(x): + return f"tl_math.sqrt({x})" + + @staticmethod + def tanh(x): + return f"tl_math.tanh({x})" + + @staticmethod + def rsqrt(x): + return f"tl.rsqrt({x})" + + @staticmethod + def floor(x): + return f"tl_math.floor({x})" + + @staticmethod + def erf(x): + return f"tl_math.erf({x})" + + @staticmethod + def ceil(x): + return f"tl_math.ceil({x})" + + +def group_fn(self, sizes): + groups = list() + for s in sizes : + if not s : + groups.append(1) + elif isinstance(s, list): + group = flatten(s) + groups.append(NumelList(tuple(group)) if isinstance(group, list) else group) + else : + groups.append(s) + return tuple(groups) + + +@staticmethod +def select_index_dtype(node_schedule, numel, reduction_numel): + return "tl.int32" + + + +class IterationRangesEntryNPUIndex(IterationRangesEntry): + def __init__( + self, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_tiling_axis = False + self.is_split_axis = False + self.indexing_code = IndentedBuffer() + self.sorted_order = None + self.tiling_order = None + self.split_order = None + self.var_directions = {} + self.directions = [] + # don't use functools.lru_cache(None), so that previous indexing_code produdec by previous index, + # could be overwritten + self.codegen = self._codegen + # axis mask + def _codegen_mask(self): + + if self.is_tiling_axis : + BLOCK_NAME = f"{self.name.upper()}BLOCK" + upper = f"min({BLOCK_NAME}+{self.symbol()}_offset, {self.name}_numel)" if self.is_split_axis else f"{self.name}_numel" + line = f"{self.name}_mask = {self.name} < {upper}" + self.writeline(line) + for var in self.var_directions.keys(): + line = f"{var.name}_mask = {var.name} < {upper}" + self.writeline(line) + else: + pass + + def get_axis_direction(self ) : + + #assume self.golden_var_list is to be correct axis order + + if self.directions: + return f"[{','.join(self.directions)}]" + tiling_axis = [x.symbol() for x in self.kernel.tiling_axis] + + rev_orders = [x for x in self.kernel.golden_var_list if x in tiling_axis] + self.directions = ["None"] * len(tiling_axis) + assert len(tiling_axis) == len(rev_orders), f"tiling len={len(tiling_axis)}, golden varlist len ={len(rev_orders)}" + var_orders = list(reversed(rev_orders)) + index = var_orders.index(self.symbol()) + self.directions[index] = ":" + return f"[{','.join(self.directions)}]" + + # axis var, FIXME, need to define var with diffent direction + def _codegen(self): + self.indexing_code.clear() + index = None + # for multiple reduce dims, don't need this + if not self.is_tiling_axis : + return self.name + + direction = self.get_axis_direction() + index = f"{self.name} = {self.codegen_index(direction)}" + for var, dir in self.var_directions.items(): + line = f"{var.name} = {self.codegen_index(dir)}" + self.writeline(line) + + # reduction axis + if self.prefix == 'r': + if V.kernel.inside_reduction and V.kernel.current_node \ + and isinstance(V.kernel.current_node, SchedulerNode) \ + and V.kernel.current_node.node \ + and V.kernel.current_node.node.data \ + and isinstance(V.kernel.current_node.node.data, ir.Reduction): + reduction_type = V.kernel.current_node.node.data.reduction_type + if reduction_type in {"argmax", "argmin"} : + self.writeline(f"{self.parent.prefix}index = " + f"{self.codegen_index(None)}") + if index: + self.writeline(index) + self._codegen_mask() + return self.name + + def writeline(self, line): + self.indexing_code.writeline(line) + + def is_1d_persisent_reduction(self) : + return len(V.kernel.tiling_axis) == 1 and V.kernel.persistent_reduction + + def codegen_index(self, direction): + BLOCK_NAME = f"{self.name.upper()}BLOCK" + BLOCK_NAME_SUB = f"{BLOCK_NAME}_SUB" + index = None + if self.prefix == 'r' : + if V.kernel.persistent_reduction : + if self.is_1d_persisent_reduction() : + index = f"tl.arange(0, {BLOCK_NAME_SUB})" + else : + index = f"base_{self.name}" + else: + index = f"(loop_{self.name} * {BLOCK_NAME_SUB}) + base_{self.name}" + else : + if self.is_split_axis : + offset = f"{self.symbol()}_offset" + index = f"{offset} + (loop_{self.name} * {BLOCK_NAME_SUB}) + base_{self.name}" + else : + index = f"(loop_{self.name} * {BLOCK_NAME_SUB}) + base_{self.name}" + + if len(V.kernel.tiling_axis) > 1 and direction is not None : + index += direction + + return index + + + def codegen_header(self, code): + # generate offset index loop + lines = [] + BLOCK_NAME = f"{self.name.upper()}BLOCK" + BLOCK_NAME_SUB = f"{BLOCK_NAME}_SUB" + + if self.is_1d_persisent_reduction() : + return + + if self.is_split_axis : + lines.append(f"{self.symbol()}_offset = tl.program_id({self.split_order}) * {BLOCK_NAME}") + + if self.is_tiling_axis : + lines.append(f"base_{self.name}= tl.arange(0, {BLOCK_NAME_SUB})") + block = f"{BLOCK_NAME}" if self.is_split_axis else f"{self.symbol()}_numel" + lines.append(f"loops_{self.name} = ({block} + {BLOCK_NAME_SUB} - 1) // {BLOCK_NAME_SUB}") + + else: + pass + + code.writelines(lines) + + def precomputed_args(self): + # for dynamic shapes, find parts of indexing expressions that have to be precomputed + precomputed_args: List[sympy.Expr] = [] + if isinstance(self.expr, (sympy.Symbol, sympy.Integer)): + return precomputed_args + + if not isinstance(self.expr, (FloorDiv, ModularIndexing)): + raise RuntimeError("assert isinstance(self.expr, (FloorDiv, ModularIndexing)), type(self.expr)") + for arg in self.expr.args[1:]: + if not isinstance(arg, (sympy.Integer, sympy.Symbol)): + symbols = arg.free_symbols + if len(symbols) > 0 and all( + symbol_is_type(s, SymT.SIZE) for s in symbols + ): + precomputed_args.append(arg) + return precomputed_args + + +class IterationRangesRootNPUIndex(IterationRangesRoot): + def __init__( + self, + name: str, + numel: sympy.Expr, + prefix: str, + index: int, + kernel: TritonKernel, + pid_cache=None, + *, + is_loop: bool, + tensor_dim: Optional[int], + grid_dim: Optional[int], + ): + super().__init__(name, numel, prefix, index, kernel, pid_cache, is_loop=is_loop, tensor_dim=tensor_dim, + grid_dim=grid_dim, has_zdim= False ) + + def __repr__(self): + return f"IterationRangesRootNPUIndex({self.name!r}, {self.numel}, ...)" + + def remove_entry(self, name): + if name in self.var_ranges : + del self.var_ranges[name] + if name in self.var_list: + del self.var_list[self.var_list.index(name)] + if name in V.kernel.range_tree_nodes: + V.kernel.range_tree_nodes_removed[name] = V.kernel.range_tree_nodes[name] + del V.kernel.range_tree_nodes[name] + if name in self.nodes: + del self.nodes[name] + + def duplicated_check(self, divisor, length): + """ + Lookup a given RangeTreeEntry, creating it if needed + """ + if V.graph.sizevars.statically_known_equals(divisor * length, self.numel): + expr = FloorDiv(sympy_index_symbol(f"{self.prefix}index"), divisor) + else: + expr = ModularIndexing( + sympy_index_symbol(f"{self.prefix}index"), divisor, length + ) + + return expr not in self.nodes + + + def lookup(self, divisor, length): + """ + Lookup a given RangeTreeEntry, creating it if needed + """ + if V.graph.sizevars.statically_known_equals(divisor * length, self.numel): + expr = FloorDiv(sympy_index_symbol(f"{self.prefix}index"), divisor) + else: + expr = ModularIndexing( + sympy_index_symbol(f"{self.prefix}index"), divisor, length + ) + + if expr not in self.nodes: + node = IterationRangesEntryNPUIndex( + f"{self.prefix}{next(V.kernel.iter_vars_count)}", + divisor, + length, + expr, + self, + ) + V.kernel.range_tree_nodes[node.symbol()] = node + self.var_list.append(node.symbol()) + self.var_ranges[node.symbol()] = length + self.nodes[expr] = node + + + return self.nodes[expr] + + +def is_compatible(groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]]): + try: + groups = flatten(groups) + NPUIndexTritonKernel._split_iteration_ranges(groups, lengths) + return True + except CantSplit: + return False + + +class NPUIndexTritonKernel(TritonKernel): + overrides = NPUTritonKernelOverrides + + def __init__( + self, + tiling: Dict[str, sympy.Expr], + min_elem_per_thread=0, + optimize_mask=True, + fixed_config: Optional[FixedTritonConfig] = None, + **kwargs,): + + super().__init__(tiling=tiling, + min_elem_per_thread=min_elem_per_thread, + optimize_mask=optimize_mask, + fixed_config=fixed_config, + **kwargs) + self.first_node = True + self.inside_high_order_reduction = False + self.low_dims = set() + self.split_axis = [] + self.tiling_axis = [] + self.range_tree_nodes_removed: Dict[sympy.Symbol, IterationRangesEntry] = {} + self.range_tree_nodes_substituted = {} + self.expr_substituted = {} + self.sorted_axis = [] + self.prefix: IndentedBuffer = IndentedBuffer() + self.index_analysis = {} # var_list -> indexAnalysis + self.golden_var_list = None + self.reduce_analysis = None + self.load_store_indexing = None + + def gen_triton_ext_imports(self): + imports = IndentedBuffer() + imports.splice( + """ + from torch._inductor.runtime import triton_helpers + from torch_npu._inductor import npu_triton_heuristics + from torch_npu._inductor import npu_triton_helpers + from torch_npu._inductor.runtime import NPUDeviceProperties + from torch_npu._inductor.npu_triton_helpers import libdevice, math as tl_math + import torch + import torch_npu + """ + ) + return imports.getvalue() + + + def patch_triton_hash(self): + # remove this method once the original invocation is fixed + import hashlib + from triton.compiler.compiler import triton_key, make_backend + from triton.runtime.driver import driver + backend = make_backend(driver.active.get_current_target()) + key = f"{triton_key()}-{backend.hash()}" + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + def numof_tiling_axis(self): + return len(self.tiling_axis) + + #do nothing in NpuTritonKernel + def codegen_range_tree(self): + pass + + + def initialize_range_tree(self, pid_cache): + #self.numels = flatten(self.numels) + self.total_numels = 0 + for k, x in self.numels.items() : + if not isinstance(x, sympy.Integer) : + x = x.subs(V.graph.sizevars.var_to_val) + self.numels[k] = x + if x > 1 : + self.total_numels +=1 + + no_r_dim = not self.inside_reduction or self.numels["r"] == 1 + prefixes = "wvtzyxr" + active_prefixes = prefixes[-len(self.numels) :] + #prefix can not be 's', 'u', 'ps' , 'i', 'z' + #prefix can not be 'p' but can be 'z' since 2.6 + grid_dims = "xyztvw" + if self.no_x_dim: + tensor_dims = "r" + elif no_r_dim: + tensor_dims = "xyztvw" + else: + tensor_dims = "xyztvwr" + tensor_dims = "".join(p for p in tensor_dims if p in active_prefixes) + for i, prefix in enumerate(active_prefixes): + is_reduction = prefix_is_reduction(prefix) + tensor_dim = tensor_dims.find(prefix) if prefix in tensor_dims else None + grid_dim = None if is_reduction else grid_dims.find(prefix) + index = i if grid_dim is None else grid_dim + self.range_trees.append( + IterationRangesRootNPUIndex( + f"{prefix}index", + self.numels[prefix], + prefix, + index, + self, + pid_cache=pid_cache, + is_loop=is_reduction and not self.persistent_reduction, + tensor_dim=tensor_dim, + grid_dim=grid_dim + ) + ) + + + + def get_axis_dtype(self, axis): + dtype = None + if axis is None : + return None + for node in self.node_schedule : + if node in (EnableReduction, DisableReduction) : + continue + if axis.symbol() in node._body.indexing_map : + dtype = V.graph.get_dtype(node.node.name) + break + if dtype is None : + should_break_all = False + for node in self.node_schedule: + if should_break_all: + break + if node in (EnableReduction, DisableReduction): + continue + for key, value in node._body.indexing_map.items(): + if key in self.range_tree_nodes : + dim = self.range_tree_nodes[key] + else : + dim = self.range_tree_nodes_removed[key] + + if dim.parent == axis.parent : + dtype = V.graph.get_dtype(node.node.name) + should_break_all = True + break + return dtype + + def create_inductor_meta(self): + mutated_args = set() + for mutation in self.mutations: + if mutation in self.args.input_buffers: + mutated_args.add(self.args.input_buffers[mutation]) + if ( + mutation in self.args.inplace_buffers + and mutation not in V.graph.removed_buffers + and mutation not in self.removed_buffers + ): + mutated_args.add(self.args.inplace_buffers[mutation].inner_name) + if mutation in self.args.output_buffers: + mutated_args.add(self.args.output_buffers[mutation]) + mutated_args = sorted(mutated_args) + tiling_axis = [x.sorted_order for x in self.tiling_axis] + split_axis = [x.sorted_order for x in self.split_axis] + axis_names = [x.name for x in self.sorted_axis] + split_axis_dtype = self.get_axis_dtype(self.split_axis[0]) if self.split_axis else None + inductor_meta = { + "autotune_hints": set(self.autotune_hints), + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + "mutated_arg_names": mutated_args, + + # Due to breaking change of triton 3.0, the original invocation is broken + "backend_hash": self.patch_triton_hash(), # torch.utils._triton.triton_hash_with_backend(), + "split_axis" : split_axis, + "tiling_axis" : tiling_axis, + "axis_names" : axis_names, + "low_dims" : self.low_dims, + "numof_reduction_axis": self.numof_reduction_axis(), + "split_axis_dtype": split_axis_dtype, + "dual_reduction": self.numof_reduction_axis() > 1, + "traced_graph_hash": "TRACED_GRAPH_HASH" + #"coordinate_descent_tuning" : True + + } + return inductor_meta + + # numels sent to autotune configs + def get_size_hints(self): + size_hints = [] + if (len(self.range_tree_nodes.values()) == 0): + return [v for _,v in self.numels.items()] + + for i, node in enumerate(self.sorted_axis): + if isinstance(node.expr, ModularIndexing): + numel_expr = node.length + else: + numel_expr = node.expr.subs({sympy_index_symbol(r.name): r.numel for r in self.range_trees}) + + numel_expr = V.graph.sizevars.symbolic_hint(numel_expr) + + size_hints.append(numel_expr) + return size_hints + + # torch251 done + def add_numel_to_call_args_and_grid(self, name, call_args, arg_types, grid): + for node in self.sorted_axis: + if isinstance(node.expr, ModularIndexing) : + numel_expr = node.length + else : + numel_expr = node.expr.subs({sympy_index_symbol(r.name): r.numel for r in self.range_trees}) + + if isinstance(numel_expr, (sympy.Integer, sympy.Symbol)): + expr = numel_expr + else: + expr = V.graph.wrapper_code.generate_node_numel_expr(name, node, numel_expr) + call_args.append(expr) + arg_types.append(type(expr)) + if node.parent.grid_dim is not None: + grid.append(expr) + + def gen_numel_args(self, signature, triton_meta_signature, argdefs ): + for node in self.sorted_axis: + arg_name = f"{node.name}_numel" + if not inductor_npu_config.inductor_static_mode: + sizearg = SizeArg(arg_name, node.length) + signature.append(sizearg) + triton_meta_signature[arg_name] = signature_of( + sizearg, size_dtype=self.index_dtype + ) + argdefs.append(arg_name) + else : + argdefs.append(f"{arg_name}: tl.constexpr") + self.triton_meta["constants"][arg_name] = node.length + + # BLOCK and SUB_BLOCK definitions + def add_autotune_args(self, argdefs): + for axis in self.split_axis : + argdefs.append(f"{axis.name.upper()}BLOCK: tl.constexpr") + + for axis in self.tiling_axis : + if axis.name[0] == 'r' and self.persistent_reduction: + continue + argdefs.append(f"{axis.name.upper()}BLOCK_SUB: tl.constexpr") + + def _get_heuristic(self): + if self.persistent_reduction: + assert self.inside_reduction + return "persistent_reduction_npu_index" + elif self.inside_reduction: + return "reduction_npu_index" + return "pointwise_npu_index" + + def get_kernel_name(self, src_code, node_schedule, kernel): + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_category = get_kernel_category_by_source_code(src_code)[:3] + kernel_name = "_".join( + ["triton", kernel_category, fused_name, wrapper.get_next_kernel_suffix()] + ) + return kernel_name + + # modify triton_meta, inductor_meta , etc. + def codegen_kernel(self, name=None): + code = IndentedBuffer() + size_hints = self.get_size_hints() + heuristics = self._get_heuristic() + if name is None: + code.splice(gen_common_triton_imports()) + # Note: add extra imports for extensions + code.splice(self.gen_triton_ext_imports()) + + if config.benchmark_kernel: + code.splice(self.imports_for_benchmark_kernel()) + + argdefs, _, signature, _ = self.args.python_argdefs() + + for i, arg in enumerate(signature): + if isinstance(arg, SizeArg): + symbol = cast(sympy.Symbol, arg.expr) + if symbol in V.graph.sizevars.inv_precomputed_replacements: + signature[i] = SizeArg( + arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol] + ) + + triton_meta_signature = signature_to_meta( signature, size_dtype=self.index_dtype, argdefs = argdefs ) + + triton_meta = { + "signature": triton_meta_signature, + "device": + NPUDeviceProperties.create( + V.graph.get_current_device_or_throw() + ), + "constants": {}, + # special config for NPU, specify compile target + "mix_mode": "aiv", + } + + inductor_meta = self.create_inductor_meta() + num_gb = None + if config.benchmark_kernel or config.profile_bandwidth: + num_gb = self.estimate_kernel_num_bytes() / 1e9 + inductor_meta["kernel_num_gb"] = num_gb + + self.triton_meta = triton_meta + self.gen_numel_args(signature, triton_meta_signature, argdefs) + + #add in tiling args + self.add_autotune_args(argdefs) + #for scalar codegen + if len(self.range_tree_nodes) == 0: + self.write_scalar() + else: + self.codegen_body() + + for helper in self.helper_functions: + code.writeline("") + code.splice(helper) + + + # Note: override original triton_heuristics + if self.inside_reduction: + reduction_hint = self.features.get_reduction_hint() + heuristics_line = f""" + @npu_triton_heuristics.{heuristics}( + size_hints={size_hints}, + reduction_hint={reduction_hint}, + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + else: + tile_hint = "" + if len(size_hints) == 2: + if len(signature) == 4: # input, output and 2 args + tile_hint = "tile_hint=TileHint.SQUARE," + else: + tile_hint = "tile_hint=TileHint.DEFAULT," + heuristics_line = f""" + @npu_triton_heuristics.{heuristics}( + size_hints={size_hints!r}, {tile_hint} + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + min_elem_per_thread={self.min_elem_per_thread} + ) + @triton.jit + """ + code.splice(heuristics_line) + code.writeline( + f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):" + ) + with code.indent(): + self.codegen_static_numels(code) + for old, new in self.args.aliases(): + code.writeline(f"{old} = {new}") + code.splice(self.body) + + if config.benchmark_kernel: + code.splice(self.codegen_kernel_benchmark(num_gb)) + + return code.getvalue() + + + + def codegen_static_numels(self, code): + for symbol in self.reduction_axis_list(): + if symbol.name[0] != "r" or not self.persistent_reduction: + continue + + node = self.range_tree_nodes[symbol] + simplified_tree_numel = V.graph.sizevars.simplify(node.length) + if isinstance(simplified_tree_numel, (sympy.Integer, int)): + val = int(simplified_tree_numel) + else: + continue + val = next_power_of_2(val) + code.writeline(f"{node.name.upper()}BLOCK_SUB: tl.constexpr = {val}") + + + def lowest_axis_variable(self): + if len(self.tiling_axis) == 0 : + return None + return self.tiling_axis[-1] + + + def is_isolated_symbol(self, input_str, range): + patterns = [r'\b' + re.escape(range.name) + r'\b'] + for var in range.var_directions.keys(): + pattern = r'\b' + re.escape(var.name) + r'\b' + patterns.append(pattern) + + for pattern in patterns : + if re.search(pattern, input_str) : + return True + return False + + + def find_axis_in_load_store(self, range): + if not range : + return False + for line in self.loads._lines : + if line.find('tl.load') >= 0 and self.is_isolated_symbol(line, range): + return True + for line in self.compute._lines : + if line.find('tl.load') >= 0 and self.is_isolated_symbol(line, range): + return True + for line in self.post_loop_store._lines : + if line.find('tl.store') >= 0 and self.is_isolated_symbol(line, range): + return True + for line in self.stores._lines : + if isinstance(line,DeferredLine) : + line = line.line + if line.find('tl.store') >= 0 and self.is_isolated_symbol(line, range): + return True + return False + + def write_scalar(self): + self.body.splice(self.indexing_code) + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + self.loads.clear() + self.compute.clear() + self.stores.clear() + self.post_loop_store.clear() + self.prefix.clear() + + def codegen_body(self): + if not ( + self.loads + or self.stores + or self.compute + or self.post_loop_store + ): + return + + def write_pointwise() : + self.body.splice(self.indexing_code) + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + + def codegen_range(index) : + def is_1d_reduction() : + return self.numels["r"] > 1 and len(self.numels) == 1 + + def loop_body(index, indexing_code, is_last_axis, do_indent = True ) : + if do_indent: + self.body.do_indent() + if indexing_code : + self.body.splice(indexing_code) + if is_last_axis: + write_pointwise() + else: + codegen_range(index + 1) + if do_indent : + self.body.do_unindent() + + if index < 0 or index >= len(self.range_tree_nodes): + return + + range = self.sorted_axis[index] + numof_tilings = len(self.tiling_axis) + last_tiling = range.is_tiling_axis and numof_tilings >=1 and range.tiling_order == len(self.tiling_axis) -1 + next_is_dual_reduction_tiling = index == len(self.sorted_axis) - numof_tilings -1 and self.numof_reduction_axis() + + is_last_axis = index == len(self.sorted_axis) -1 + indexing_code = getattr(range, "indexing_code") + reduction_1d = is_1d_reduction() + do_indent = False + # do nothing except for writing porintwise + if len(self.loads._lines) == 0 and len(self.stores._lines) == 0: + do_indent = False + indexing_code = None + #loop_body(index, indexing_code, is_last_axis, do_indent = do_indent) + #return + # tiling axis and last tiling + if range.is_tiling_axis and last_tiling: + do_indent = False + need_axis_loop = self.find_axis_in_load_store(range) + if not need_axis_loop : + indexing_code = None + if (range.prefix != 'r' or not self.persistent_reduction) and need_axis_loop: + self.body.splice(self.prefix) + self.body.writeline(f"for loop_{range.name} in range(loops_{range.name}):") + do_indent = True + loop_body(index, indexing_code, is_last_axis, do_indent) + self.body.splice(self.post_loop_store) + self.post_loop_store.clear() + + # tiling axis and but not last tiling + elif range.is_tiling_axis : + do_indent = False + if len(self.loads._lines) == 0 and len(self.stores._lines) == 0: + do_indent = False + indexing_code = None + if self.numof_reduction_axis() <= 1 : + do_indent = True + self.body.writeline(f"for loop_{range.name} in range(loops_{range.name}):") + loop_body(index, indexing_code, is_last_axis, do_indent = do_indent) + + elif not is_last_axis : + do_indent = True + if range.is_split_axis : + offset = f"{range.name}_offset" + self.body.writeline(f"for {range.name} in range({offset}, " + f"min({offset} + {range.name.upper()}BLOCK, {range.name}_numel)):") + else : + self.body.writeline(f"for {range.name} in range({range.name}_numel):") + + if not reduction_1d and self.persistent_reduction : + self.body.do_indent() + self.body.splice(self.prefix) + self.prefix.clear() + self.body.do_unindent() + + loop_body(index, indexing_code, is_last_axis, do_indent = do_indent) + else : + write_pointwise() + + if self.first_node: + for node in self.sorted_axis: + node.codegen_header(self.body) + + while True : + if not self.sorted_axis[-1].is_tiling_axis : + x = self.sorted_axis[-1] + self.sorted_axis.pop(-1) + self.sorted_axis.insert(0, x) + else : + break + + if self.first_node: + codegen_range(0) + else : + last_axis_order = self.tiling_axis[-1].sorted_order + if self.persistent_reduction and self.numof_reduction_axis() > 1 : + last_axis_order = last_axis_order - self.numof_reduction_axis() + 1 + for _ in range(last_axis_order) : + self.body.do_indent() + codegen_range(last_axis_order) + for _ in range(last_axis_order) : + self.body.do_unindent() + + self.cse.invalidate(self.outside_loop_vars) + self.loads.clear() + self.compute.clear() + self.stores.clear() + self.post_loop_store.clear() + self.prefix.clear() + self.first_node = False + + # for creat constant tensor, if have two axis, constant=tl.full([1,1]) else tl.full([1]) + def triton_tensor_ndim(self): + if self.numof_reduction_axis() > 1 : + return 1 + + return len(self.tiling_axis) + + # fixme, indexing.mask_str is None , see varmean_test.py + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): + if not self.inside_reduction: + raise RuntimeError("assert self.inside_reduction") + + self.inside_reduction = False + indexing = self.indexing(index, block_ptr=True) + self.inside_reduction = True + var = self.args.output(name) + if isinstance(indexing, BlockPtrOptions): + self.post_loop_store.writeline( + DeferredLine( + name, + self.codegen_block_ptr_store_line( + name, + indexing, + indexing.format(var), + value, + f", boundary_check={indexing.boundary_check()!r}", + ), + ) + ) + else: + if not isinstance(indexing, IndexingOptions): + raise RuntimeError("assert isinstance(indexing, IndexingOptions)") + line = f"tl.store({var} + ({indexing.index_str} ), {value}, {indexing.mask_str})" + if self.numof_reduction_axis() > 1 : + line = f"tl.store({var} + ({indexing.index_str} + tl.arange(0,1) ), {value}, {indexing.mask_str})" + self.post_loop_store.writeline( + DeferredLine( name, line ) + ) + + + # apply new var in case dim are permuted/broadcast + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + + var = self.args.output(name) + original_index = index + index_analyze = IndexAnalysis(self, index, is_store_index=True) + index_analyze.analyze_index() + indexing = self.indexing(index, dense_indexing=True, block_ptr=mode is None, index_analyze=index_analyze) + index_str = indexing.index_str + value_str = f"{value}" + mask_str = indexing.mask_str + + if index_analyze.need_permute : + value_str = value_str.replace(f"{value}", f"{value}{index_analyze.generate_statement()}") + + advance_block_ptr = None + if isinstance(indexing, BlockPtrOptions): + block_ptr, advance_block_ptr, other = self.codegen_block_ptr( + name, var, indexing + ) + # block_ptr stores don't do implicit casting + line = self.codegen_block_ptr_store_line( + name, indexing, block_ptr, value, other + ) + elif mode is None: + line = f"tl.store({var} + ({index_str}), {value_str}, {mask_str})" + if self.numof_reduction_axis() > 1 : + line = f"tl.store({var} + ({index_str} + tl.arange(0,1) ), {value_str}, {indexing.mask_str})" + + elif mode == "atomic_add": + line = f"tl.atomic_add({var} + ({index_str}), {value_str}, {indexing.mask_str})" + else: + raise NotImplementedError(f"store mode={mode}") + + self.stores.writeline(DeferredLine(name, line)) + if advance_block_ptr: + self.stores.writeline(advance_block_ptr) + + if not self.inside_reduction: + self.outside_loop_vars.add(value) + + def find_reduction_node(self): + node = self.current_node + if node is not None and isinstance(node, SchedulerNode) : + reduction = node.node.data + if reduction is not None and isinstance(reduction, ir.Reduction) : + return reduction + + for node in self.node_schedule: + if node in (EnableReduction, DisableReduction): + continue + reduction = node.node.data + if reduction is not None and isinstance(reduction, ir.Reduction) : + return reduction + + return None + + # select the golden varlist, from to which to deduce permute, broadcast shape + def select_golden_varlist(self) : + longest = None + maximum_length = 0 + self.golden_var_list = None + def all_tiling_in_var_list(var_list) : + return all([x in var_list for x in self.tiling_axis]) + # all are load indexings, select the longest as gold + for index in self.load_store_indexing: + index = index.subs(V.graph.sizevars.var_to_val) + analyze = IndexAnalysis(self, index) + if len(analyze.var_list) > maximum_length and all_tiling_in_var_list(analyze.var_list) : + longest = analyze.var_list + maximum_length = len(longest) + #fixme , this may cause problems + if not longest : + self.golden_var_list = tuple([x.symbol() for x in self.tiling_axis]) if self.tiling_axis else [] + else : + self.golden_var_list = tuple([x for x in longest if x in self.tiling_axis]) if self.tiling_axis else [] + assert self.golden_var_list is not None + + # to generate shape of the tile + def dense_size_list(self) -> List[str]: + if self.inside_reduction : + if not self.reduce_analysis: + self.reduce_analysis = ReductionAnalysis(self) + return self.reduce_analysis.dense_size_list() + + if not self.golden_var_list : + self.select_golden_varlist() + + golden_var_list = self.golden_var_list if self.golden_var_list else [x.symbol() for x in self.tiling_axis] + assert golden_var_list is not None + #shape = range(len(self.golden_var_list)) + sizes = [None for _ in golden_var_list ] + for i, var in enumerate(reversed(golden_var_list)) : + axis = self.range_tree_nodes[var] + sizes[i] = f"{axis.name.upper()}BLOCK_SUB" + return sizes + + def dense_size_str(self): + if self.inside_reduction : + if not self.reduce_analysis: + self.reduce_analysis = ReductionAnalysis(self) + return self.reduce_analysis.dense_size_str() + sizes = self.dense_size_list() + return f"[{', '.join(sizes)}]" + + # and add to shape to value + def reduction_resize(self, value, dim): + ndims = self.triton_tensor_ndim() + if ndims == 1: + return f"triton_helpers.promote_to_tensor({value})" + dense_list = self.dense_size_list() + dense_list[dim] = "1" + expand_str = ", ".join(dense_list) + return f"{value}.reshape({expand_str})" + #return f"{value}" + + + # FIXME, to determine reduction_dim + def reduction_dim(self): + if not self.reduce_analysis: + self.reduce_analysis = ReductionAnalysis(self) + return self.reduce_analysis.reduced_dim + + def filter_masks(self, mask_vars): + for node in self.sorted_axis: + if not(node.is_tiling_axis ): + mask_vars.discard(f"{node.name}_mask") + + def numof_reduction_axis(self): + root = self.range_trees[-1] + if root is None : + return 0 + + return len(root.var_list) + + + def reduction_axis_list(self): + root = self.range_trees[-1] + if root is None : + return [] + return root.var_list + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, Tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: + if not self.inside_reduction: + raise RuntimeError("assert self.inside_reduction") + masks = {f"{node.symbol()}_mask" for node in self.sorted_axis} + self.filter_masks(masks) + masks = sorted(masks) + if self._load_mask: + masks.append(self._load_mask) + reduction_range_prefix = self.range_trees[-1].prefix + if not self.reduce_analysis: + self.reduce_analysis = ReductionAnalysis(self) + dense_size_str = self.dense_size_str() + + if len(dense_size_str) > 2: + value = self._map_tuple_or_scalar( + lambda v: self.cse.generate( + self.compute, f"tl.reshape({v}, {dense_size_str})", dtype=v.dtype, + ), + value, + + ) + + dim: int + root_op: str + + def final_reduction(value): + #use_helper = reduction_type in {"any", "max", "min", "prod"} + module = "tl" # use tl + if reduction_type in {"max", "min"}: + return self.reduction_resize(f"{module}.{reduction_type}({value}, {dim})", dim) + return self.reduction_resize(f"{module}.{reduction_type}({value}, {dim})", dim) + + def final_argreduce(buffer, result_var, value, index): + buffer.splice( + f"""\ + _, {result_var}_tmp = triton_helpers.{root_op}_with_index({value}, {index}, {dim}) + {result_var} = {self.reduction_resize(f'{result_var}_tmp', dim)} + """ + ) + + def get_reduction_axis() : + return list(self.range_tree_nodes.values())[-1] + + cache_key = (src_dtype, reduction_type, value) + if cache_key in self.cse.reduction_cache: + return self.cse.reduction_cache[cache_key] + + dim = self.reduction_dim() + acc_type = triton_acc_type(src_dtype) + torch_acc_type = upcast_acc_dtype(src_dtype) + result_var: Any = self.cse.newvar(dtype=torch_acc_type) + result_var.mask_vars = {var for var in masks if var[0] != "r"} + cond = " & ".join(masks) + + + def where_cond(tval, fval): + if not cond: + return tval + return TritonKernelOverrides.where(cond, tval, fval) + + if self.persistent_reduction: + default = ir.Reduction.default_value(reduction_type, src_dtype) + default = self._map_tuple_or_scalar(constant_repr, default) + + def _mask_value(value, default): + return self.cse.generate(self.compute, where_cond(value, default) , dtype=value.dtype) + # fixme masked_value doesn't work dual reduction + if self.numof_reduction_axis() == 1 : + if isinstance(value, tuple): + masked_value = [_mask_value(v, d) for v, d in zip(value, default)] + else: + masked_value = _mask_value(value, default) + else : + masked_value = value + + if reduction_type in {"argmax", "argmin", "max", "min"}: + reduce_axis = get_reduction_axis() + broadcast_string: str + reshape_str = self.reduce_analysis.get_reduce_dim_reshape(reduce_axis) + broadcast_string = f"tl.broadcast_to({reduce_axis.symbol()}.reshape({reshape_str}), {masked_value}.shape)" + accumulator_index = str( + self.cse.generate( + self.compute, + broadcast_string, + dtype=torch.int64 + ) + ) + if reduction_type == "argmax" or reduction_type == "argmin": + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + final_argreduce( + self.compute, result_var, masked_value, accumulator_index + ) + elif reduction_type == "max" or reduction_type == "min": + result_var = self.cse.generate( + self.compute, final_reduction(masked_value), dtype=masked_value.dtype, + ) + elif reduction_type == "welford_reduce": + raise RuntimeError("assert False, welford_reduction and is not supported now..") + elif reduction_type == "welford_combine": + raise RuntimeError("assert False, welford_combine and is not supported now..") + else: + result_var = self.cse.generate( + self.compute, final_reduction(masked_value), dtype=masked_value.dtype, + ) + else: + accumulator = self.cse.namedvar(f"_{result_var}", dtype=torch_acc_type) + default = ir.Reduction.default_accumulator(reduction_type, src_dtype) + default = self._map_tuple_or_scalar(constant_repr, default) + if not isinstance(default, tuple): + self.prefix.writeline( + f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {acc_type})" + ) + + if reduction_type in {"argmax", "argmin"}: + accumulator_index = f"_{result_var}_index" + long_max = torch.iinfo(torch.int64).max + self.prefix.writeline( + f"{accumulator_index} = tl.full({self.dense_size_str()}, {long_max}, tl.int64)" + ) + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + + self.compute.splice( + f"""\ + {accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index( + {accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index + ) + {accumulator} = {where_cond(f'{accumulator}_next', accumulator)} + {accumulator_index} = {where_cond(f'{accumulator_index}_next', accumulator_index)} + """ + ) + final_argreduce(self.post_loop_store, result_var, accumulator, accumulator_index) + elif is_welford_reduction(reduction_type): + raise RuntimeError("assert False, welford_reduction and is not supported now..") + else: + combine_fn = ir.get_reduction_combine_fn(reduction_type, src_dtype) + updated = combine_fn(accumulator, value) + self.compute.writeline( + f"{accumulator} = {where_cond(updated, accumulator)}" + ) + + if src_dtype == torch.bool: + accumulator = f"{accumulator}.to(tl.int8)" + result_type = triton_compute_type(dtype) + self.post_loop_store.writeline( + f"{result_var} = {final_reduction(accumulator)}.to({result_type})" + ) + else: + self.post_loop_store.writeline( + f"{result_var} = {final_reduction(accumulator)}" + ) + + self.cse.reduction_cache[cache_key] = result_var + + if isinstance(result_var, tuple): + self.outside_loop_vars |= set(result_var) + else: + self.outside_loop_vars.add(result_var) + + return result_var + + #broadcast, permute handling + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + original_index = index + store_cache = self.cse.store_cache + if name in store_cache: + index_analyze = IndexAnalysis(self, index) + index_analyze.analyze_index() + result_var = store_cache[name] + if index_analyze.need_permute: + line = f"{result_var}{index_analyze.generate_statement()}" + buffer = self.compute if self.persistent_reduction else self.loads + result_var = self.cse.generate(buffer, line, dtype=result_var.dtype) + return result_var + + index_analyze = IndexAnalysis(self, index) + index_analyze.analyze_index() + indirect_indexing = self.is_indirect_indexing(index) + indexing = self.indexing(index, block_ptr=True) + has_rindex = indexing.has_rindex() + has_tmpmask = indexing.has_tmpmask() + is_coalesced = any( + i == 1 for i in self.get_strides_of_load(original_index).values() + ) + ep = "" + if ( + (has_tmpmask or has_rindex) + and V.graph.get_dtype(name) != torch.bool + and indexing.has_mask() + ): + other = ", other=0.0" + else: + other = "" + + advance_block_ptr = None + append_broadcast = None + dtype = V.graph.get_dtype(name) + + if V.graph.is_unspec_arg(name): + line = var + else: + if isinstance(indexing, BlockPtrOptions): + block_ptr, advance_block_ptr, other = self.codegen_block_ptr( + name, var, indexing, other + ) + line = f"tl.load({block_ptr}{other}{ep})" + # add needed size=1 dimensions + line = triton_reshape( + line, indexing.block_shape, indexing.reshape_suffix + ) + elif isinstance(original_index, sympy.Integer): + line = f"tl.load({var} + ({original_index}))" + full_list = ["1"] * (len(self.tiling_axis) if self.tiling_axis else 1 ) + append_broadcast = f"[{', '.join(full_list)} ]" + else: + index_str = indexing.index_str + mask_str = indexing.mask_str + line = f"tl.load({var} + ({index_str}), {mask_str}{ep}{other})" + + dtype = V.graph.get_dtype(name) + if dtype in (torch.bfloat16, ): + line += ".to(tl.float32)" + if dtype == torch.bool and torch.version.hip is None: + line += ".to(tl.int1)" + if has_tmpmask: + # Masked loads must come after the mask is computed + load_buffer = self.compute + elif ( + self.inside_reduction + and self.range_trees[-1].is_loop + and not indirect_indexing + and not has_rindex + ): + # can lift a common load outside of reduction loop + # One exception is when this is an indirect_load. + load_buffer = self.prefix + + else: + load_buffer = self.loads + + result_var = self.cse.generate(load_buffer, line, dtype=dtype) + if not (isinstance(result_var, TritonCSEVariable)): + raise RuntimeError("assert isinstance(result_var, TritonCSEVariable)") + result_var.mask_vars = indexing.mask_vars # type: ignore[assignment] + + if append_broadcast and append_broadcast != '[]': + line = f"tl.broadcast_to({result_var}, {append_broadcast})" + result_var = self.cse.generate(load_buffer, line, dtype = dtype) + # triton can handle broadcast + # elif need_broadcast and not indirect_indexing: + # line = f"{result_var}.broadcast_to({self.get_broadcast_dense_str(broadcast_shape)})" + # result_var = self.cse.generate(load_buffer, line, dtype = dtype) + elif index_analyze.need_permute : + line = f"{result_var}{index_analyze.generate_statement()}" + result_var = self.cse.generate(self.loads, line, dtype = dtype) + + if advance_block_ptr: + load_buffer.writeline(advance_block_ptr) + + if not self.inside_reduction or (not indexing.has_rmask() and not has_rindex): + self.outside_loop_vars.add(result_var) + + return result_var + + # don't call symlify_indexing + def prepare_indexing( + self, + index: sympy.Expr, + index_analyze + ): + #index = self.simplify_indexing(index) + index = sympy_subs(index, V.graph.sizevars.precomputed_replacements) + # if simple replacements didn't get rid of floor/ceil, try full subs + if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)): + index = index.subs(V.graph.sizevars.precomputed_replacements) + + if len(index.atoms(sympy.ceiling)): + for a in index.atoms(sympy.ceiling): + # for nested exprs, atoms yields top level first (?) + # so if everything goes fine, lower level replacements will come up empty + symbols = a.free_symbols + if len(symbols) > 0 and all( + symbol_is_type(s, (SymT.SIZE, SymT.PRECOMPUTED_SIZE)) + for s in symbols + ): + replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)} + index = sympy_subs(index, replacements) + + #simp_index = self.simplify_indexing(index) + simp_index = index + + simp_index = ( + simp_index if not isinstance(simp_index, Identity) else simp_index.args[0] + ) + + # to generate range.var_directions for permuted axis + index_analyze.analyze_index() + return self.codegen_indexing(simp_index) + + + def replace_index_vars(self, index, index_analyze) : + + new_index = index + if index_analyze.var_replacements : + new_index = sympy_subs(index, index_analyze.var_replacements) + return new_index + + + def index_to_str(self, index: sympy.Expr) -> str: + if isinstance(index, list): + return f"[{', '.join(map(self.index_to_str, index))}]" + index = self.rename_indexing(index) + return self.kexpr(index) # type: ignore[call-arg] + + #1. only remove the line which asserts index var should be in "xyr" + #2. don't do simplify_indexing, which combine continuous dims + #3. removed block_ptr, removed dense mask/broadcast support + # fixme, dense_mask_vars should be generated from sorted_axis + # upgraded to torch251 + def indexing( + self, + index: sympy.Expr, + *, + copy_shape=None, + dense_indexing=False, + override_mask=None, + block_ptr=False, + index_analyze = None + ) -> Union[IndexingOptions, BlockPtrOptions]: + """ + Compute the index and mask to pass to tl.load() or tl.store() + """ + if not index_analyze : + index_analyze = IndexAnalysis(self, index) + index_analyze.analyze_index() + + index = self.prepare_indexing(index, index_analyze) + index_vars = index.free_symbols + has_rindex = False + #index = self.simplify_indexing(index) + index = sympy_subs(index, V.graph.sizevars.precomputed_replacements) + # if simple replacements didn't get rid of floor/ceil, try full subs + if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)): + index = index.subs(V.graph.sizevars.precomputed_replacements) + if len(index.atoms(sympy.ceiling)): + for a in index.atoms(sympy.ceiling): + # for nested exprs, atoms yields top level first (?) + # so if everything goes fine, lower level replacements will come up empty + symbols = a.free_symbols + if len(symbols) > 0 and all( + s.name.startswith("s") or s.name.startswith("ps") for s in symbols + ): + replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)} + index = sympy_subs(index, replacements) + + #if not self.inside_reduction : + index = self.replace_index_vars(index, index_analyze) + #index = self.simplify_indexing(index) + index_vars = index.free_symbols + has_rindex = False + + mask_vars: Set[str] = set() + for var in index_vars: + if not (isinstance(var, sympy.Symbol)): + raise RuntimeError("assert isinstance(var, sympy.Symbol)") + + has_rindex = has_rindex or var.name.startswith("r") + if override_mask: + pass + elif var.name.startswith("tmp"): + # indirect indexing + cse_var = self.cse.varname_map[var.name] + mask_vars.update(cse_var.mask_vars) + elif var.name.startswith(("s", "ps", "i")): + pass + else: + # var is one of xN, yN or rN + mask_vars.add(f"{var.name}_mask") + + expand_str = None + index_str = self.index_to_str(index) + + if isinstance(index, sympy.Integer): + expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() + if (index != 0): + index_str = f"tl.full({expand_str}, {index_str}, tl.int32)" + else: + index_str = f"tl.arange(0,1)" + return IndexingOptions(index_str, set(), "None", expand_str, has_rindex, index) + + if override_mask: + mask_vars = {override_mask} + if self._load_mask: + mask_vars.add(self._load_mask) + self.filter_masks(mask_vars) + mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None" + return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex, index) # type: ignore[arg-type] + + + + def codegen_indexing(self, expr: sympy.Expr): + expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges()) + for sym in sorted(expr.free_symbols, key=str): + if sym in self.range_tree_nodes: + # if indexing expression is complicated, we precompute it on the host side + # and send the result as a kernel argument + replacements = {} + for ps in self.range_tree_nodes[sym].precomputed_args(): # type: ignore[index] + replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps) + if len(replacements) > 0: + self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index] + self.range_tree_nodes[sym].expr, replacements # type: ignore[index] + ) + self.range_tree_nodes[sym].codegen() # type: ignore[index] + return expr + + #FIXME, when xindex(16) -> x2:2,x3:8, when new length:16 in , should return (x2,x3) + def split_and_set_ranges(self, lengths: Sequence[Sequence[sympy.Expr]]): + groups = [rt.numel for rt in self.range_trees] + if not self.inside_reduction: + groups[-1] = sympy.S.One + + return self.map_kernel_groups_to_node_sizes(groups, lengths, self.set_ranges) + + #support split multiple ranges (instead of double) from one flatten range, triple-ranges are needed in mamba model + @staticmethod + def _split_iteration_ranges( + groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]] + ): + sv = V.graph.sizevars + new_ranges: List[List[sympy.Expr]] = [[] for _ in groups] + remaining = [sv.simplify(g) for g in groups] + for i, group in enumerate(remaining): + if isinstance(group, (list, tuple)): + remaining[i] = NumelList(group).numels() + + var_count = itertools.count() + + def add_range(i, expr): + expr = sv.simplify(expr) + if not sv.statically_known_multiple_of(remaining[i], expr): + raise CantSplit() + # guard on the last item out + remaining[i] = FloorDiv(remaining[i], expr) + new_ranges[i].append(expr) + return next(var_count) + + def make_combined(strides, index_list): + def getter(flat_vars): + expr = sympy.Integer(0) + for stride, index in zip(strides, index_list): + expr = stride * flat_vars[index] + expr + return expr + + return getter + + def size_hints(group): + if isinstance(group, (list, tuple)): + return sv.size_hint(NumelList(group).numels()) + return sv.size_hint(group) + + def add_multiple_range(size, return_getters): + # need to break size in multiple + index_list = [] + stride_list = [] + group = current_group + remained_size = size + # Two checks: + # 1. remaining sizes to be merged + # 2. remained_size is already divided to 1 + while (group < len(remaining) and remaining[group] > 1) and (remained_size > 1): + group_size = remaining[group] + # size should be divisible by group_size + if not sv.statically_known_multiple_of(remained_size, group_size): + raise CantSplit() + index_list.append(add_range(group, group_size)) + remained_size = FloorDiv(remained_size, group_size) + stride_list.append(remained_size) + group = group + 1 + if remained_size != 1: + raise CantSplit() + return_getters.append(make_combined(stride_list, index_list)) + + return_getters_groups = [] + current_group = 0 + + for length_group in lengths: + return_getters = [] + for size in length_group: + if sv.statically_known_equals(size, 1): # type: ignore[arg-type] + return_getters.append(lambda _: sympy.Integer(0)) + continue + + while ( + current_group < len(remaining) + and size_hints(remaining[current_group]) == 1 + ): + # scroll to next group with remaining elements + current_group += 1 + size_hint = sv.size_hint(size) + if size_hint > size_hints(remaining[current_group]): + #add multiple ranges (two or more) to the list, as well as the getter funcs + add_multiple_range(size_hint, return_getters) + else: + return_getters.append( + operator.itemgetter(add_range(current_group, size_hint)) + ) + return_getters_groups.append(return_getters) + + if not (all(V.graph.sizevars.size_hint(s) == 1 for s in remaining)): + raise RuntimeError("assert all(V.graph.sizevars.size_hint(s) == 1 for s in remaining)") + + return new_ranges, return_getters_groups + + # torch260 done + # just to override load method of CSEProxy, however, CSEProxy is an inner which can not be monkey patched, + # we need to override the whole inner class + def __enter__(self): + class CSEProxy: + self.name = "CSEProxy" + vr_analysis = ValueRangeAnalysis() + + @staticmethod + def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc] + def inner(*args, **kwargs): + bounds = CSEProxy._bound_variable(name, *args, **kwargs) + + value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type] + dtype_handler = DtypePropagationOpsHandler() + + output_idx = 0 + + def do_cse(v): + # cpp backend doesnt set current device - TODO: fix + if V.graph.current_device is not None: + device_str = V.graph.get_current_device_or_throw().type + triton_backend = ( + config.cpu_backend == "triton" + if device_str == "cpu" + else config.cuda_backend == "triton" + ) + else: + triton_backend = False + + # only triton backend tracks dtype currently + if triton_backend: + if name == "masked": + output_dtype = value.dtype + else: + output_dtype = getattr( + dtype_handler, + name, + )(*args, **kwargs) + else: + # cpp backend doesnt track dtype yet + output_dtype = None + + csevar = V.kernel.cse.generate( + V.kernel.compute, + v, + bounds=bounds, + dtype=output_dtype, + ) + + nonlocal output_idx + if ( + config.test_configs.runtime_triton_dtype_assert + and triton_backend + ): + from torch._inductor.codegen.triton import triton_type + + # we tree_map over the output, so we need to fetch corresponding dtype + if isinstance(output_dtype, (list, tuple)): + output_dtype = output_dtype[output_idx] + + V.kernel.compute.writeline( + f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})" + ) + output_idx += 1 + + csevar.update_on_args(name, args, kwargs) + + return csevar + + return pytree.tree_map(do_cse, value) + + return inner + + @staticmethod + def _bound_variable(name, *args, **kwargs): + """ + If the variable comes from an FX node, we forward the bound we have already computed + Else, if the variable when codegen'ing another op, we try to compute its bounds + """ + from torch._inductor.select_algorithm import TritonTemplateKernel + + if isinstance(V.kernel, TritonTemplateKernel): + return ValueRanges.unknown() + + fx_node = V.interpreter.current_node + if fx_node.target == name and self.node_to_bounds is not None: + if not (isinstance(self.node_to_bounds, dict)): + raise RuntimeError("assert isinstance(self.node_to_bounds, dict)") + + return self.node_to_bounds.get(fx_node, ValueRanges.unknown()) + elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name): + # These create lots of inner strings. We would need to compute the bounds at the ops + # We will also likely not get much from computing VRs on these nodes + if any( + s in fx_node.target + for s in ("set_indirect", "reduction", "scan") + ): + return ValueRanges.unknown() + + # We assume that the inputs come from `ops.` and are not strings. If you want to generate + # intermediary strings, wrap them in CSE variables with properly initialised bounds. + + # If there is no FX bound but we know how to compute one we do so + if (kwargs): + raise RuntimeError("assert not kwargs") + + def arg_to_bound(x): + if isinstance(x, CSEVariable): + return x.bounds + elif isinstance(x, sympy.Expr): + return bound_sympy(x) + else: + return x + + arg_bounds = list(map(arg_to_bound, args)) + return getattr(CSEProxy.vr_analysis, name)(*arg_bounds) + return ValueRanges.unknown() + + @staticmethod + def indirect_indexing( + var: CSEVariable, + size: Union[sympy.Expr, int], + check: bool = True, + wrap_neg=True, + ): + if isinstance(size, int): + size = sympy.Integer(size) + if not (isinstance(size, sympy.Expr)): + raise RuntimeError("assert isinstance(size, sympy.Expr), size") + # Skip CSE since this doesn't return an expression + + if var.bounds.lower < 0: # type: ignore[operator] + if wrap_neg: + stm = ops.add(var, ops.index_expr(size, torch.long)) + # Mixed negative and non-negative + if var.bounds.upper >= 0: # type: ignore[operator] + lt = ops.lt(var, 0) + stm = ops.where(lt, stm, var) + else: + stm = var + + # Propagate bounds as we know how to compute them properly + new_bounds = ValueRanges.unknown() + if var.bounds != ValueRanges.unknown() and isinstance( + size, sympy.Number + ): + # Take the negative part of the bound and add size to it + # Then take union of that and the positive part + # This is a tighter bound than that of a generic ops.where, as we have info on the cond + neg_bounds = var.bounds & ValueRanges(-int_oo, -1) + new_bounds = ValueRanges( + neg_bounds.lower + size, neg_bounds.upper + size + ) + # We don't have a good way of representing the empty range + if var.bounds.upper >= 0: # type: ignore[operator] + pos = var.bounds & ValueRanges(0, int_oo) + new_bounds = new_bounds | pos + + var = self.cse.generate(self.compute, stm, bounds=new_bounds) + + sympy_var = parent_handler.indirect_indexing(var, size, check) + if generate_assert(check): + assert_lower = not (var.bounds.lower >= 0) + # value ranges cannot x < s when x and s are symbols + assert_upper = not isinstance(size, sympy.Number) or not ( + var.bounds.upper < size + ) + self.check_bounds(sympy_var, size, assert_lower, assert_upper) + return sympy_var + + @staticmethod + def check_bounds( + expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ): + return self.check_bounds(expr, size, lower, upper) + + @staticmethod + def load(name: str, index: sympy.Expr) -> CSEVariable: + if name in self.cse.invalidated_stores: + # A load from an invalidated store requires us to + # keep the actual buffer around + V.kernel.must_keep_buffers.add(name) + if free_symbol_is_type(index, SymT.TMP): + return self.indirect_load(name, index) + store_cache = self.cse.store_cache + if name in store_cache: + return self.load(name, index) + out = self.load(name, index) + # count load that is not in the store_cache, and also not in the + # cse cache. + if out.use_count == 1: + self.num_load += 1 + return out + + @staticmethod + def _update_store_cache(name: str, value: CSEVariable): + self.cse.store_cache[name] = value + if self.current_node and name in V.graph.name_to_buffer: + buf = self.current_node.get_output(name) + for other_name in buf.get_mutations(): + self.cse.store_cache[other_name] = value + + @staticmethod + def store( + name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + self.store_buffer_names.add(name) + if mode is None: + CSEProxy._update_store_cache(name, value) + if name not in V.graph.removed_buffers: + return self.store(name, index, value, mode=mode) + return None # type: ignore[return-value] + + @staticmethod + def store_reduction(name: str, index: sympy.Expr, value: CSEVariable): + self.store_buffer_names.add(name) + CSEProxy._update_store_cache(name, value) + + if name not in V.graph.removed_buffers: + return self.store_reduction(name, index, value) + raise RuntimeError("store_reduction") + + @staticmethod + def reduction( + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, Tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: + self.num_reduction += 1 + return self.reduction(dtype, src_dtype, reduction_type, value) + + @staticmethod + def scan( + dtypes: Tuple[torch.dtype, ...], + combine_fn: Callable[ + [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], + Tuple[CSEVariable, ...], + ], + values: Tuple[CSEVariable, ...], + ) -> Tuple[CSEVariable, ...]: + return self.scan(dtypes, combine_fn, values) + + @staticmethod + def sort( + dtypes: Tuple[torch.dtype, ...], + values: Tuple[CSEVariable, ...], + stable: bool, + descending: bool, + ) -> Tuple[CSEVariable, ...]: + return self.sort(dtypes, values, stable, descending) + + @staticmethod + def bucketize( + values: CSEVariable, + boundaries: Tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: CSEVariable, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[Tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[CSEVariable] = None, + ) -> CSEVariable: + return self.bucketize( + values, + boundaries, + boundary_indices, + indexing_dtype, + right, + sorter, + sorter_indices, + ) + + # Use mypy to check protocol implemented correctly + def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]: + return h + + super().__enter__() + if not (self.overrides): + raise RuntimeError("assert self.overrides") + parent_handler = self.overrides(V.get_ops_handler()) + self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) + self.exit_stack.enter_context(V.set_kernel_handler(self)) + return self diff --git a/torch_npu/_inductor/codegen/triton_utils.py b/torch_npu/_inductor/codegen/triton_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5acd971ba027980ccede319c0a8870b1419ae37a --- /dev/null +++ b/torch_npu/_inductor/codegen/triton_utils.py @@ -0,0 +1,29 @@ + +import torch + +# wrapper npu 32 bytes align, get and pass unalign info to triton meta +# then autotune choose tiling param and send them to bishengIR +byte_per_numel = { + torch.float32: 4, # torch.float32 or torch.float + torch.float64: 8, # torch.float64 or torch.double + torch.float16: 2, # torch.float16 or torch.half + torch.bfloat16: 2, # torch.bfloat16 + torch.int32: 4, # torch.int32 or torch.int + torch.int64: 8, # torch.int64 or torch.long + torch.int16: 2, # torch.int16 or torch.short + torch.int8: 1, # torch.int8 + torch.uint8: 1, # torch.uint8 + torch.bool: 1, # torch.bool + torch.complex32: 4, # torch.complex32 (not yet available in PyTorch as of the latest stable release) + torch.complex64: 8, # torch.complex64 + torch.complex128: 16 # torch.complex128 +} + + +def get_aligned_numel(dtype): + if dtype in byte_per_numel: + return 32 // byte_per_numel[dtype] + else: + return 1 + + diff --git a/torch_npu/_inductor/codegen/wrapper.py b/torch_npu/_inductor/codegen/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6772b5e58e985a8035134fa620eeb117afb96b --- /dev/null +++ b/torch_npu/_inductor/codegen/wrapper.py @@ -0,0 +1,87 @@ +from torch._inductor.codegen.wrapper import PythonWrapperCodegen, SymbolicCallArg, SubgraphPythonWrapperCodegen +from torch._inductor.virtualized import V +from torch._inductor.utils import ( + cache_on_self, +) +from torch._inductor.runtime import triton_heuristics +from torch._inductor import config +import copy + +class NPUWrapperCodeGen(PythonWrapperCodegen): + def __init__(self): + super().__init__() + + @staticmethod + def create( + is_subgraph: bool, subgraph_name: str, parent_wrapper: PythonWrapperCodegen + ): + if is_subgraph: + return SubgraphPythonWrapperCodegen(subgraph_name, parent_wrapper) + return NPUWrapperCodeGen() + + def write_header(self) -> None: + super().write_header() + self.imports.splice( + f""" + import torch_npu + """, + strip=True, + ) + + @cache_on_self + def write_triton_header_once(self) -> None: + import_str = f""" + import triton + import triton.language as tl + from {triton_heuristics.__name__} import ( + split_scan_grid, + grid_combo_kernels, + start_graph, + end_graph, + cooperative_reduction_grid, + ) + from torch_npu._inductor.npu_triton_heuristics import grid + import torch_npu + """ + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.splice(import_str) + self.kernel_autotune_calls.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + if not V.graph.cpp_wrapper: + self.imports.splice(import_str, strip=True) + self.imports.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + + #generate numel expr for range_tree_node + def generate_node_numel_expr(self, kernel_name: str, node, numel_expr): + expr = f"{kernel_name}_{node.name}_numel" + if (expr, V.graph) not in self.kernel_numel_expr: + # declare expr once in each graph (scope) + self.kernel_numel_expr.add((expr, V.graph)) + self.writeline( + f"{self.declare}{expr} = {self.expr_printer(numel_expr)}{self.ending}" + ) + else: + self.writeline(f"{expr} = {self.expr_printer(numel_expr)}{self.ending}") + # We can get symbolic expressions here, like s0*64 + # It is fine to have them here, but we need to handle them correctly as their own type + # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy* + # scalars as well. + # This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for + # constant now, need type info. I agree, this needs type info, and while this is not true type info + # it suffices as a type hint for the purposes of producing the correct code for this type. + return SymbolicCallArg(expr, numel_expr) + + # don't free anything + def make_buffer_free(self, buffer): + return "" + + # don't assert + def codegen_input_size_asserts(self) -> None: + pass + + def get_next_kernel_suffix(self) -> str: + iter = copy.copy(self._names_iter) + return f"{next(iter)}" diff --git a/torch_npu/_inductor/config.py b/torch_npu/_inductor/config.py new file mode 100644 index 0000000000000000000000000000000000000000..e7bc04606560950e6cea939803b88bfc47c19b2b --- /dev/null +++ b/torch_npu/_inductor/config.py @@ -0,0 +1,58 @@ +import os # noqa: C101 +import torch +import logging +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING +from triton.runtime.driver import driver +from torch._inductor import config +enable_npu_indexing = True + +config.triton.unique_kernel_names = True +# avoid test_opensora_cases_model_16_forward reinterpre_tensor issue +config.allow_buffer_reuse = False +#inductor debug switch +config.trace.enabled = True + +# npu hardware params from trion +target = driver.active.get_current_target() +device = driver.active.get_current_device() +prop = driver.active.utils.get_device_properties(device) + +num_cube_core = prop["num_aicore"] +num_vector_core = prop["num_aicore"] + +# unit byte +npu_block = 32 + +traced_fx_graph_cache = os.environ.get("INDUCTOR_ASCEND_FX_GRAPH_CACHE", None) +check_accuracy = os.environ.get("INDUCTOR_ASCEND_CHECK_ACCURACY", False) +auto_fallback = os.environ.get("INDUCTOR_ASCEND_AUTO_FALLBACK", True) +fallback_warning = os.environ.get("INDUCTOR_ASCEND_FALLBACK_WARNING", False) + +acc_comp_tol = { + torch.float32: {'rtol': 1.3e-6, 'atol': 1e-5}, + torch.float16: {'rtol': 1e-3, 'atol': 1e-5}, + torch.bfloat16: {'rtol': 1.6e-2, 'atol': 1e-5}, + "default": {'rtol': 1.3e-6, 'atol': 1e-5}, +} + +if ("Ascend910B" in target.arch): + num_vector_core = num_cube_core * 2 + +log_level_env = os.getenv('INDUCTOR_ASCEND_LOG_LEVEL', 'INFO').upper() +log_level_mapping = { + 'DEBUG': logging.DEBUG, + 'INFO': logging.INFO, + 'WARNING': logging.WARNING, + 'ERROR': logging.ERROR, + 'CRITICAL': logging.CRITICAL +} +log_level = log_level_mapping.get(log_level_env.upper(), logging.INFO) +logging.basicConfig( + level=log_level, + format='%(asctime)s - %(levelname)s - %(message)s' +) +log = logging.getLogger(__name__) + +aggresive_autotune = os.getenv("INDUCTOR_ASCEND_AGGRESSIVE_AUTOTUNE", '0').lower() in ('1', 'true') +inductor_static_mode = os.environ.get('INDUCTOR_STATIC_MODE', '0').lower() in ('1', 'yes', 'true') +profile_path = "./profile_result/" \ No newline at end of file diff --git a/torch_npu/_inductor/decomposition.py b/torch_npu/_inductor/decomposition.py new file mode 100644 index 0000000000000000000000000000000000000000..17a9b00adc5c9c91dc01b7a253e1002e15bcd8ed --- /dev/null +++ b/torch_npu/_inductor/decomposition.py @@ -0,0 +1,50 @@ +from torch._inductor.decomposition import decompositions, pw_cast_for_opmath +from torch._inductor.decomposition import register_decomposition +import torch._ops +from .lowering import _init_set + + +aten = torch.ops.aten + +DECOMPOSITION_OVERLOAD_OP = [ + aten._log_softmax, + aten.nll_loss_forward, + # aten.gelu_backward, + # aten.gelu, + aten.nll_loss_backward, + aten._log_softmax_backward_data, + aten.embedding_dense_backward, + aten.addmm, + aten.gelu +] + + +def _register_npu_inductor_decompositons(): + + overload_op_set = set() + _init_set(DECOMPOSITION_OVERLOAD_OP, overload_op_set) + + for op in overload_op_set: + if (op in decompositions): + del decompositions[op] + + @register_decomposition([aten.scatter.src]) + @pw_cast_for_opmath + def scatter_src(self, input_tensor, dim, index_tensor, source_tensor): + (XNUMEL, YS) = input_tensor.shape + index_rblock = torch.arange(YS).npu().reshape((1, YS)).repeat((XNUMEL, 1)) + + index_tensor_brd = index_tensor.to(torch.int32).broadcast_to(XNUMEL, YS) + source_tensor_brd = source_tensor.broadcast_to(XNUMEL, YS).to(torch.float32) + scatter1 = torch.where(index_rblock == index_tensor_brd, 1.0, 0.0) * source_tensor_brd + return scatter1 + + @register_decomposition([aten.expm1]) + def expm1(x): + tensor = torch.exp(x) - torch.ones_like(x) + return tensor + + @register_decomposition([aten.erfc]) + def erfc(x): + tensor = torch.ones_like(x) - torch.exp(x) + return tensor \ No newline at end of file diff --git a/torch_npu/_inductor/dynamo_embedding_backward_dispatch.py b/torch_npu/_inductor/dynamo_embedding_backward_dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..6584c99e7a864619f73b564a060cafc52f3654fa --- /dev/null +++ b/torch_npu/_inductor/dynamo_embedding_backward_dispatch.py @@ -0,0 +1,10 @@ +import torch +from torch.library import Library, impl +python_dispatcher_lib = Library("aten", "IMPL", "PythonDispatcher") + + +@impl(python_dispatcher_lib, "embedding_backward") +def embedding_backward(grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse): + if sparse: + raise RuntimeError("the current NPU does not yet support sparse tensor, when sparse is set to True") + return torch.ops.aten.embedding_dense_backward(grad, indices, num_weights, padding_idx, scale_grad_by_freq) \ No newline at end of file diff --git a/torch_npu/_inductor/lowering.py b/torch_npu/_inductor/lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..fb1dd5ce1e62cc0d0d2df559a1ac02f5632c3fff --- /dev/null +++ b/torch_npu/_inductor/lowering.py @@ -0,0 +1,362 @@ +import sympy +from torch._inductor.ir import Reduction +from torch._inductor.utils import sympy_product +from torch._inductor import ir +from torch._inductor.ir import ExpandView, TensorBox, ops_wrapper +from torch._inductor.lowering import sum_ +from torch._inductor import lowering +from torch._prims_common import ( + is_boolean_dtype, + is_integer_dtype, + get_computation_dtype, +) +from torch._inductor.decomposition import decompositions, pw_cast_for_opmath +import torch._ops + + +def make_reduction(reduction_type: str, override_return_dtype=None): + def inner(x, axis=None, keepdims=False, *, dtype=None): + kwargs = _make_reduction_inner( + x, + axis=axis, + keepdims=keepdims, + dtype=dtype, + override_return_dtype=override_return_dtype, + ) + result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs) + if isinstance( + result.data.data, Reduction + ): #Only realize if reduction isn't unrolled + size = x.get_size() + axis = set(_validate_reduction_axis(x, axis)) + kept_idx = [] + reduced_idx = [] + for i in range(len(size)): + if i in axis: + reduced_idx.append(i) + else: + kept_idx.append(i) + + object.__setattr__(result.data.data, "kept_idx", kept_idx) + object.__setattr__(result.data.data, "reduced_idx", reduced_idx) + + result.realize() + return result + + return inner + + +lowering.make_reduction = make_reduction + +from torch._inductor.lowering import ( + lowerings, + make_fallback, + register_lowering, + to_dtype, + # make_reduction, + # reduce_amax, + # reduce_amin, + fallback_cumsum, + _validate_reduction_axis, + div, + squeeze, + square, + sub, + fallback_handler, + is_boolean_type, + logical_and, + make_pointwise, + _make_reduction_inner, + _validate_reduction_axis, +) + +aten = torch.ops.aten +tr_c10d = torch.ops.tr_c10d +prims = torch.ops.prims + +import torch_npu + +from torch_npu import npu_dtype_cast + + +def _init_set(input_list, output_set): + for fn in input_list: + output_set.add(fn) + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + other_fn = getattr(fn, overload) + output_set.add(other_fn) + + +GENERATE_LIST = [ + prims.iota, + aten.full, + aten.mul, + aten.add, + aten.sub, + aten.div, + aten.exp, + aten.maximum, + aten.sum, + aten.select, + aten.unsqueeze, + aten.repeat, + aten.clone, #remove this, case permute_reshape will fail + aten.reshape, + aten.where, + aten.lt, + aten.minimum, + aten.gt, + aten.le, + aten.ceil, + aten.floor, + aten.rsqrt, + aten.abs, + aten.log, + aten.bitwise_xor, + aten.amax, + # backward + prims.convert_element_type, + aten.min, + aten.max, + aten.erf, + aten.argmax, + aten.argmin, + aten.clamp_min, + aten.slice, + aten.neg, + aten.cat, + aten.arange, + aten.expand, + aten.eq, + aten.where, + aten.scalar_tensor, + aten.ge, + aten.permute, + aten.sqrt, + aten.relu, + aten.clamp, + aten.clamp_max, + aten.mean, + # npu.npu_dtype_cast + npu_dtype_cast, + aten.select_scatter, + prims.broadcast_in_dim, + prims.maximum, + aten.ne, + aten.sigmoid, + aten.sign, + aten.logical_and, + aten.logical_or, + aten.logical_not, + aten.pow, + aten.gelu, + aten.tanh, + aten.isnan, + aten.bitwise_and, + aten.squeeze, + aten.copy, + aten.reciprocal +] + +GENERATE_LIST2 = [ + "foreach" +] + +FALLBACK_LIST = [] + +# 先删除从lowering已经注册的op,再更新,不然会lowering的时候找到在torch注册的op +LOWERING_OVERLOAD_OP = [ + aten.cumsum, + aten.mean, + aten.max, + aten.min, + aten.amin, + aten.amax, + aten.argmax, + aten.argmin, + aten.sum, + + aten.var_mean, + aten.var, + + aten.embedding, + aten.split, + aten.split_with_sizes, + aten.nll_loss_forward, + aten.gather, + aten.cat, + aten.slice_scatter, + #aten.clone, cast permute_reshape will fail if enable this +] + + +def _register_npu_inductor_fallbacks(): + gen_set = set() + _init_set(GENERATE_LIST, gen_set) + overload_op_set = set() + _init_set(LOWERING_OVERLOAD_OP, overload_op_set) + + # 把不在白名单的op fallback + for op in lowerings: + if op not in decompositions and op not in gen_set: + if isinstance(op, torch._ops.OpOverloadPacket) or \ + isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + flag = False + for gens in GENERATE_LIST2: + if str(op).find(gens) != -1: + flag = True + if flag: + continue + else: + make_fallback(op) + FALLBACK_LIST.append(op) + # 把需要overload的op在lowering里删除 + for op in overload_op_set: + if op in lowerings: + del lowerings[op] + + # register the reductions useing custom make_reduction + reduce_amax = register_lowering(aten.amax)(make_reduction("max")) + reduce_amin = register_lowering(aten.amin)(make_reduction("min")) + reduce_argmax = register_lowering(aten.argmax)( + make_reduction("argmax", override_return_dtype=torch.int64) + ) + reduce_argmin = register_lowering(aten.argmin)( + make_reduction("argmin", override_return_dtype=torch.int64) + ) + @register_lowering([aten.sum, prims.sum]) + def sum_(x, axis=None, keepdims=False, *, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + fn = make_reduction("sum", override_return_dtype=dtype) + return fn(x, axis, keepdims, dtype=dtype) + + + @register_lowering(aten.max, type_promotion_kind=None) + def reduce_max(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amax(x, axis=dim, keepdims=keepdim), + reduce_argmax(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amax(x, axis=None, keepdims=keepdim) + + + @register_lowering(aten.min, type_promotion_kind=None) + def reduce_min(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amin(x, axis=dim, keepdims=keepdim), + reduce_argmin(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amin(x, axis=None, keepdims=keepdim) + + + @register_lowering(aten.mean) + def mean(x, axis=None, keepdim=False, *, dtype=None): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + # compute in higher-precision until end of mean lowering + output_dtype = x.get_dtype() + if output_dtype in (torch.float16, torch.bfloat16): + x = to_dtype(x, torch.float) + sum_result = sum_(x, axis, keepdim) + denom = sympy_product(size[i] for i in axis) + denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + return to_dtype(div(sum_result, denom), output_dtype) + + + @register_lowering(aten.cumsum) + def cumsum(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + # torch.int64->torch.int32 + dtype = torch.int32 + if len(x.get_size()) == 0: + if axis not in [0, -1]: + raise ValueError("axis must be 0 or -1") + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + return fallback_cumsum(x, dim=axis, dtype=dtype) + + @register_lowering(npu_dtype_cast, type_promotion_kind=None) + def _convert_npu_type(x: TensorBox, dtype: torch.dtype): + return to_dtype(x, dtype, copy=True) + + + def var_mean_sum_(x, axis, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + x_mean = mean(x, axis, keepdim=True) + if return_mean: + x_mean.realize() + + diffs = square(sub(x, x_mean)) + sum_result = sum_(diffs, axis, keepdim) + denom = sympy_product(size[i] for i in axis) + if correction: + denom = sympy.Max(denom - correction, 0) + denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + x_var = div(sum_result, denom) + if not return_mean: + return (x_var,) + + x_mean = x_mean if keepdim else squeeze(x_mean, axis) + return x_var, x_mean + + + def var_mean_helper_(x, *, axis, correction, keepdim, return_mean): + out_dtype = x.get_dtype() + compute_dtype = get_computation_dtype(out_dtype) + x = to_dtype(x, compute_dtype, copy=False) + kwargs = dict( + x=x, + axis=axis, + correction=correction, + keepdim=keepdim, + return_mean=return_mean, + ) + output = ( + var_mean_sum_(**kwargs) + ) + output = tuple(to_dtype(x, out_dtype, copy=False) for x in output) + return output[0] if not return_mean else output + + @register_lowering(aten.var_mean) + def var_mean(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True + ) + + @register_lowering([aten.var, prims.var]) + def var_(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False + ) + + @register_lowering(aten.embedding, type_promotion_kind=None) + def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): + return fallback_handler(aten.embedding.default)(weight, indices, padding_idx=-1, scale_grad_by_freq=False, + sparse=False) + + @register_lowering(aten.cat) + def cat(inputs, dim=0): + return fallback_handler(aten.cat.default)(inputs, dim) + + make_fallback(aten._log_softmax) + make_fallback(aten.gather) + make_fallback(aten.nll_loss_forward) diff --git a/torch_npu/_inductor/lowering_fx.py b/torch_npu/_inductor/lowering_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..e0ff1308becd95ce3de809a4d8208b18d435c4ad --- /dev/null +++ b/torch_npu/_inductor/lowering_fx.py @@ -0,0 +1,2404 @@ +import itertools +import functools + +import os +import textwrap + +import sympy +from sympy.core import Expr, Integer, Symbol +from torch._inductor.ir import Reduction +from torch._inductor.utils import sympy_product +from torch._inductor import ir +from torch._inductor.ir import ExpandView, TensorBox +from torch._inductor.lowering import sum_ +from torch._inductor import lowering +from torch._prims_common import ( + is_boolean_dtype, + is_integer_dtype, + get_computation_dtype, +) +from torch._inductor.decomposition import decompositions, pw_cast_for_opmath +import torch._ops + +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Set, + Tuple, + Union, + ) + +from torch._prims_common import ( + canonicalize_dims, + check, + dtype_to_type, + ELEMENTWISE_TYPE_PROMOTION_KIND, + get_computation_dtype, + is_boolean_dtype, + is_float_dtype, + is_integer_dtype, + Number, +) + +from torch.utils._sympy.functions import ( + FloorDiv, + Identity, + ModularIndexing, +) + + +from torch._inductor.ir import ( + ExpandView, + IndexingConstant, + is_triton, + ops_wrapper, + PermuteView, + Pointwise, + Reduction, + SqueezeView, + TensorBox, + IRNode, + validate_ir, + View, +) + +from torch._inductor.utils import ( + decode_device, + sympy_product, + +) + + +from torch.fx.experimental.proxy_tensor import make_fx +from torch._inductor.fx_passes.post_grad import view_to_reshape +from torch._inductor import scheduler +from torch._inductor.utils import ModularIndexing, FloorDiv +import sympy + + +from torch._inductor.virtualized import ops, V + +from torch._inductor import scheduler + +from torch._inductor.ir import Reduction +from torch._inductor.utils import sympy_product +from torch._inductor import ir +from torch._inductor.ir import ExpandView, TensorBox +from torch._inductor import lowering +from torch._prims_common import ( + is_boolean_dtype, + is_integer_dtype, + get_computation_dtype, +) +from torch._inductor.decomposition import decompositions +import torch._ops + +aten = torch.ops.aten +tr_c10d = torch.ops.tr_c10d +prims = torch.ops.prims +npu = torch.ops.npu + + +def _init_set(input_list, output_set): + for fn in input_list: + output_set.add(fn) + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + other_fn = getattr(fn, overload) + output_set.add(other_fn) + + +GENERATE_LIST = [ + aten.mul, + aten.add, + aten.sub, + aten.div, + aten.exp, + aten.maximum, + aten.sum, + aten.select, + aten.unsqueeze, + aten.repeat, + #aten.clone, + aten.reshape, + aten.where, + aten.lt, + aten.minimum, + aten.gt, + aten.le, + aten.ceil, + aten.floor, + aten.rsqrt, + aten.abs, + aten.log, + aten.bitwise_xor, + aten.amax, + # backward + prims.convert_element_type, + aten.min, + aten.max, + aten.erf, + aten.argmax, + aten.argmin, + aten.clamp_min, + aten.slice, + aten.neg, + aten.cat, + aten.arange, + aten.expand, + aten.eq, + aten.where, + aten.scalar_tensor, + aten.ge, + aten.permute, + aten.sqrt, + aten.relu, + aten.clamp, + aten.clamp_max, + aten.mean, + # npu.npu_dtype_cast + npu.npu_dtype_cast, + aten.select_scatter, + aten.slice_scatter, + prims.broadcast_in_dim, + prims.maximum, + aten.ne, + aten.sigmoid, + aten.sign, + aten.logical_and, + aten.logical_or, + aten.logical_not, + aten.pow, + aten.gelu, + aten.tanh, + aten.isnan, + aten.bitwise_and, + aten.squeeze, + aten.copy, + aten.reciprocal +] + +GENERATE_LIST2 = [ + "foreach" +] + +FALLBACK_LIST = [] + +# 先删除从lowering已经注册的op,再更新,不然会lowering的时候找到在torch注册的op +LOWERING_OVERLOAD_OP = [ + aten.cumsum, + aten.mean, + # aten.max, + # aten.min, + # aten.mul, + aten.var_mean, + aten.var, + + aten.embedding, + aten.split, + aten.split_with_sizes, + aten.nll_loss_forward, + aten.gather, + aten.cat, + aten.clone +] + +LOWERING_OVERLOAD_OP = list(set(GENERATE_LIST) | set(LOWERING_OVERLOAD_OP)) + + +fn_to_aten_fn = {} +node_id = itertools.count(0) + +def register_fn_to_aten_fn(fn, aten_fn=None): + if fn not in fn_to_aten_fn: + fn_to_aten_fn[fn] = aten_fn + return fn + +def register_to_aten(aten_fn=None): + def decorator(fn): + if fn not in fn_to_aten_fn: + fn_to_aten_fn[fn] = aten_fn + return fn + return decorator + +reduction_type_to_aten_fn = { + "sum": aten.sum, + "prod": aten.prod, + "xor_sum": prims.xor_sum, + "any": aten.any, + "max": aten.amax, + "min": aten.amin, + "argmax": aten.argmax, + "argmin": aten.argmin +} + +operator_to_string = { + '+': 'a', + '-': 'sub', + '*': 'm', + '/': 'd', + '(': 'l', + ')': 'r', + '.': 'p', +} + +string_to_operator = {v: k for k, v in operator_to_string.items()} + +def map_operators_to_strings(expr_str: str): + expr_str = expr_str.replace(' ', '') + for op, string in operator_to_string.items(): + expr_str = expr_str.replace(op, string) + return '_' + expr_str + +def map_strings_to_operators(expr_str: str): + for op, string in string_to_operator.items(): + expr_str = expr_str.replace(op, string) + return expr_str[1:] + + +class TracedGraph: + def __init__(self): + self.graph = torch.fx.Graph() + self.last_node: Optional[torch.fx.Node] = None + self.sym_nodes: Dict[str, torch.fx.Node] = {} + + def __str__(self): + return str(self.graph) + + def get_placeholder_names(self): + placeholder_names = set() + for node in self.graph.nodes: + if node.op == 'placeholder' and node.name not in self.sym_nodes: + placeholder_names.add(node.name) + return placeholder_names + + __repr__ = __str__ + + + +def create_fake_input(size, stride, device, dtype): + size = [V.graph.sizevars.shape_env.create_symintnode(s, hint=None) \ + if isinstance(s, Expr) and not isinstance(s, Integer) else s for s in size] + stride = [V.graph.sizevars.shape_env.create_symintnode(s, hint=None) \ + if isinstance(s, Expr) and not isinstance(s, Integer) else s for s in stride] + with V.graph.fake_mode: + fake_input = torch.empty_strided(size, stride, device=device, dtype=dtype) + return fake_input + + +def create_sym_inputs(traced_graph: TracedGraph, size: List[Expr]): + for s in size: + if isinstance(s, (List, Tuple)): + create_sym_inputs(traced_graph, s) + continue + if isinstance(s, Expr) and not isinstance(s, Integer): + s_name = str(s) + if not isinstance(s, Symbol): + s_name = map_operators_to_strings(s_name) + if s_name in traced_graph.sym_nodes: + continue + new_node = traced_graph.graph.placeholder(s_name) + new_node.meta['val'] = V.graph.sizevars.shape_env.create_symintnode(s, hint=None) + traced_graph.sym_nodes.update({s_name: new_node}) + + +def process_ir_constant(inp: ExpandView) -> Union[TracedGraph, int, float]: + skip = False + if isinstance(inp.data, IndexingConstant): + dtype = inp.data.dtype + inp = inp.data.index + # convert to original dtype. + if dtype in [torch.float32, torch.float16, torch.bfloat16]: + # sympy inputs + if isinstance(inp, Expr) and not isinstance(inp, sympy.core.numbers.Number): + traced_graph = TracedGraph() + create_sym_inputs(traced_graph, [inp]) + s_name = str(inp) + if not isinstance(inp, Symbol): + s_name = map_operators_to_strings(str(inp)) + traced_graph.last_node = traced_graph.sym_nodes[s_name] + inp = traced_graph + else: + inp = float(inp) + elif isinstance(inp.data, ir.Constant): + dtype = inp.data.dtype + inp = inp.data.value + else: + skip = True + return inp, skip + + +def fetch_graphs(inputs: Optional[List[TensorBox]]): + if isinstance(inputs, (TensorBox, ir.StorageBox, ir.View, sympy.Symbol, ir.Constant)): + inputs = [inputs] + input_graphs = [] + for inp in inputs: + if isinstance(inp, List): + input_graphs.append(fetch_graphs(inp)) + continue + if not isinstance(inp, (TensorBox, ir.StorageBox, ir.View, ir.ReinterpretView, ir.PermuteView, ir.SliceView, ir.ExpandView)): + input_graphs.append(inp) + continue + if isinstance(inp, ExpandView): + inp, skip = process_ir_constant(inp) + if not skip: + input_graphs.append(inp) + continue + name = inp.get_name() + traced_graph = inp.get_traced_graph() + if traced_graph is not None: + input_graphs.append(traced_graph) + continue + traced_graph = TracedGraph() + device = inp.get_device() + dtype = inp.get_dtype() + size = inp.get_size() + stride = inp.get_stride() + new_node = traced_graph.graph.placeholder(name) + fake_input = create_fake_input(size, stride, device, dtype) + new_node.meta['val'] = fake_input + traced_graph.last_node = new_node + input_graphs.append(traced_graph) + return input_graphs + + +def merge_traced_graphs(input_graphs: List[TracedGraph], origin_fn, node_name, **kwargs): + new_graph = TracedGraph() + exist_nodes: Dict[str, torch.fx.Node] = {} + def merge_graph(input_graphs: List[TracedGraph]): + for input_graph in input_graphs: + if isinstance(input_graph, List): + merge_graph(input_graph) + continue + if not isinstance(input_graph, TracedGraph): + continue + for node in input_graph.graph.nodes: + if node.name in exist_nodes: + continue + new_node = new_graph.graph.node_copy(node, lambda n: exist_nodes[n.name]) + exist_nodes[node.name] = new_node + if node.name in input_graph.sym_nodes: + new_graph.sym_nodes.update({node.name: new_node}) + + def parse_args(input_graphs, exist_nodes): + args = [] + for input_graph in input_graphs: + if isinstance(input_graph, TracedGraph): + args.append(exist_nodes[input_graph.last_node.name]) + elif isinstance(input_graph, (List, Tuple)): + args.append(parse_args(input_graph, exist_nodes)) + else: + if isinstance(input_graph, Expr) and not isinstance(input_graph, Integer): + if not isinstance(input_graph, Symbol): + input_graph = map_operators_to_strings(str(input_graph)) + args.append(new_graph.sym_nodes[str(input_graph)]) + else: + args.append(input_graph) + return args + + num_args = len(input_graphs) + + for k, v in kwargs.items(): + if isinstance(v, Expr) and not isinstance(v, Integer): + traced_graph = TracedGraph() + create_sym_inputs(traced_graph, [v]) + s_name = str(v) + if not isinstance(v, Symbol): + s_name = map_operators_to_strings(str(v)) + traced_graph.last_node = traced_graph.sym_nodes[s_name] + kwargs[k] = traced_graph.sym_nodes[s_name] + input_graphs.append(traced_graph) + merge_graph(input_graphs) + input_graphs = input_graphs[:num_args] + # if inputs do not have any valid graphs, like full/iota + create_sym_inputs(new_graph, input_graphs) + args = parse_args(input_graphs, exist_nodes) + with new_graph.graph.inserting_after(new_graph.last_node): + new_node = new_graph.graph.call_function(origin_fn, args=tuple(args), kwargs=kwargs) + new_node.name = node_name + new_graph.last_node = new_node + return new_graph + +def merge_fx_graphs(traced_graphs: List[TracedGraph]): + new_graph = TracedGraph() + exist_nodes: Dict[str, torch.fx.Node] = {} + last_nodes = [] + def merge_graph(input_graphs: List[TracedGraph]): + for input_graph in input_graphs: + if isinstance(input_graph, List): + merge_graph(input_graph) + continue + if not isinstance(input_graph, TracedGraph): + continue + for node in input_graph.graph.nodes: + if node.name in exist_nodes: + continue + new_node = new_graph.graph.node_copy(node, lambda n: exist_nodes[n.name]) + exist_nodes[node.name] = new_node + last_nodes.append(exist_nodes[input_graph.last_node.name]) + merge_graph(traced_graphs) + new_graph.last_node = last_nodes + return new_graph + +def subtract_graph(graph1: TracedGraph, graph2: TracedGraph, node_name=None) -> Tuple[TracedGraph, torch.fx.Node]: + new_graph = TracedGraph() + last_node2 = graph2.last_node + graph1_node_names = {node.name for node in graph1.graph.nodes} + graph2_node_names = {node.name for node in graph2.graph.nodes} + placeholder = None + exist_nodes: Dict[str, torch.fx.Node] = {} + if node_name not in graph1_node_names: + placeholder = new_graph.graph.placeholder(last_node2.name if node_name is None else node_name) + exist_nodes[last_node2.name] = placeholder + for node in graph1.graph.nodes: + if node.name in graph2_node_names and node.name not in graph1.sym_nodes: + continue + new_node = new_graph.graph.node_copy(node, lambda n: exist_nodes[n.name]) + exist_nodes[node.name] = new_node + new_graph.last_node = exist_nodes[graph1.last_node.name] + new_graph.sym_nodes = graph1.sym_nodes + return new_graph, placeholder + + +def get_last_node(gm: torch.fx.GraphModule): + last_node = None + for node in gm.graph.nodes: + last_node = node + return last_node + +def create_fx_from_snodes_by_traced_graph(snodes: List[scheduler.SchedulerNode]): + fx_call_inputs = [] + for snode in snodes: + snode.node.data.traced_graph.last_node.name = snode.node.get_name() + if len(snodes) == 1: + traced_graph = snodes[0].node.data.traced_graph + else: + traced_graph = merge_fx_graphs([snode.node.data.traced_graph for snode in snodes]) + fx_inputs = [] + for node in traced_graph.graph.nodes: + if node.op == 'placeholder': + fx_call_inputs.append(node.target) + fx_inputs.append(node.meta['val']) + non_contiguous_indices = {} + non_contiguous_indices["inputs"] = [i for i, inp in enumerate(fx_inputs) if torch.is_tensor(inp) and not inp.is_contiguous()] + num_inputs = len(fx_call_inputs) + fx_call_outputs = [] + for snode in snodes: + if snode.has_aliasing_or_mutation(): + for buf in snode.get_outputs(): + if len(buf.get_mutations()): + fx_call_outputs.extend(buf.get_mutations()) + elif len(buf.get_aliases()): + fx_call_outputs.append(buf.get_name()) + elif snode.node.get_name() not in (V.graph.removed_buffers | V.graph.inplaced_to_remove): + fx_call_outputs.append(snode.node.get_name()) + num_outputs = len(fx_call_outputs) + outputs = traced_graph.last_node if isinstance(traced_graph.last_node, List) \ + else [traced_graph.last_node] + outputs = [output for output in outputs if output.name not in (V.graph.removed_buffers | V.graph.inplaced_to_remove)] + fx_call_args = fx_call_inputs + fx_call_outputs + traced_graph.graph.output(tuple(outputs)) + traced_graph.graph.lint() + orig_module = torch.nn.Module() + gm = torch.fx.GraphModule(orig_module, traced_graph.graph) + gm.recompile() + def runnable_gm(*args): + return torch.fx.Interpreter(gm).run(*args) + with V.graph.fake_mode: + gm = make_fx(runnable_gm)(*fx_inputs) + view_to_reshape(gm) + last_node = get_last_node(gm) + fx_output_nodes = last_node.args[0] + fx_outputs = [node.meta['val'] for node in fx_output_nodes] + non_contiguous_indices["outputs"] = [i + num_inputs for i, call_output in enumerate(fx_call_outputs) \ + if not V.graph.try_get_buffer(call_output).layout.is_contiguous()] + fx_args = fx_inputs + fx_outputs + + return gm, fx_call_args, fx_args, { + "num_inputs": num_inputs, + "num_outputs": num_outputs, + "non_contiguous_indices": non_contiguous_indices, + } + + +def create_compile_kwargs(final_kernel, fx_call_args, fx_args): + + _, kernel_call_args, _, arg_types = final_kernel.args.python_argdefs() + for idx, call_arg in enumerate(fx_call_args): + if call_arg in final_kernel.args.inplace_buffers: + fx_call_args[idx] = final_kernel.args.inplace_buffers[call_arg].other_names[-1] + fx_arg_shapes = [fx_arg.shape for fx_arg in fx_args if isinstance(fx_arg, torch.Tensor)] + + if set(kernel_call_args) != set(fx_call_args): + return None + grid: List[Any] = [] + final_kernel.add_numel_to_call_args_and_grid(final_kernel.kernel_name, kernel_call_args, arg_types, grid) + + index_map = {element: idx for idx, element in enumerate(kernel_call_args)} + call_args_mapping = [index_map[element] for element in fx_call_args] + + mismatch_indices_shapes = {} + + for i in range(len(fx_call_args)): + mismatch_indices_shapes[i] = fx_arg_shapes[i] + + return { + "call_args_mapping": call_args_mapping, + 'grid': tuple(grid), + "mismatch_indices_shapes": mismatch_indices_shapes, + } + +def generate_fx_graph_code(code, kernel_code, kernel_name, compile_kwargs): + code = textwrap.indent(code, ' ') + code_template = f""" +import os +import torch +from torch._inductor.compile_fx import clone_preserve_strides +from torch._dynamo.testing import rand_strided +from torch import device + +import math +import random +import os +import tempfile +from math import inf, nan +from torch._inductor.hooks import run_intermediate_hooks +from torch._inductor.utils import maybe_profile +from torch._inductor.codegen.memory_planning import _align as align +from torch import device, empty_strided +from torch._inductor.async_compile import AsyncCompile +from torch._inductor.select_algorithm import extern_kernels +from torch._inductor.codegen.multi_kernel import MultiKernelCall +import triton +import triton.language as tl +from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph +from torch_npu._inductor.npu_triton_heuristics import grid +from torch_npu._inductor import get_current_raw_stream as get_raw_stream +from torch_npu._inductor import config as npu_config + +aten = torch.ops.aten +inductor_ops = torch.ops.inductor +_quantized = torch.ops._quantized +assert_size_stride = torch._C._dynamo.guards.assert_size_stride +empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda +empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor +alloc_from_pool = torch.ops.inductor._alloc_from_pool + +file_path = os.path.abspath(__file__) +dir_path = os.path.dirname(file_path) + + +class GraphModule(torch.nn.Module): + def __init__(self): + super().__init__() +{code} +model = GraphModule().npu() +call_args_mapping = {compile_kwargs['call_args_mapping']} +num_inputs = {compile_kwargs['num_inputs']} +num_outputs = {compile_kwargs['num_outputs']} +non_contiguous_indices = {compile_kwargs['non_contiguous_indices']} +mismatch_indices_shapes = {compile_kwargs['mismatch_indices_shapes']} + +def run(): + async_compile = AsyncCompile() + {kernel_name} = async_compile.triton('{kernel_name}', ''' +{kernel_code} + ''', device_str='npu') + + async_compile.wait(globals()) + del async_compile + + stream0 = get_raw_stream(0) + + + args = torch.load(os.path.join(dir_path, "data.pth")) + + call_inputs_indices = call_args_mapping[:num_inputs] + call_outputs_indices = call_args_mapping[num_inputs:] + + args = [arg.npu() if isinstance(arg, torch.Tensor) else arg for arg in args] + + fx_args = [] + for idx in call_args_mapping: + arg = args[idx] + if isinstance(arg, torch.Tensor): + fx_arg = clone_preserve_strides(arg).float() if arg.dtype == torch.bfloat16 else clone_preserve_strides(arg) + fx_args.append(fx_arg) + + fx_inputs = [fx_args[idx].contiguous() if idx in non_contiguous_indices['inputs'] else fx_args[idx] for idx in range(num_inputs)] + if len(mismatch_indices_shapes): + for ind, shape in mismatch_indices_shapes.items(): + if ind >= num_inputs: + break + fx_inputs[ind] = fx_inputs[ind].reshape(shape) + model_outputs = model.forward(*fx_inputs) + for idx, (out1, out2) in enumerate(zip(model_outputs, fx_args[num_inputs:(num_inputs + num_outputs)])): + out1 = out1.reshape(out2.shape) + if idx in non_contiguous_indices['outputs']: + out2.copy_(out1) + else: + out2.data = out1.data + + {kernel_name}.run(*args, grid=grid{compile_kwargs['grid']}, stream=stream0) + + for actual, expected in zip([args[i] for i in call_outputs_indices], fx_args[num_inputs:]): + if actual.dtype != expected.dtype: + expected = expected.to(actual.dtype) + acc_comp_tol = npu_config.acc_comp_tol.get(actual.dtype, npu_config.acc_comp_tol['default']) + rtol = acc_comp_tol['rtol'] + atol = acc_comp_tol['atol'] + try: + torch.testing.assert_close(actual, expected, rtol=rtol, atol=atol, equal_nan=False) + except Exception as e: + print(e) + +if __name__ == "__main__": + run() +""" + return code_template + + +def dump_fx_graph_code(code, dump_path, traced_graph_hash): + py_path = os.path.join(dump_path, traced_graph_hash + '.py') + with open(py_path, 'w') as f: + f.write(code) + + +def clone(x, *, memory_format=None): + # TODO(jansel): memory format + input_graphs = fetch_graphs(x) + node_name = f'clone_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.clone, node_name) + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + ranges=list(x.get_size()), + traced_graph=new_graph, + node_name=node_name + ) + + +def _register_npu_inductor_fallbacks(): + gen_set = set() + _init_set(GENERATE_LIST, gen_set) + overload_op_set = set() + _init_set(LOWERING_OVERLOAD_OP, overload_op_set) + + # 把不在白名单的op fallback + for op in lowering.lowerings: + if op not in decompositions and op not in gen_set: + if isinstance(op, torch._ops.OpOverloadPacket) or \ + isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + flag = False + for gens in GENERATE_LIST2: + if str(op).find(gens) != -1: + flag = True + if flag: + continue + else: + lowering.make_fallback(op) + FALLBACK_LIST.append(op) + + # 把需要overload的op在lowering里删除 + for op in overload_op_set: + if op in lowering.lowerings: + del lowering.lowerings[op] + + + def transform_args( + args: List[Any], + kwargs: Dict[str, Any], + broadcast: bool, + type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND], + convert_input_to_bool: bool, + ) -> Tuple[List[Any], Dict[str, Any]]: + args_indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + kwargs_indices = [k for k, v in kwargs.items() if isinstance(v, TensorBox)] + # check that there's something to transform + if not args_indices and not kwargs_indices: + return args, kwargs + + if type_promotion_kind or convert_input_to_bool: + if convert_input_to_bool: + dtype = torch.bool + else: + # FIXME this is a crude approximation for promoting args + promoting_args = [ + a + for a in args + if isinstance(a, (Number, sympy.Basic)) or hasattr(a, "dtype") + ] + # only consider tensor kwargs for promotion, for now + promoting_args.extend(a for a in kwargs.values() if hasattr(a, "dtype")) + dtype = lowering.get_promoted_dtype( + *promoting_args, type_promotion_kind=type_promotion_kind # type: ignore[arg-type] + ) + + device = ( + args[args_indices[0]] if args_indices else kwargs[kwargs_indices[0]] + ).get_device() + + # sometimes args are an immutable list so we can't mutate them + def promote(arg): + if isinstance(arg, TensorBox): + return to_dtype(arg, dtype) + elif isinstance(arg, ir.Constant): + return ir.Constant(value=arg.value, dtype=dtype, device=device) + else: + return arg + + args = [promote(a) for a in args] + kwargs = {k: promote(v) for k, v in kwargs.items()} + + if broadcast: + broadcasted = broadcast_tensors( + *list( + itertools.chain( + (args[i] for i in args_indices), + (kwargs[k] for k in kwargs_indices), + ) + ) + ) + size = list(broadcasted[0].get_size()) + + for i, x in zip(args_indices, broadcasted[: len(args_indices)]): + args[i] = x + for k, x in zip(kwargs_indices, broadcasted[len(args_indices) :]): + kwargs[k] = x + + for i in range(len(args)): + if isinstance(args[i], ir.Constant): + args[i] = ExpandView.create(args[i], size) + for k in kwargs: + if isinstance(kwargs[k], ir.Constant): + kwargs[k] = ExpandView.create(kwargs[k], size) + + return args, kwargs + + + def _register_lowering( + aten_fn, decomp_fn, broadcast, type_promotion_kind, convert_input_to_bool + ): + + """ + Add a lowering to lowerings dict + + Arguments: + aten_fn: torch.ops.aten.* fn we are lowering + decomp_fn: alternate implementation on our IR + broadcast: True to apply broadcasting to tensor inputs + type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion + convert_input_to_bool: some logical ops require inputs are converted to bool + """ + + @functools.wraps(decomp_fn) + def wrapped(*args, **kwargs): + args: List[Any] = list(args) + kwargs: Dict[str, Any] = dict(kwargs) + unpacked = False + # TODO maybe we need to use pytrees here + if len(args) == 1 and isinstance(args[0], (list, tuple)): + unpacked = True + args = list(args[0]) + + if not all( + (fn in lowering.fallbacks or lowering.in_namespace(fn, "_c10d_functional")) for fn in aten_fn + ): + # explicitly assert for "out=" ops for better error messages + assert not any( + x == "out" for x in kwargs.keys() + ), "out= ops aren't yet supported" + + args, kwargs = transform_args( + args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool + ) + + if unpacked: + args = [args] + + out = decomp_fn(*args, **kwargs) + validate_ir(out) + + return out + + aten_fn = lowering.get_overloads(aten_fn) + + lowering.lowerings.update(dict.fromkeys(aten_fn, wrapped)) + return wrapped + + + def register_lowering( + aten_fn, + broadcast=False, + type_promotion_kind=lowering.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + convert_input_to_bool=False, + ): + + """ + Shim to support decorator syntax. + """ + return functools.partial( + _register_lowering, + aten_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + ) + + + def _make_reduction_inner(x, *, axis, keepdims, dtype, override_return_dtype): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = set(lowering._validate_reduction_axis(x, axis)) + + kept_sizes = [] + kept_idx = [] + reduced_sizes = [] + reduced_idx = [] + for i in range(len(size)): + if i in axis: + reduced_idx.append(i) + reduced_sizes.append(size[i]) + else: + kept_idx.append(i) + kept_sizes.append(size[i]) + + def loader(index, reduction_index): + assert len(reduction_index) == len(reduced_idx) + if keepdims: + assert len(index) == len(size) + index = [index[i] for i in kept_idx] + assert len(index) == len(kept_idx) + new_index = [None] * (len(index) + len(reduction_index)) + for idx, var in itertools.chain( + zip(kept_idx, index), zip(reduced_idx, reduction_index) + ): + new_index[idx] = var + return inner_loader(new_index) + + if keepdims: + new_size = list(size) + for i in reduced_idx: + new_size[i] = sympy.S.One + else: + new_size = kept_sizes + + inner_loader = x.make_loader() + return dict( + device=x.get_device(), + dst_dtype=override_return_dtype or x.get_dtype(), + src_dtype=x.get_dtype(), + inner_fn=loader, + ranges=new_size, + reduction_ranges=reduced_sizes, + ) + + + def make_reduction(reduction_type: str, override_return_dtype=None): + def inner(x, axis=None, keepdims=False, *, dtype=None): + kwargs = _make_reduction_inner( + x, + axis=axis, + keepdims=keepdims, + dtype=dtype, + override_return_dtype=override_return_dtype, + ) + node_name = f'reduction_{next(node_id)}' + input_graphs = fetch_graphs([x, axis if axis is not None else list(range(len(x.get_size())))]) + new_graph = merge_traced_graphs(input_graphs, reduction_type_to_aten_fn[reduction_type], + node_name, keepdim=keepdims) + + result = Reduction.create(reduction_type=reduction_type, + input_node=x, + node_name=node_name, + traced_graph=new_graph, + **kwargs) + if isinstance( + result.data.data, Reduction + ): + #Only realize if reduction isn't unrolled + size = x.get_size() + axis = set(lowering._validate_reduction_axis(x, axis)) + kept_idx = [] + reduced_idx = [] + for i in range(len(size)): + if i in axis: + reduced_idx.append(i) + else: + kept_idx.append(i) + + object.__setattr__(result.data.data, "kept_idx", kept_idx) + object.__setattr__(result.data.data, "reduced_idx", reduced_idx) + + result.realize() + return result + + return inner + + + lowering.make_reduction = make_reduction + + + def to_dtype(x: TensorBox, dtype: torch.dtype, copy=False): + src_dtype = x.get_dtype() + if src_dtype == dtype: + return clone(x) if copy else x + + def _to_dtype(x): + return ops.to_dtype(x, dtype, src_dtype=src_dtype) + register_fn_to_aten_fn(_to_dtype, aten.to.dtype) + return make_pointwise(_to_dtype, override_return_dtype=dtype, dtype=dtype)(x) + + + @register_lowering(prims.convert_element_type, type_promotion_kind=None) + def _convert_element_type(x: TensorBox, dtype: torch.dtype): + if dtype.is_complex or x.get_dtype().is_complex: + if x.get_size(): + # Decompose since aa aten fallback is more friendly for c++ codegen. + # This decomposition doesn't work for empty tensor, which needs more investigation. + dst = empty_like(x, dtype=dtype) + ir.InplaceCopyFallback.create(dst, x) + return dst + else: + return lowering.fallback_handler( + prims.convert_element_type.default, add_to_fallback_set=False + )(x, dtype) + return to_dtype(x, dtype, copy=True) + + + def register_pointwise( + aten_fn, + name=None, + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + convert_input_to_bool=False, + override_return_dtype=None, + override_fn_when_input_bool=None, + allow_alpha=False, + use_libdevice_for_f64=False, + triton_fallback=None, + ): + """A pointwise function that maps ops.{name} to inputs""" + name = name or aten_fn.__name__ + fn = ops_wrapper(name) + if use_libdevice_for_f64: + fn_libdevice = ops_wrapper("libdevice_" + name) + lowering.register_op_dtype_propagation_rules( + "libdevice_" + name, type_promotion_kind, override_return_dtype + ) + + lowering.register_op_dtype_propagation_rules( + name, type_promotion_kind, override_return_dtype + ) + + if override_fn_when_input_bool is not None: + override_fn_when_input_bool = ops_wrapper(override_fn_when_input_bool) + + fn = register_fn_to_aten_fn(fn, aten_fn) + + fn = make_pointwise( + fn, + override_return_dtype=override_return_dtype, + override_fn_when_input_bool=override_fn_when_input_bool, + override_fn_when_gpu_float64=fn_libdevice if use_libdevice_for_f64 else None, # type: ignore[possibly-undefined] + allow_alpha=allow_alpha, + triton_fallback=triton_fallback, + ) + fn = register_lowering( + aten_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + )(fn) + + if hasattr(prims, name): + register_lowering( + getattr(prims, name), + type_promotion_kind=None, + convert_input_to_bool=convert_input_to_bool, + )(fn) + return fn + + + def make_pointwise( + fn, + override_return_dtype=None, + override_device=None, + override_fn_when_input_bool=None, + override_fn_when_gpu_float64=None, + allow_alpha=False, + triton_fallback=None, + **kwargs + ): + def inner(*inputs: TensorBox, alpha=None): + if triton_fallback is not None and any( + isinstance(inp, IRNode) and is_triton(inp) for inp in inputs + ): + assert not allow_alpha # not implemented + return triton_fallback(*inputs) + + inputs = lowering.promote_constants(inputs, override_return_dtype) + if allow_alpha: + if alpha is not None and alpha != 1: + inputs = list(inputs) + inputs[-1] = mul(inputs[-1], alpha) + else: + assert alpha is None + loaders = [x.make_loader() for x in inputs] + ranges = inputs[0].get_size() + dtype = override_return_dtype or inputs[0].get_dtype() + is_gpu_device = lowering.is_gpu(decode_device(inputs[0].get_device()).type) + + for other in inputs[1:]: + assert isinstance(other, ir.BaseConstant) or len(ranges) == len( + other.get_size() + ), f"ndim mismatch {fn} {ranges} {other.get_size()}" + + # in tracing, we will annotate pointwise nodes that correspond to the output of + # a pointwise node that would have been run in eager. intermediary pointwise nodes + # during decompositions are not annotated. + emulate_precision_casts = ( + V.graph is not None + and getattr(V.graph, "current_node", None) is not None + and V.graph.current_node.meta is not None + and V.graph.current_node.meta.get("low_precision_pointwise_barrier", False) + and dtype in (torch.bfloat16, torch.float16) + ) + + def inner_fn(index): + assert len(index) == len(ranges), f"wrong ndim {index} {ranges}" + if dtype == torch.bool and override_fn_when_input_bool is not None: + return override_fn_when_input_bool(*[load(index) for load in loaders]) + elif ( + override_fn_when_gpu_float64 + and is_gpu_device + and dtype == torch.float64 + ): + return override_fn_when_gpu_float64(*[load(index) for load in loaders]) + else: + inputs_loaded = [] + for load in loaders: + out = load(index) + if emulate_precision_casts: + downcast = ops.to_dtype(out, dtype, use_compute_types=False) + out = ops.to_dtype(downcast, dtype) + inputs_loaded.append(out) + + out = fn(*inputs_loaded) + if emulate_precision_casts: + # fp16/bf16 kernels are computed in fp32. Casting down to fp16/bf16 here, + # then upcasting again, to emulate casts that eager would do. + downcast = ops.to_dtype(out, dtype, use_compute_types=False) + return ops.to_dtype(downcast, dtype) + return out + + if not override_device: + device = None + for i in inputs: + if lowering.is_gpu(i.get_device().type): + device = i.get_device() + break + if not device: + device = inputs[0].get_device() + + device = override_device or device + + input_graphs = fetch_graphs(inputs) + node_name = f'pointwise_{next(node_id)}' + origin_fn = fn_to_aten_fn[fn] + new_graph = merge_traced_graphs(input_graphs, origin_fn, node_name, **kwargs) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=ranges, + node_name=node_name, + traced_graph=new_graph, + ) + + return inner + + @register_lowering(aten.where, broadcast=False, type_promotion_kind=None) + def where(cond, a, b): + def fn(*args): + return ops.where(*args) + + if isinstance(a, (float, int)): + a = lowering.constant_like(a)(b) + if isinstance(b, (float, int)): + b = lowering.constant_like(b)(a) + + args = [cond, a, b] + dtype = lowering.get_promoted_dtype( + args[1], args[2], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])): + args[i] = x + for i in range(len(args)): + if isinstance(args[i], ir.Constant): + args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size())) + register_fn_to_aten_fn(fn, aten.where) + return make_pointwise(fn, override_return_dtype=dtype)( + args[0], to_dtype(args[1], dtype), to_dtype(args[2], dtype) + ) + + + @register_lowering(aten.broadcast_tensors, broadcast=False, type_promotion_kind=None) + def broadcast_tensors(*inputs): + if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)): + return broadcast_tensors(*inputs[0]) + target: List[sympy.Expr] = functools.reduce( + lowering.broadcast_symbolic_shapes, [x.get_size() for x in inputs], [] + ) + outputs = [] + for x in inputs: + sizes = x.get_size() + if len(sizes) != len(target) or any( + ( + ( + V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(a, 1), size_oblivious=True + ) + and not V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(b, 1), size_oblivious=True + ) + ) + or ( + not V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(a, 1), size_oblivious=True + ) + and V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(b, 1), size_oblivious=True + ) + ) + ) + for a, b in zip(sizes, target) + ): + x = expand(x, target) + outputs.append(x) + return outputs + + + @register_lowering(aten.squeeze, type_promotion_kind=None) + def squeeze(x, dim=None): + assert isinstance(x, TensorBox) + if dim is None: + return TensorBox(SqueezeView.create(x.data)) + + dim = ( + V.graph.sizevars.evaluate_static_shape(dim) + if isinstance(dim, (int, sympy.Expr)) + else tuple(V.graph.sizevars.evaluate_static_shape(d) for d in dim) + ) + dim = canonicalize_dims(len(x.get_size()), dim) # type: ignore[call-overload] + dims = set((dim,) if not isinstance(dim, tuple) else dim) + + new_shape = [] + for d, s in enumerate(x.get_size()): + if not ( + d in dims + and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1, size_oblivious=True)) + ): + new_shape.append(s) + + # squeeze does nothing if the size isn't 1 + return view(x, new_shape) if new_shape != x.get_size() else x + + + @register_lowering([aten.squeeze_]) + def squeeze_(x, dim=None): + val = squeeze(x, dim) + assert isinstance(x, TensorBox) + assert isinstance(val, TensorBox) + x.data = val.data + return x + + + @register_lowering(aten.isinf) + def isinf(x): + if lowering.is_integer_type(x): + return full_like(x, False, dtype=torch.bool) + fn = ops_wrapper("isinf") + register_fn_to_aten_fn(fn, aten.isinf) + return make_pointwise(fn, override_return_dtype=torch.bool)(x) + + + @register_lowering(aten.isnan) + def isnan(x): + if lowering.is_integer_type(x): + return full_like(x, False, dtype=torch.bool) + fn = ops_wrapper("isnan") + register_fn_to_aten_fn(fn, aten.isnan) + return make_pointwise(fn, override_return_dtype=torch.bool)(x) + + + @register_lowering(aten.ceil) + def ceil(x): + if lowering.is_integer_type(x): + return clone(x) + fn = ops_wrapper("ceil") + register_fn_to_aten_fn(fn, aten.ceil) + return make_pointwise(fn)(x) + + + @register_lowering(aten.floor) + def floor(x): + if lowering.is_integer_type(x): + return clone(x) + fn = ops_wrapper("floor") + register_fn_to_aten_fn(fn, aten.floor) + return make_pointwise(fn)(x) + + + @register_lowering(aten.round.default) + def round(x): + if lowering.is_integer_type(x): + return clone(x) + else: + fn = ops_wrapper("round") + register_fn_to_aten_fn(fn, aten.round) + return make_pointwise(fn)(x) + + + @register_lowering(aten.trunc) + def trunc(x): + if lowering.is_integer_type(x): + return clone(x) + fn = ops_wrapper("trunc") + register_fn_to_aten_fn(fn, aten.trunc) + return make_pointwise(fn)(x) + + + @register_lowering(aten.expand, type_promotion_kind=None) + def expand(x, sizes): + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + (x,) = lowering.promote_constants([x]) + if isinstance(x, ir.BaseConstant): + return ExpandView.create(x, tuple(sizes)) + assert isinstance(x, TensorBox) + assert isinstance(sizes, (list, tuple)) + if tuple(x.get_size()) == tuple(sizes): + return x + + if not free_unbacked_symbols(x.get_size()): + x_size_product = V.graph.sizevars.size_hint(sympy_product(x.get_size())) + # TODO: It would be better to realize the input if any of its sizes + # are unbacked, because typically the size will be non-zero. However, + # this cannot be done directly as below as we'll choke on the size_hint + # here + if x_size_product > 0 and not free_unbacked_symbols(sizes): + # maybe realize input before broadcasting it + x.mark_reuse( + V.graph.sizevars.size_hint(sympy_product(sizes)) // x_size_product + ) + input_graphs = fetch_graphs([x.data, tuple(sizes)]) + node_name = f'expand_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.expand, node_name) + return TensorBox(ExpandView.create(x.data, tuple(sizes), traced_graph=new_graph, node_name=node_name)) + + + @register_lowering(aten.expand_as, type_promotion_kind=None) + def expand_as(x, y): + return expand(x, y.get_size()) + + + @register_lowering(aten.repeat) + def repeat(x, repeats): + input_graphs = fetch_graphs([x, repeats]) + node_name = f'repeat_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.repeat, node_name) + old_size = list(x.get_size()) + if len(repeats) > len(old_size): + old_size = [sympy.S.One] * (len(repeats) - len(old_size)) + old_size + x = view(x, list(old_size)) + assert len(repeats) == len(x.get_size()) + + new_size = list(x.get_size()) + + zero_tensor = False + for i in range(len(repeats)): + if repeats[i] == 0: + zero_tensor = True + new_size[i] = new_size[i] * repeats[i] + + if zero_tensor: + return empty(new_size, dtype=x.get_dtype(), device=x.get_device()) + if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)): + return clone(expand(x, new_size)) + + x_loader: Callable[[Any], Any] + + def inner_fn(index): + assert len(index) == len(repeats) + index = list(index) + for i in range(len(repeats)): + if repeats[i] != 1: + if old_size[i] == 1: + index[i] = sympy.S.Zero + else: + index[i] = ModularIndexing(index[i], 1, old_size[i]) + return x_loader(index) + + old_size_product = V.graph.sizevars.size_hint(sympy_product(old_size)) + if old_size_product > 0: + # maybe realize the input + x.mark_reuse( + V.graph.sizevars.size_hint(sympy_product(new_size)) // old_size_product + ) + + x_loader = x.make_loader() + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(new_size), + traced_graph=new_graph, + node_name=node_name + ) + + @register_lowering(aten._unsafe_view, type_promotion_kind=None) + @register_lowering(aten.view, type_promotion_kind=None) + @register_lowering(aten.reshape, type_promotion_kind=None) + def view(x, sizes): + assert isinstance(x, TensorBox) + assert isinstance(sizes, (list, tuple)) + input_graphs = fetch_graphs([x.data, sizes]) + node_name = f'view_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.reshape, node_name) + return TensorBox(View.create(x.data, sizes, traced_graph=new_graph, node_name=node_name)) + + + @register_lowering(aten.permute, type_promotion_kind=None) + def permute(x, dims): + assert isinstance(x, TensorBox) + assert isinstance(dims, (list, tuple)) + input_graphs = fetch_graphs([x.data, dims]) + node_name = f'permute_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.permute, node_name) + return TensorBox(PermuteView.create(x.data, tuple(dims), traced_graph=new_graph, node_name=node_name)) + + + @register_lowering(aten.slice, type_promotion_kind=None) + def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True): + assert isinstance(x, TensorBox) + dim = _validate_dim(x, dim, 0) + input_graphs = fetch_graphs([x.data]) + node_name = f'slice_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.slice, node_name, dim=dim, start=start, end=end, step=step) + + return TensorBox(ir.SliceView.create(x.data, dim, start, end, step, traced_graph=new_graph, node_name=node_name)) + + + @register_lowering(aten.select, type_promotion_kind=None) + def select(x, dim, idx): + idx = View.handle_negative_index(idx, x.get_size()[dim]) + return squeeze(slice_(x, dim, idx, idx + 1), dim) + + + @register_lowering(aten.split, type_promotion_kind=None) + def split(x, sizes, dim=0): + dim = _validate_dim(x, dim, 0) + sizes_ = sizes + + # If sizes is an integer (or a SymInt), we turn it into a list of sizes + # by computing what the actual size of each chunk should be. + if not isinstance(sizes, (list, tuple)): + x_size = x.get_size()[dim] + chunks = V.graph.sizevars.evaluate_static_shape( + FloorDiv(x_size + sizes - 1, sizes) + ) + sizes_ = [sizes] * chunks + # The last chunk might have a smaller size than the rest. + sizes_[-1] = x_size - (chunks - 1) * sizes + + # From this point, we assume that the sum of the sizes of all chunks + # equals the size of the base tensor. + result = [] + start = 0 + for size in sizes_: + end = start + size + # No need for clamping here, since we compute the exact + # start and end values. + result.append(slice_(x, dim, start, end, clamp=False)) + start = end + return result + + + @register_lowering(aten.split_with_sizes, type_promotion_kind=None) + def split_with_sizes(x, sizes, dim=0): + return split(x, sizes, dim) + + + @register_lowering(aten.unbind, type_promotion_kind=None) + def unbind(x, dim=0): + dim = _validate_dim(x, dim, 0) + x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) + result = [select(x, dim, i) for i in range(x_size)] + return result + + + @register_lowering(aten.unsqueeze, type_promotion_kind=None) + def unsqueeze(x, dim): + dim = _validate_dim(x, dim, 1) + new_shape = list(x.get_size()) + new_shape.insert(dim, sympy.S.One) + return view(x, new_shape) + + + @register_lowering(aten.unsqueeze_, type_promotion_kind=None) + def unsqueeze_(x, dim): + val = unsqueeze(x, dim) + assert isinstance(x, TensorBox) + assert isinstance(val, TensorBox) + x.data = val.data + return x + + + def _validate_dim(x, dim, offset=0): + dim = V.graph.sizevars.shape_env.evaluate_expr(sympy.sympify(dim)) + ndim = len(x.get_size()) + if dim < 0: + dim += ndim + offset + assert 0 <= dim < ndim + offset + return dim + + + @register_lowering(aten.copy, type_promotion_kind=None) + def copy(self, src, non_blocking=False): + x = src + if self.get_device() != src.get_device(): + x = lowering.to_device(x, self.get_device()) + if self.get_dtype() != src.get_dtype(): + x = to_dtype(x, self.get_dtype()) + + if self.get_size() != src.get_size(): + out = expand(x, self.get_size()) + return clone(out) + return clone(x) + + + @register_lowering(prims.iota) + def iota( + length, + *, + start, + step, + dtype, + device, + requires_grad, + ): + def fn(index): + return ops.index_expr(step * index[0] + start, dtype=dtype) + + node_name = f'iota_{next(node_id)}' + new_graph = merge_traced_graphs([length], prims.iota, node_name, \ + start=start, step=step, \ + dtype=dtype, device=device, \ + requires_grad=requires_grad) + return Pointwise.create( + device=decode_device(device), + dtype=dtype, + inner_fn=fn, + ranges=[length], + traced_graph=new_graph, + node_name=node_name + ) + + + @register_lowering(aten.select_scatter, type_promotion_kind=None) + def select_scatter(x, src, dim: int, index: int): + assert x.get_dtype() == src.get_dtype() + input_graphs = fetch_graphs([x, src, dim, index]) + node_name = f'select_scatter_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.select_scatter, node_name) + x_loader = x.make_loader() + dim = _validate_dim(x, dim, 0) + if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)): + index = index + x.get_size()[dim] + V.graph.sizevars.guard_leq(0, index) # type: ignore[arg-type] + V.graph.sizevars.guard_lt(index, x.get_size()[dim]) # type: ignore[arg-type] + src = expand(unsqueeze(src, dim), x.get_size()) + src_loader = src.make_loader() + + def inner_fn(idx): + return ops.where( + ops.eq( + ops.index_expr(idx[dim], torch.int32), + ops.index_expr(index, torch.int32), + ), + src_loader(idx), + x_loader(idx), + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(x.get_size()), + traced_graph=new_graph, + node_name=node_name + ) + + + @register_lowering(aten.slice_scatter, type_promotion_kind=None) + def slice_scatter(x, src, dim=0, start=None, end=None, step=1): + assert x.get_dtype() == src.get_dtype() + input_graphs = fetch_graphs([x, src]) + node_name = f'slice_scatter_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.slice_scatter, node_name, \ + dim=dim, + start=start, + end=end, + step=step) + x_loader = x.make_loader() + dim = _validate_dim(x, dim, 0) + dim_size = x.get_size()[dim] + + start, end = ir.SliceView.normalize_start_end(x, dim, start, end) + + src_size = list(x.get_size()) + src_size[dim] = FloorDiv(end - start + (step - 1), step) + src = expand(src, src_size) + src_loader = src.make_loader() + + def inner_fn(idx): + if start == 0 and end == dim_size and step == 1: + # selecting every element is the same as just src.clone() + return src_loader(idx) + + idx_dim = ops.index_expr(idx[dim], torch.int64) + src_idx = list(idx) + src_idx[dim] = FloorDiv(idx[dim] - start, step) + + mask = [] + if start != 0: + mask.append( + ops.ge( + idx_dim, + ops.index_expr(sympy.expand(start), torch.int64), + ) + ) + if end != dim_size: + mask.append( + ops.lt( + idx_dim, + ops.index_expr(sympy.expand(end), torch.int64), + ) + ) + if step != 1: + mask.append( + ops.eq( + ops.index_expr( + ModularIndexing(idx[dim] - start, 1, step), torch.int64 + ), + ops.constant(0, torch.int64), + ) + ) + assert mask + mask = functools.reduce(ops.and_, mask) + src_val = ops.masked( + mask, + lambda: src_loader(src_idx), + 0 if lowering.is_integer_type(x) else 0.0, + ) + return ops.where( + mask, + src_val, + x_loader(idx), + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(x.get_size()), + traced_graph=new_graph, + node_name=node_name + ) + + + @register_lowering([torch.tensor, aten.scalar_tensor]) + def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False): + lowering.assert_nyi(layout in (None, torch.strided), f"layout={layout}") + lowering.assert_nyi(not pin_memory, "pin_memory") + input_graphs = fetch_graphs([data]) + node_name = f'tensor_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.scalar_tensor, node_name, \ + dtype=dtype, + device='npu', + layout=layout, + pin_memory=False) + if isinstance(lowering._unwrap(data), int): + dtype = dtype or torch.int64 + else: + dtype = dtype or torch.get_default_dtype() + + ranges: List[sympy.Expr] = [] + + if isinstance(data, sympy.Basic): + + def inner_fn(index): + return ops.index_expr(data, dtype) + + elif isinstance(data, (float, int)): + + def inner_fn(index): + return ops.constant(data, dtype) + + elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8: + # inline small tensors + ranges.append(sympy.Integer(len(data))) + + def inner_fn(index): + def binary_search(start, end): + assert start < end + if end - start == 1: + return ops.constant(data[start], dtype) + mid = (end - start) // 2 + start + return ops.where( + ops.lt( + ops.index_expr(index[0], torch.int64), + ops.constant(mid, torch.int64), + ), + binary_search(start, mid), + binary_search(mid, end), + ) + + if len(data) == 0: + return ops.constant(0, dtype) + return binary_search(0, len(data)) + + else: + return V.graph.add_tensor_constant( + torch.tensor(data, dtype=dtype, device=device) + ) + + return Pointwise.create( + device=decode_device(device), + dtype=dtype, + inner_fn=inner_fn, + ranges=ranges, + traced_graph=new_graph, + node_name=node_name + ) + + + def tensor_constructor(fill_value): + # torch.zeros, torch.ones, etc + def inner( + *size, + names=None, + dtype=None, + device=None, + layout=None, + pin_memory=False, + memory_format=None, + ): + lowering.assert_nyi(names is None, "named tensors") + lowering.assert_nyi(layout in (None, torch.strided), f"layout={layout}") + lowering.assert_nyi(not pin_memory, "pin_memory") + device = decode_device(device) + dtype = dtype or torch.get_default_dtype() + if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): + size = tuple(size[0]) + # See https://github.com/pytorch/pytorch/issues/118102 + # All sizes at lowering time should be sympy.Symbol, not SymInt! + for s in size: + assert not isinstance(s, torch.SymInt) + size = [sympy.expand(s) for s in size] + return _full(fill_value, device, dtype, size) + + return inner + + + def _full(fill_value, device, dtype, size): + value = fill_value + if not isinstance(fill_value, (int, float)) and hasattr(value, "value"): + value = value.value + + if isinstance(value, (int, float)): + + def inner_fn(index): + return ops.constant(value, dtype) + + elif isinstance(value, sympy.Basic): + + def inner_fn(index): + return ops.index_expr(value, dtype) + + else: + assert len(value.get_size()) == 0 + value_loader = value.make_loader() + + def inner_fn(index): + return value_loader([]) + + node_name = f'full_{next(node_id)}' + new_graph = merge_traced_graphs([size, fill_value], aten.full.default, node_name, \ + device='npu', dtype=dtype, layout = torch.strided, pin_memory = False) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(size), + traced_graph=new_graph, + node_name=node_name + ) + + + @register_lowering(aten.empty_strided) + def empty_strided( + size, stride, *, dtype=None, layout=None, device=None, pin_memory=None + ): + assert isinstance(size, (list, tuple)) + assert isinstance(stride, (list, tuple, type(None))) + lowering.assert_nyi(not pin_memory, "pin_memory") + lowering.assert_nyi(layout in (None, torch.strided), f"layout={layout}") + dtype = lowering.decode_dtype(dtype) or torch.get_default_dtype() + device = device or torch.tensor(0.0).device + device = decode_device(device) + pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size) + pointwise.realize() + buffer = pointwise.data.data + # explicitly set ranges to zeros in order to make a NopKernelSchedulerNode + buffer.data = lowering.dataclasses.replace(buffer.data, ranges=[0] * len(size)) + assert isinstance(buffer, ir.ComputedBuffer) + size = [sympy.expand(s) for s in size] + stride = ( + [sympy.expand(s) for s in stride] + if stride + else ir.FlexibleLayout.contiguous_strides(size) + ) + buffer.layout = ir.FixedLayout( + device=device, + dtype=dtype, + size=size, + stride=stride, + ) + return pointwise + + + @register_lowering([torch.empty, aten.empty]) + def empty( + *size, + names=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, + memory_format=None, + ): + lowering.assert_nyi(names is None, "named tensors") + device = decode_device(device) + if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): + size = tuple(size[0]) + return empty_strided( + size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + + @register_lowering([torch.full, aten.full]) + def full(size, fill_value, **kwargs): + assert kwargs.get("dtype") is not None, "dtype should be handled by decomposition" + return tensor_constructor(fill_value)(size, **kwargs) + + + register_lowering(aten.clone)(clone) + + + @register_lowering(aten.constant_pad_nd, type_promotion_kind=None) + def constant_pad_nd(x, padding, fill_value=0): + assert (len(padding) % 2) == 0 + + input_graphs = fetch_graphs([x, padding]) + node_name = f'constand_pad_nd_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.constant_pad_nd, node_name, value=fill_value) + + if all(p == 0 for p in padding): + return clone(x) + + sizes = x.get_size() + + bounds = list(reversed(list(zip(padding[::2], padding[1::2])))) + n = len(sizes) - len(bounds) + + # if padding is a complicated expression, hoist it + bounds_precomp: List[Tuple[sympy.Symbol, Any]] = [] + for l, h in bounds: + bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type] + + output_size = list(sizes[:n]) + mask_sizes = [] + for (low, high), size in zip(bounds, sizes[n:]): + mask_sizes.append(size) + output_size.append(sympy.expand(size + low + high)) + assert len(output_size) == len(sizes) + fill_value = dtype_to_type(x.get_dtype())(fill_value) + + def mask(index): + mask = [] + for idx, (low, high), length in zip(index[n:], bounds, mask_sizes): + if low != 0: + mask.append(lowering.range_mask_low(idx, 0)) + if high != 0: + mask.append(lowering.range_mask_high(idx, length)) + mask = functools.reduce(ops.and_, mask) + return ops.masked(mask, lambda: x_loader(index), fill_value) + + def offset_fn(index): + new_index = list(index[:n]) + for idx, (low, high) in zip(index[n:], bounds_precomp): + new_index.append(idx - low) + assert len(new_index) == len(index) + return mask(new_index) + + x_loader = x.make_loader() + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=offset_fn, + ranges=output_size, + traced_graph=new_graph, + node_name=node_name + ) + + + @make_pointwise + @register_to_aten(aten_fn=aten.pow) + def pow_native(a, b): + return ops.pow(a, b) + + + @register_lowering(aten.pow, broadcast=True) + def pow(a, b): + if isinstance(b, float) and b == int(b): + return pow(a, int(b)) + elif isinstance(b, float) and b == 0.5: + return sqrt(a) + elif isinstance(b, int) and b == 1: + return clone(a) + + input_graphs = fetch_graphs([a, b]) + node_name = f'pointwise_{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.pow, node_name) + + # Type promotion ensures all tensor arguments have the same type + dtype = next(x.get_dtype() for x in (a, b) if isinstance(x, ir.TensorBox)) + is_integer_pow = is_integer_dtype(dtype) + + # Optimize away small fixed powers, or for integers avoid falling back to ATen + embed_exponent = isinstance(b, int) and ( + -32 < b < 32 or (is_integer_pow and b >= 0) + ) + if embed_exponent: + loader = a.make_loader() + + def fn(idx): + return lowering.pow_recursive(loader(idx), b, a.get_dtype()) + + return Pointwise.create( + device=a.get_device(), + dtype=a.get_dtype(), + inner_fn=fn, + ranges=a.get_size(), + node_name=node_name, + traced_graph=new_graph, + ) + + if isinstance(a, Number): + if a == 1: + return full_like(b, 1) + if a == 2 and is_float_dtype(b.get_dtype()): + return exp2(b) + + if is_integer_pow: + # ops.pow doesn't work for integers + if isinstance(a, Number): + return lowering.fallback_pow_scalar(a, b) + elif isinstance(b, Number): + return lowering.fallback_pow_tensor_scalar(a, b) + else: + return lowering.fallback_pow_tensor_tensor(a, b) + + return pow_native(a, b) + + + def mutate_to(changed, val, unsafe_alias=False): + if isinstance(changed, TensorBox): + changed_data = changed.data + else: + changed_data = changed + if isinstance(val, TensorBox): + val = val.data + + if not isinstance(val, ir.StorageBox): + # introduce a copy to handle views + input_graphs = fetch_graphs([changed, val]) + node_name = f'copy__{next(node_id)}' + new_graph = merge_traced_graphs(input_graphs, aten.copy_, node_name) + val = Pointwise.create( + device=changed.get_device(), + dtype=changed.get_dtype(), + inner_fn=val.make_loader(), + ranges=changed.get_size(), + traced_graph=new_graph, + node_name=node_name + ).data + assert isinstance(val, ir.StorageBox) + + if isinstance(changed_data, ir.StorageBox) and not ( + changed_data.is_input_buffer() + # In AOTI, module parameters and buffers are not lifted as graph inputs + or changed_data.is_module_buffer() + or isinstance(changed_data.data, ir.NopKernel) + ): + # Fast path, just swing the data pointer + val.realize() + changed_data.data = val.data + return changed + + ir.MutationLayoutSHOULDREMOVE.realize_into( + val, changed_data, unsafe_alias=unsafe_alias + ) + return changed + + + empty_like = register_lowering(aten.empty_like)(lowering.create_tensor_like(empty)) + ones_like = lowering.create_tensor_like(tensor_constructor(1)) + zeros_like = lowering.create_tensor_like(tensor_constructor(0)) + + + @register_lowering(aten.full_like, type_promotion_kind=None) + def full_like(x, fill_value, **kwargs): + return lowering.create_tensor_like(tensor_constructor(fill_value))(x, **kwargs) + + + @register_lowering(aten.fill_) + def fill_(x, fill_value): + return mutate_to(x, full_like(x, fill_value)) + + + @register_lowering(aten.copy_, type_promotion_kind=None) + def copy_(dst, src, non_blocking=False): + if dst is src: + # dst.copy_(dst) can happen from the reinplacing pass + return dst + src = lowering.to_device(src, dst.get_device()) + src = to_dtype(src, dst.get_dtype()) + src = expand(src, dst.get_size()) + return mutate_to(dst, src) + + + @make_pointwise + def floordiv(a, b): + return ops.floordiv(a, b) + + + @make_pointwise + def truncdiv(a, b): + return ops.truncdiv(a, b) + + + @register_lowering(aten.div, broadcast=True) + def div_mode(a, b, rounding_mode=None): + both_integer = lowering.is_integer_type(a) and lowering.is_integer_type(b) + both_boolean = lowering.is_boolean_type(a) and lowering.is_boolean_type(b) + + # floordiv and truncdiv need special handling for integer tensors on Triton, + # see the discussion at https://github.com/openai/triton/issues/605 + if rounding_mode == "floor": + assert not both_boolean, "floordiv operands can not be boolean at the same time" + return floordiv(a, b) if both_integer else floor(div(a, b)) + if rounding_mode == "trunc": + assert not both_boolean, "truncdiv operands can not be boolean at the same time" + return truncdiv(a, b) if both_integer else trunc(div(a, b)) + return div(a, b) + + + @register_lowering([aten.mul], broadcast=True) + def mul(a, b): + both_bool = lowering.is_boolean_type(a) and lowering.is_boolean_type(b) + if both_bool: + return logical_and(a, b) + else: + fn = ops_wrapper(aten.mul.__name__) + fn = register_fn_to_aten_fn(fn, aten.mul) + return make_pointwise(fn)(a, b) + + + @register_lowering([aten.reciprocal], broadcast=True,) + def reciprocal(a): + return div(1.0, a) + + + @register_lowering([prims.div], broadcast=True) + def div_prim(a, b): + is_integral = all(lowering.is_boolean_type(x) or lowering.is_integer_type(x) for x in [a, b]) + + if is_integral: + return truncdiv(a, b) + + def fn(*args): + return ops.truediv(*args) + + fn = register_fn_to_aten_fn(fn, aten.div) + return make_pointwise(fn)(a, b) + + + @register_lowering( + [aten.true_divide, aten.div.Tensor], + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + def div(a, b): + a, b = lowering.promote_constants( + (a, b), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + return div_prim(a, b) + + + @register_lowering(aten.rsqrt) + def rsqrt(x): + dtype = x.get_dtype() + if is_integer_dtype(dtype) or is_boolean_dtype(dtype): + x = to_dtype(x, torch.get_default_dtype()) + + def _rsqrt(x): + return ops.rsqrt(x) + + register_fn_to_aten_fn(_rsqrt, aten.rsqrt) + return make_pointwise(_rsqrt)(x) + + + @register_lowering([aten.sum, prims.sum]) + def sum_(x, axis=None, keepdims=False, *, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + fn = make_reduction("sum", override_return_dtype=dtype) + return fn(x, axis, keepdims, dtype=dtype) + + + @register_lowering(aten.prod) + def prod(x, axis=None, keepdims=False, *, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + fn = make_reduction("prod", override_return_dtype=dtype) + return fn(x, axis, keepdims, dtype=dtype) + + + @register_lowering(aten.any) + def reduce_any(x, dim=None, keepdim=False): + x = to_dtype(x, torch.bool) + return make_reduction("any")(x, axis=dim, keepdims=keepdim) + + + @register_lowering(aten.max, type_promotion_kind=None) + def reduce_max(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amax(x, axis=dim, keepdims=keepdim), + reduce_argmax(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amax(x, axis=None, keepdims=keepdim) + + + @register_lowering(aten.min, type_promotion_kind=None) + def reduce_min(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amin(x, axis=dim, keepdims=keepdim), + reduce_argmin(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amin(x, axis=None, keepdims=keepdim) + + + register_lowering(prims.xor_sum)(make_reduction("xor_sum")) + reduce_amax = register_lowering(aten.amax)(make_reduction("max")) + reduce_amin = register_lowering(aten.amin)(make_reduction("min")) + reduce_argmax = register_lowering(aten.argmax)( + make_reduction("argmax", override_return_dtype=torch.int64) + ) + reduce_argmin = register_lowering(aten.argmin)( + make_reduction("argmin", override_return_dtype=torch.int64) + ) + + add = register_pointwise( + aten.add, allow_alpha=True, override_fn_when_input_bool="logical_or" + ) + + def register_pointwise_numeric(op, name=None, triton_fallback=None): + return register_pointwise( + op, + name=name, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + triton_fallback=triton_fallback, + ) + + + def register_pointwise_numeric_ldf64(op): + return register_pointwise( + op, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + use_libdevice_for_f64=True, + ) + + + def register_inplace(aten_op, outplace_op): + @register_lowering(aten_op, type_promotion_kind=None) + def fn(*args, **kwargs): + result = outplace_op(*args, **kwargs) + result = to_dtype(result, args[0].get_dtype()) + return mutate_to(args[0], result) + + return fn + + + rsqrt = register_pointwise_numeric(aten.rsqrt) + exp = register_pointwise_numeric_ldf64(aten.exp) + exp2 = register_pointwise_numeric(aten.exp2) + expm1 = register_pointwise_numeric(aten.expm1) + relu = register_pointwise(aten.relu) + sigmoid = register_pointwise_numeric_ldf64(aten.sigmoid) + sqrt = register_pointwise_numeric_ldf64(aten.sqrt) + square = register_pointwise(aten.square) + sub = register_pointwise(aten.sub, allow_alpha=True) + register_pointwise_numeric_ldf64(aten.cos) + register_pointwise_numeric_ldf64(aten.sin) + abs = register_pointwise(aten.abs) + bitwise_and = register_pointwise(aten.bitwise_and) + bitwise_left_shift = register_pointwise(aten.bitwise_left_shift) + bitwise_not = register_pointwise( + aten.bitwise_not, override_fn_when_input_bool="logical_not" + ) + bitwise_or = register_pointwise(aten.bitwise_or) + bitwise_right_shift = register_pointwise(aten.bitwise_right_shift) + bitwise_xor = register_pointwise(aten.bitwise_xor) + register_pointwise_numeric(aten.lgamma) + erf = register_pointwise_numeric(aten.erf) + register_lowering( + aten.special_erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + )(erf) + + register_pointwise_numeric(aten.log1p) + register_pointwise_numeric(aten.tan) + register_pointwise_numeric(aten.tanh) + register_pointwise_numeric_ldf64(aten.log) + logical_and = register_pointwise( + aten.logical_and, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, + ) + logical_not = register_pointwise( + aten.logical_not, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, + ) + logical_or = register_pointwise( + aten.logical_or, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, + ) + logical_xor = register_pointwise( + aten.logical_xor, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, + ) + maximum = register_pointwise(aten.maximum) + minimum = register_pointwise(aten.minimum) + clamp_min = register_pointwise(aten.clamp_min, name='maximum') + clamp_max = register_pointwise(aten.clamp_max, name='minimum') + neg = register_pointwise(aten.neg) + abs = register_pointwise(aten.abs) + register_pointwise(aten.remainder) + sign = register_pointwise(aten.sign, override_fn_when_input_bool="identity") + register_pointwise(aten.ceil) + register_pointwise(aten.signbit, override_return_dtype=torch.bool) + + register_lowering(aten._neg_view)(neg) + + register_pointwise(aten.le, override_return_dtype=torch.bool) + register_pointwise(aten.lt, override_return_dtype=torch.bool) + register_pointwise(aten.ge, override_return_dtype=torch.bool) + gt = register_pointwise(aten.gt, override_return_dtype=torch.bool) + register_pointwise(aten.eq, override_return_dtype=torch.bool) + register_pointwise(aten.ne, override_return_dtype=torch.bool) + + register_pointwise_numeric(aten.cosh) + register_pointwise_numeric(aten.sinh) + register_pointwise_numeric(aten.acos) + register_pointwise_numeric(aten.acosh) + register_pointwise_numeric(aten.asin) + register_pointwise_numeric(aten.asinh) + register_pointwise_numeric(aten.atan2) + register_pointwise_numeric(aten.atan) + register_pointwise_numeric(aten.atanh) + register_pointwise_numeric(aten.copysign) + register_pointwise_numeric(aten.erfc) + register_pointwise_numeric(aten.erfinv) + register_pointwise_numeric(aten.hypot) + register_pointwise_numeric(aten.log10) + register_pointwise_numeric(aten.log2) + register_pointwise_numeric(aten.nextafter) + + + register_inplace(aten.add_, add) + register_inplace(aten.bitwise_and_, bitwise_and) + register_inplace(aten.bitwise_left_shift_, bitwise_left_shift) + register_inplace(aten.bitwise_not_, bitwise_not) + register_inplace(aten.bitwise_or_, bitwise_or) + register_inplace(aten.bitwise_right_shift_, bitwise_right_shift) + register_inplace(aten.bitwise_xor_, bitwise_xor) + register_inplace(aten.mul_, mul) + register_inplace(aten.div_.Tensor, div) + register_inplace(aten.div_.Tensor_mode, div_mode) + register_inplace(aten.logical_and_, logical_and) + register_inplace(aten.logical_not_, logical_not) + register_inplace(aten.logical_or_, logical_or) + register_inplace(aten.logical_xor_, logical_xor) + register_inplace(aten.sub_, sub) + register_inplace(aten.relu_, relu) + register_inplace(aten.sigmoid_, sigmoid) + + + register_lowering(aten.__and__)(bitwise_and) + register_lowering(aten.__lshift__)(bitwise_left_shift) + register_lowering(aten.__or__)(bitwise_or) + register_lowering(aten.__rshift__)(bitwise_right_shift) + register_lowering(aten.__xor__)(bitwise_xor) + + register_inplace(aten.__iand__, aten.__and__) + register_inplace(aten.__ilshift__, aten.__lshift__) + register_inplace(aten.__ior__, aten.__or__) + register_inplace(aten.__irshift__, aten.__rshift__) + register_inplace(aten.__ixor__, aten.__xor__) + + + +########################################################################## + + @register_lowering(aten.mean) + def mean(x, axis=None, keepdim=False, *, dtype=None): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = lowering._validate_reduction_axis(x, axis) + # compute in higher-precision until end of mean lowering + output_dtype = x.get_dtype() + if output_dtype in (torch.float16, torch.bfloat16): + x = to_dtype(x, torch.float) + sum_result = sum_(x, axis, keepdim) + denom = sympy_product(size[i] for i in axis) + denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + return to_dtype(div(sum_result, denom), output_dtype) + + + @register_lowering(aten.cumsum) + def cumsum(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + # torch.int64->torch.int32 + dtype = torch.int32 + if len(x.get_size()) == 0: + if axis not in [0, -1]: + raise ValueError("axis must be 0 or -1") + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + return lowering.fallback_cumsum(x, dim=axis, dtype=dtype) + + + @register_lowering(npu.npu_dtype_cast, type_promotion_kind=None) + def _convert_npu_type(x: TensorBox, dtype: torch.dtype): + return to_dtype(x, dtype, copy=True) + + + def var_mean_sum_(x, axis, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + size = x.get_size() + axis = lowering._validate_reduction_axis(x, axis) + x_mean = mean(x, axis, keepdim=True) + if return_mean: + x_mean.realize() + + diffs = square(sub(x, x_mean)) + sum_result = sum_(diffs, axis, keepdim) + denom = sympy_product(size[i] for i in axis) + if correction: + denom = sympy.Max(denom - correction, 0) + denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + x_var = div(sum_result, denom) + if not return_mean: + return (x_var,) + + x_mean = x_mean if keepdim else squeeze(x_mean, axis) + return x_var, x_mean + + + def var_mean_helper_(x, *, axis, correction, keepdim, return_mean): + out_dtype = x.get_dtype() + compute_dtype = get_computation_dtype(out_dtype) + x = to_dtype(x, compute_dtype, copy=False) + kwargs = dict( + x=x, + axis=axis, + correction=correction, + keepdim=keepdim, + return_mean=return_mean, + ) + output = ( + var_mean_sum_(**kwargs) + ) + output = tuple(to_dtype(x, out_dtype, copy=False) for x in output) + return output[0] if not return_mean else output + + @register_lowering(aten.var_mean) + def var_mean(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True + ) + + @register_lowering([aten.var, prims.var]) + def var_(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False + ) + + @register_lowering(aten.embedding, type_promotion_kind=None) + def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): + return lowering.fallback_handler(aten.embedding.default)(weight, indices, padding_idx=-1, scale_grad_by_freq=False, + sparse=False) + + @register_lowering(aten.cat) + def cat(inputs, dim=0): + return lowering.fallback_handler(aten.cat.default)(inputs, dim) + + lowering.make_fallback(aten._log_softmax) + lowering.make_fallback(aten.gather) + lowering.make_fallback(aten.nll_loss_forward) \ No newline at end of file diff --git a/torch_npu/_inductor/npu_choices.py b/torch_npu/_inductor/npu_choices.py new file mode 100644 index 0000000000000000000000000000000000000000..ff0c11b4ff0e8bf20d91f480e5db0da327096751 --- /dev/null +++ b/torch_npu/_inductor/npu_choices.py @@ -0,0 +1,38 @@ +import typing +from typing import Any, Dict, List, Type, TYPE_CHECKING + +import sympy + +from torch._inductor import config +from torch._inductor.runtime.hints import ReductionHint +from torch._inductor.virtualized import V +from torch._inductor.codegen.simd_kernel_features import SIMDKernelFeatures +from torch._inductor.codegen.triton import TritonKernel + + + + + +@staticmethod +def should_use_persistent_reduction( + features: SIMDKernelFeatures, cooperative_reduction: bool +) -> bool: + """ + Heuristic to decide if a persistent reduction should be used. + """ + if not config.triton.persistent_reductions: + return False + threshold = { + ReductionHint.INNER: 1024, + ReductionHint.DEFAULT: 1024 + }.get(features.get_reduction_hint(), 64) + if cooperative_reduction: + # The RSPLIT of cooperative reductions means each thread block is operating on fewer elements + try: + threshold *= 32 // min(V.graph.sizevars.size_hint(features.numel), 32) + except ValueError: + pass # unbacked symint + + if config.triton.multi_kernel: + threshold *= 16 + return V.graph.sizevars.statically_known_leq(features.reduction_numel, threshold) # type: ignore[arg-types] \ No newline at end of file diff --git a/torch_npu/_inductor/npu_device.py b/torch_npu/_inductor/npu_device.py new file mode 100644 index 0000000000000000000000000000000000000000..2a41cc978ece1449d2841b703dbf12e597653ca1 --- /dev/null +++ b/torch_npu/_inductor/npu_device.py @@ -0,0 +1,247 @@ +import torch +from torch_npu.utils._inductor import NPUDeviceOpOverrides +from torch_npu.utils._dynamo_device import NpuInterface, current_device, set_device +from torch_npu.npu.utils import device_count + +## Override original inductor device overrides in torch_npu +class NewNPUDeviceOpOverrides(NPUDeviceOpOverrides): + def import_get_raw_stream_as(self, name): + return f"from torch_npu._inductor import get_current_raw_stream as {name}" + + def set_device(self, device_idx): + return f"torch.npu.set_device({device_idx})" + + def synchronize(self): + return """ + stream = torch.npu.current_stream() + stream.synchronize() + """ + + def device_guard(self, device_idx): + return f"torch.npu._DeviceGuard({device_idx})" + + def cpp_aoti_device_guard(self): + raise NotImplementedError + + def cpp_aoti_stream_guard(self): + return "AOTICudaStreamGuard" + + def kernel_driver(self): + source_codes = """ + static std::unordered_map registered_names; + static std::unordered_map> func_stubs; + + namespace { + + struct Grid { + Grid(uint32_t x, uint32_t y, uint32_t z) + : grid_x(x), grid_y(y), grid_z(z) {} + uint32_t grid_x; + uint32_t grid_y; + uint32_t grid_z; + + bool is_non_zero() { + return grid_x > 0 && grid_y > 0 && grid_z > 0; + } + }; + + } // anonymous namespace + + extern "C" { + typedef int (* callback)(unsigned int type, void* data, unsigned int len); + extern int MsprofReportApi(unsigned int agingFlag, const MsprofApi *api); + extern unsigned long int MsprofSysCycleTime(); + extern int MsprofRegisterCallback(unsigned int moduleId, callback handle); + static unsigned int __MsprofFlagL0 = 0; + static unsigned int __MsprofFlagL1 = 0; + + int ProfCtrlHandle(unsigned int CtrlType, void* CtrlData, unsigned int DataLen) { + if ((CtrlData == nullptr) || (DataLen == 0U)) { + return 1; + } + + if (CtrlType == 1) { + MsprofCommandHandle* handle = (MsprofCommandHandle *)(CtrlData); + if (handle->type >= 6) // 6 is not used here + return 1; + if (handle->type == 1) { // init - 0 , start - 1 + __MsprofFlagL0 = ((0x00000800ULL & handle->profSwitch) == 0x00000800ULL) ? 1 : 0; + __MsprofFlagL1 = ((0x00000002ULL & handle->profSwitch) == 0x00000002ULL) ? 1 : 0; + } + } + return 0; + } + } + + std::vector stringSplit(const std::string& s) { + std::vector tokens; + std::istringstream iss(s); + std::string token; + while (iss >> token) { + tokens.push_back(token); + } + return tokens; + } + + static inline void * loadKernel( + std::string filePath, + const std::string &nameFuncMode, + uint32_t sharedMemBytes, + const std::optional &cubinDir = std::nullopt) { + if (cubinDir) { + std::filesystem::path p1{*cubinDir}; + std::filesystem::path p2{filePath}; + filePath = (p1 / p2.filename()).string(); + } + auto splitNameMode = stringSplit(nameFuncMode); + if (splitNameMode.size() != 2) { + throw std::runtime_error(std::string("funcName not right: ") + nameFuncMode); + } + auto funcName = splitNameMode[0]; + auto kernel_mode_str = splitNameMode[1]; + std::ifstream file(std::string(filePath), std::ios::binary | std::ios::ate); + if (!file.is_open()) { + throw std::runtime_error(std::string("open npubin failed")); + } + + std::streamsize data_size = file.tellg(); + + file.seekg(0, std::ios::beg); + char* buffer = new char[data_size]; + if (!file.read(buffer, data_size)) { + throw std::runtime_error(std::string("read npubin failed")); + } + + rtError_t rtRet; + + rtDevBinary_t devbin; + devbin.data = buffer; + devbin.length = data_size; + const std::string kernel_mode{kernel_mode_str}; + if (kernel_mode == "aiv") + devbin.magic = RT_DEV_BINARY_MAGIC_ELF_AIVEC; + else + devbin.magic = RT_DEV_BINARY_MAGIC_ELF; + devbin.version = 0; + + int device = 0; + rtRet = rtSetDevice(device); + if (rtRet != RT_ERROR_NONE) { + throw std::runtime_error(std::string("rtSetDevice failed, 0x") + std::to_string(rtRet)); + } + + void *devbinHandle = NULL; + rtRet = rtDevBinaryRegister(&devbin, &devbinHandle); + if (rtRet != RT_ERROR_NONE) { + throw std::runtime_error(std::string("rtDevBinaryRegister failed, 0x") + std::to_string(rtRet)); + } + + const char* name = funcName.c_str(); + + std::string stubName(name); + stubName += "_" + std::to_string(registered_names[name]); + registered_names[name]++; + auto registered = func_stubs.emplace(stubName, std::make_unique(0)); + void *func_stub_handle = registered.first->second.get(); + rtRet = rtFunctionRegister(devbinHandle, func_stub_handle, stubName.c_str(), + (void *)name, 0); + if (rtRet != RT_ERROR_NONE) { + throw std::runtime_error(std::string("rtFunctionRegister failed, stubName = ") + stubName + + std::string(" , 0x") + std::to_string(rtRet)); + } + + return func_stub_handle; + } + + static void launchKernel(std::string kernelName, const void* func, rtStream_t stream, int gridX, int gridY, int gridZ, void *kernelArgs, int32_t kernelArgsSize) {{ + std::string name = ""; + name.append(kernelName); + char *kargs = new char[kernelArgsSize]; + memcpy(kargs, kernelArgs, kernelArgsSize); + auto launch_call = [=]() {{ + uint32_t blockNum = gridX * gridY * gridZ; + + rtError_t ret; + unsigned long int beginTime = 0; + unsigned long int endTime = 0; + unsigned long int opName = 0; + unsigned int threadId = 0; + const char* kernelNameC = kernelName.c_str(); + size_t length = name.length(); + {{ + beginTime = MsprofSysCycleTime(); + }} + ret = rtKernelLaunch(func, blockNum, kargs, kernelArgsSize, NULL, stream); + delete[] kargs; + return ret; + }}; + at_npu::native::OpCommand cmd; + cmd.Name(name.c_str()) + .SetCustomHandler(launch_call) + .Run(); + }} + """ + return source_codes + + def abi_compatible_header(self): + return """ + #include + #include + #include + #include + #include + #include + #include + #include + + #include + #include + #include + #include + #include + #include "experiment/runtime/runtime/rt.h" + """ + + def cpp_stream_type(self): + return "cudaStream_t" + + def aoti_get_stream(self): + return "aoti_torch_get_current_cuda_stream" + + def cpp_kernel_type(self): + return "void *" + + def cpp_device_ptr(self): + return "void*" + +## Override original dynamo device interface in torch_npu +class NewNpuInterface(NpuInterface): + + @staticmethod + def is_available() -> bool: + return device_count() > 0 + + @staticmethod + def get_compute_capability(device=None): + # npu has no concept of cc. triton-npu compiler depends on subarch instead + return torch.npu.get_device_name(device) + + @staticmethod + def exchange_device(device: int) -> int: + curr_device = current_device() + set_device(device) + return curr_device + + @staticmethod + def maybe_exchange_device(device: int) -> int: + return device + + @staticmethod + def is_bf16_supported(including_emulation: bool = False): + return True + + # @staticmethod + # def get_device_properties(device=None): + # props = NpuInterface.get_device_properties(device) + # setattr(props, "multi_processor_count", num_vector_core ) + # return props \ No newline at end of file diff --git a/torch_npu/_inductor/npu_fusion_attention_graph.py b/torch_npu/_inductor/npu_fusion_attention_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..2e421b6a1c2201a7a678d91eacf06ade5da9f105 --- /dev/null +++ b/torch_npu/_inductor/npu_fusion_attention_graph.py @@ -0,0 +1,238 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +import functools +import sympy +import torch +from torch.autograd import Function +from torch.library import Library, impl +import torch.nn.functional as F +import torch_npu + + + +npu_def = Library("npu_graph", "DEF") +npu_lib = Library("npu_graph", "IMPL", "PrivateUse1") +meta_lib = Library("npu_graph", "IMPL", "Meta") + +npu_def.define("npu_fa(Tensor query, Tensor key, Tensor value, int head_num, str input_layout, Tensor? pse=None, Tensor? padding_mask=None, Tensor? atten_mask=None, float scale=1., float keep_prob=1., int pre_tockens=2147483647, int next_tockens=2147483647, int inner_precise=0, int[]? prefix=None, int[]? actual_seq_qlen=None, int[]? actual_seq_kvlen=None, int sparse_mode=0, bool gen_mask_parallel=True, bool sync=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)") +npu_def.define("npu_fa_backward(Tensor query, Tensor key, Tensor value, Tensor dy, int head_num, str input_layout, *, Tensor? pse=None, Tensor? padding_mask=None, Tensor? atten_mask=None, Tensor? softmax_max=None, Tensor? softmax_sum=None, Tensor? softmax_in=None, Tensor? attention_in=None, float scale_value=1., float keep_prob=1., int pre_tockens=2147483647, int next_tockens=2147483647, int inner_precise=0, Tensor? seed=None, Tensor? offset=None, Tensor? numels=None, int[]? prefix=None, int[]? actual_seq_qlen=None, int[]? actual_seq_kvlen=None, int sparse_mode=0, bool gen_mask_parallel=True, bool sync=False) -> (Tensor, Tensor, Tensor, Tensor)") + + +@impl(npu_lib, "npu_fa") +def npu_fa(*args, **kwargs): + if len(args) > 8: + args = list(args) + # for scale + try: + args[8] = 1.0 / args[8] + except IndexError: + args[8] = 1.0 / (args[8] + 1e-6) + print("args[8]: zero can not be divided") + r1, r2, r3, r4, seed, offset, numel = torch_npu.npu_fusion_attention(*args, **kwargs) + r2.requires_grad = False + r3.requires_grad = False + r4.requires_grad = False + return r1, r2, r3, r4, torch.tensor([seed], requires_grad=False), torch.tensor([offset], requires_grad=False), torch.tensor([numel], requires_grad=False) + + +@impl(npu_lib, "npu_fa_backward") +def npu_fa_backward(*args, **kwargs): + if 'scale_value' in kwargs: + kwargs['scale_value'] = 1.0 / kwargs['scale_value'] + return torch_npu.npu_fusion_attention_grad(*args, **kwargs) + + +@impl(meta_lib, "npu_fa") +def npu_fa(query, key, value, head_num, input_layout, pse=None, padding_mask=None, + atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, + inner_precise=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False): + B = query.size(0) + N = head_num + S1 = query.size(2) + S2 = key.size(2) + + if input_layout == "BSH": + B = query.size(0) + S1 = query.size(1) + S2 = key.size(1) + + if input_layout == "SBH": + B = query.size(1) + S1 = query.size(0) + S2 = key.size(0) + + attention_score = torch.empty_like(query, dtype=query.dtype, device='meta').contiguous() + softmax_max = torch.empty([B, head_num, S1, 8], dtype=torch.float32, device='meta') + softmax_sum = torch.empty([B, head_num, S1, 8], dtype=torch.float32, device='meta') + softmax_out = torch.empty([0], dtype=query.dtype, device='meta') + return (torch.empty_like(attention_score), + torch.empty_like(softmax_max), + torch.empty_like(softmax_sum), + torch.empty_like(softmax_out), + torch.tensor([0], device='meta', requires_grad=False), + torch.tensor([0], device='meta', requires_grad=False), + torch.tensor([0], device='meta', requires_grad=False)) + + +@impl(meta_lib, "npu_fa_backward") +def npu_fa_backward(query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None, atten_mask=None, + softmax_max=None, softmax_sum=None, softmax_in=None, attention_in=None, scale_value=1.0, + keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, inner_precise=0, seed=0, offset=0, + numels=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False): + + dq = torch.empty_like(query, dtype=query.dtype, device='meta').contiguous() + dk = torch.empty_like(key, dtype=query.dtype, device='meta').contiguous() + dv = torch.empty_like(value, dtype=query.dtype, device='meta').contiguous() + dpse = torch.empty([0], dtype=query.dtype, device='meta').contiguous() + return (torch.empty_like(dq), torch.empty_like(dk), torch.empty_like(dv), torch.empty_like(dpse) if pse else None) + + +class NpuGraphAttentionFunction(Function): + @staticmethod + def forward(ctx, query, key, value, head_num, input_layout, pse=None, padding_mask=None, atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, inner_precise=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False): + # 前向传播逻辑 + # 这里假设有一个实现前向传播的函数 `npu_fusion_attention_forward` + result0, result1, result2, result3, result4, result5, result6 = torch.ops.npu_graph.npu_fa( + query, key, value, head_num, input_layout, pse=pse, padding_mask=padding_mask, atten_mask=atten_mask, scale=scale, keep_prob=keep_prob, pre_tockens=pre_tockens, next_tockens=next_tockens, inner_precise=inner_precise, prefix=prefix, actual_seq_qlen=actual_seq_qlen, actual_seq_kvlen=actual_seq_kvlen, sparse_mode=sparse_mode, gen_mask_parallel=gen_mask_parallel, sync=sync + ) + # 保存中间结果,以便在反向传播中使用 + ctx.save_for_backward(query, key, value, pse, padding_mask, atten_mask, result1, result2, result3, result0, result4, result5, result6) + ctx.head_num = head_num + ctx.input_layout = input_layout + ctx.scale = scale + ctx.keep_prob = keep_prob + ctx.pre_tockens = pre_tockens + ctx.next_tockens = next_tockens + ctx.inner_precise = inner_precise + ctx.prefix = prefix + ctx.actual_seq_qlen = actual_seq_qlen + ctx.actual_seq_kvlen = actual_seq_kvlen + ctx.sparse_mode = sparse_mode + ctx.gen_mask_parallel = gen_mask_parallel + ctx.sync = sync + + return result0, result1, result2, result3, result4, result5, result6 + + @staticmethod + def backward(ctx, grad_result0, grad_result1, grad_result2, grad_result3, grad_result4, grad_result5, grad_result6): + # 获取保存的中间结果 + query, key, value, pse, padding_mask, atten_mask, result1, result2, result3, result0, result4, result5, result6 = ctx.saved_tensors + # 反向传播逻辑 + # 这里假设有一个实现反向传播的函数 `npu_fusion_attention_backward` + grad_query, grad_key, grad_value, grad_pse = torch.ops.npu_graph.npu_fa_backward( + query, key, value, grad_result0, ctx.head_num, ctx.input_layout, pse=pse, padding_mask=padding_mask, atten_mask=atten_mask, softmax_max=result1, softmax_sum=result2, softmax_in=result3, attention_in=result0, scale_value=ctx.scale, keep_prob=ctx.keep_prob, pre_tockens=ctx.pre_tockens, next_tockens=ctx.next_tockens, inner_precise=ctx.inner_precise, seed=result4, offset=result5, numels=result6, prefix=ctx.prefix, actual_seq_qlen=ctx.actual_seq_qlen, actual_seq_kvlen=ctx.actual_seq_kvlen, sparse_mode=ctx.sparse_mode, gen_mask_parallel=ctx.gen_mask_parallel, sync=ctx.sync + ) + return (grad_query, grad_key, grad_value, None, None, grad_pse, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None) + + +def npu_fusion_attention_graph(query, key, value, head_num, input_layout, pse=None, padding_mask=None, + atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, + inner_precise=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False): + return NpuGraphAttentionFunction.apply(query, key, value, head_num, input_layout, pse, padding_mask, + atten_mask, scale, keep_prob, pre_tockens, next_tockens, + inner_precise, prefix, actual_seq_qlen, actual_seq_kvlen, sparse_mode, gen_mask_parallel, sync) +torch_npu.npu_fusion_attention_graph = npu_fusion_attention_graph + + +def register_fx_pass(): + TOKEN_MAX = 2147483647 + from torch._inductor.pattern_matcher import register_replacement, fwd_only, joint_fwd_bwd + from torch._inductor.fx_passes.joint_graph import patterns + from torch._dynamo.utils import counters + from torch._inductor.fx_passes.fuse_attention import partialize_and_update_signature + + def _npu_fusion_attention_graph_pattern_1(query, key, value, inv_scale_factor, dropout_p): + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return torch.nn.functional.dropout( + torch.matmul(q, k.transpose(-2, -1)).div(inv_scale_factor).softmax(dim=-1), + p=dropout_p, + ).matmul(v) + + + def _npu_fusion_attention_graph_replacement_1(query, key, value, inv_scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + head_num = query.size(2) + input_layout = "BNSD" + return torch_npu.npu_fusion_attention_graph( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + head_num, + input_layout, + None, + atten_mask=None, + scale=inv_scale_factor, + keep_prob=1.0 - dropout_p, + )[0] + + def _get_sfdp_patterns(): + device = 'npu' + g_inp = functools.partial( + torch.empty, (2, 4, 8, 16), device=device, requires_grad=True + ) + c_inp = functools.partial(torch.tensor, 2.0, device=device) + d = {"dropout_p": 0.113377} + candidates = [] + for dtype in [torch.float]: + g = functools.partial(g_inp, dtype=dtype) + c = functools.partial(c_inp, dtype=dtype) + candidates.append(( + _npu_fusion_attention_graph_pattern_1, + _npu_fusion_attention_graph_replacement_1, + [g(), g(), g(), c()], + d, + )) + + for pattern, replacement, args, workaround in candidates: + # gets serialized to a python file and does not require tracing at runtime. + if not isinstance(workaround, dict): + raise ValueError("workaround not dict") + name = pattern.__name__ + + if dtype != torch.float: + name += "_half" + + if args[0].size(0) == 1: + name += "_bs1" + + training_name = name + "_training" + yield training_name, { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": joint_fwd_bwd, + "pass_dicts": patterns, + "scalar_workaround": workaround, + } + + if workaround: + if not (len(workaround) == 1 and "dropout_p" in workaround): + raise ValueError("not (len(workaround) == 1 and dropout_p in workaround)") + # functools.partial insufficient because we look at signature downstream + pattern = partialize_and_update_signature(pattern, dropout_p=0.0) + replacement = partialize_and_update_signature( + replacement, dropout_p=0.0 + ) + workaround = {} + + inference_name = name + "_inference" + yield inference_name, { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": fwd_only, + "pass_dicts": patterns, + "scalar_workaround": workaround, + } + + for _, register_replacement_kwargs in _get_sfdp_patterns(): + register_replacement( + **register_replacement_kwargs, + ) + +register_fx_pass() + + + diff --git a/torch_npu/_inductor/npu_triton_helpers.py b/torch_npu/_inductor/npu_triton_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..5f60f5dd499eec9257d4e37bb6cf9094059b506f --- /dev/null +++ b/torch_npu/_inductor/npu_triton_helpers.py @@ -0,0 +1,20 @@ +import triton +import triton.language as tl + +import triton.language.extra.ascend.libdevice as libdevice +from torch._inductor.runtime import triton_helpers +libdevice = tl.extra.ascend.libdevice +math = tl.math + + +@triton.jit +def maximum(a, b): + return tl.maximum(a, b) + + +@triton.jit +def minimum(a, b): + return tl.minimum(a, b) + +triton_helpers.maximum = maximum +triton_helpers.minimum = minimum diff --git a/torch_npu/_inductor/npu_triton_heuristics.py b/torch_npu/_inductor/npu_triton_heuristics.py new file mode 100644 index 0000000000000000000000000000000000000000..43acb8dd8c1b2b8e499c7b60880dd8346eef7187 --- /dev/null +++ b/torch_npu/_inductor/npu_triton_heuristics.py @@ -0,0 +1,1091 @@ +# This file is based on triton_heuristics with heuristics designed for NPU +import os +import sys +import functools +import time +import copy +import importlib +from typing import Any, Callable, List, Optional +import logging +import re +import hashlib +import json + +import torch +from torch._inductor import config +from torch._dynamo.utils import dynamo_timed +from torch._inductor.runtime.triton_heuristics import ( + CachingAutotuner, + HeuristicType, + unique_configs, + hash_configs, + Config, + ASTSource, + _find_names, + get_first_attr, + collected_calls, + _dump_launch_params, + builtins +) +from torch._inductor.runtime.benchmarking import benchmarker +from torch._inductor.runtime.autotune_cache import AutotuneCache + + +from torch._inductor.runtime.runtime_utils import ( + create_bandwidth_info_str, + get_num_bytes, + +) + +from torch._inductor.compile_fx import clone_preserve_strides + +import triton +from triton.compiler import CompiledKernel + +try: + from triton.backends.compiler import GPUTarget + from triton.runtime.autotuner import OutOfResources + import torch.autograd.profiler as autograd_profiler +except ImportError: + GPUTarget = None + OutOfResources = None + autograd_profiler = None + +from .codegen.split_tiling import SplitTiling +from .utils import get_current_raw_stream +from .codegen.tile_generator import TileGenerator +from .codegen.triton_utils import get_aligned_numel +from .config import aggresive_autotune +from .config import log +from . import config as npu_config + + +# torch-261 +class NPUCachingAutotuner(CachingAutotuner): + def __init__( + self, + fn, + triton_meta, # passed directly to triton + configs, + save_cache_hook, + mutated_arg_names: List[str], # see [Note: clone mutated buffers] + optimize_mem, + heuristic_type, + size_hints=None, + inductor_meta=None, # metadata not relevant to triton + custom_kernel=False, # whether the kernel is inductor-generated or custom + filename: Optional[str] = None, + reset_to_zero_arg_names: Optional[List[str]] = None, + ): + super().__init__(fn, triton_meta, configs, save_cache_hook, mutated_arg_names, optimize_mem, heuristic_type, + size_hints, inductor_meta, custom_kernel, filename, reset_to_zero_arg_names) + + self.exceptions = [] + + def precompile(self, warm_cache_only=False): + # xpu_graph changed TORCHINDUCTOR_CACHE_DIR. + # When TORCHINDUCTOR_COMPILE_THREADS > 1, multiprocessing's fork method + # does not propagate TORCHINDUCTOR_CACHE_DIR into the child threads. + # However, after all the child threads finished, the main thread reaches + # here and inherits xpu_graph's TORCHINDUCTOR_CACHE_DIR. Then the main + # thread finds the cache dir does not have any compiled kernel. It will + # compile all kernels one by one. + # So we directly replace TORCHINDUCTOR_CACHE_DIR with the standard cache dir. + if ("xpu_graph" in os.getenv("TORCHINDUCTOR_CACHE_DIR", "")): + import getpass + import tempfile + sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser()) + cache_dir = os.path.join( + tempfile.gettempdir(), + "torchinductor_" + sanitized_username, + ) + os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir + os.environ["TRITON_CACHE_DIR"] = os.path.join(cache_dir, "triton", "0") + with self.lock: + if self.launchers: + return + self.launchers = [] + compiled_binaries = [] + if not self.configs: + raise RuntimeError("No triton configs are available") + for c in self.configs: + try: + print(f"start compile kernel {self.inductor_meta['kernel_name']} config:{c.kwargs}", flush=True) + compiled_binary, launcher = self._precompile_config( + c, warm_cache_only + ) + except Exception as e: + log.debug(f"[thread {os.getpid()}][InductorNPU.precompile] Exception = {e}, kernel = {self.fn.__name__} config = {c}") + # Skip the config if the compilation fails + continue + if launcher is not None: + self.launchers.append(launcher) + compiled_binaries.append(compiled_binary) + + if len(self.launchers) == 0: + raise RuntimeError( + "No valid triton configs. Report a fatal compilation error" + ) + + self.configs = None + + + def _precompile_config(self, cfg: Config, warm_cache_only: bool): + """Ahead of time compile a given autotuner config.""" + compile_meta = copy.deepcopy(self.triton_meta) + + for k, v in cfg.kwargs.items(): + if k not in self.fn.arg_names: + continue + compile_meta["constants"][k] = v + + compile_meta["num_warps"] = cfg.num_warps + compile_meta["num_stages"] = cfg.num_stages + + compile_meta["debug"] = ( + os.getenv("INDUCTOR_ASCEND_DEBUG", 'false').lower() in ('true', '1') and + config.assert_indirect_indexing and torch.version.hip is None + ) + + # device type will be "hip" rather than "cuda" here + compile_meta["device_type"] = self.device_props.type + compile_meta["cc"] = self.device_props.cc + + if ASTSource: + compile_args = ( + ASTSource( + self.fn, + compile_meta["signature"], + compile_meta["constants"], + ), + ) + + cc_str = str(compile_meta["cc"]) + if "gfx10" in cc_str or "gfx11" in cc_str: + rocm_warp_size = 32 + else: + rocm_warp_size = 64 + + if GPUTarget: + target = GPUTarget( + compile_meta["device_type"], + compile_meta["cc"], + rocm_warp_size if torch.version.hip else 32, + ) + else: + target = ( + (compile_meta["device_type"], compile_meta["cc"]) + if not torch.version.hip + else [ + compile_meta["device_type"], + compile_meta["cc"], + rocm_warp_size, + ] + ) + + options = { + "num_warps": compile_meta["num_warps"], + "num_stages": compile_meta["num_stages"], + "debug": compile_meta["debug"], + } + if self.device_props.type == "hip": + if "waves_per_eu" in compile_meta: + options["waves_per_eu"] = compile_meta["waves_per_eu"] + if "matrix_instr_nonkdim" in compile_meta: + options["matrix_instr_nonkdim"] = compile_meta[ + "matrix_instr_nonkdim" + ] + compile_kwargs = { + "target": target, + "options": options, + } + else: + compile_args = (self.fn,) + compile_kwargs = compile_meta + if warm_cache_only: + return ( + triton.compile(*compile_args, **compile_kwargs), + None, + ) + + # importing from torch is safe now that precompile has returned + from torch._dynamo.device_interface import DeviceGuard + + device_interface = self.get_device_interface() + + # load binary to the correct device + with DeviceGuard(device_interface, compile_meta["device"]): # type: ignore[attr-defined] + # need to initialize context + device_interface.synchronize(device_interface.current_device()) + + try: + + binary = triton.compile(*compile_args, **compile_kwargs) + binary._init_handles() + + except Exception: + log.exception( + "Triton compilation failed: %s\n%s\nmetadata: %s", + self.inductor_meta.get("kernel_name", "triton_"), + self.fn.src, + compile_meta, + ) + raise + + call_args = [ + arg + for i, arg in enumerate(self.fn.arg_names) + if i not in self.fn.constexprs + ] + def_args = [name for name in self.fn.arg_names if name not in cfg.kwargs] + + binary_shared = ( + binary.shared if hasattr(binary, "shared") else binary.metadata.shared + ) + + scope = { + "grid_meta": cfg.kwargs, + "bin": binary, + "launch_enter_hook": CompiledKernel.launch_enter_hook, + "launch_exit_hook": CompiledKernel.launch_exit_hook, + "metadata": binary.packed_metadata + if hasattr(binary, "packed_metadata") + else binary.metadata, + "shared": binary_shared, + } + + scope["num_warps"] = ( + binary.num_warps + if hasattr(binary, "num_warps") + else binary.metadata.num_warps + ) + + scope["cta_args"] = ( + (binary.num_ctas, *get_first_attr(binary, "cluster_dims", "clusterDims")) + if hasattr(binary, "num_ctas") + else ( + (binary.metadata.num_ctas, *binary.metadata.cluster_dims) + if hasattr(binary, "metadata") + else () + ) + ) + + scope["function"] = get_first_attr(binary, "function", "cu_function") + + def get_launch_args_without_kernel_launch_metadata( + input_grid, + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + input_bin, + launch_enter_hook, + launch_exit_hook, + num_warps, + shared, + cta_args, + args, + ): + """ + Construct launch args before CompiledKernel.launch_metadata is added. + """ + return ( + grid_0, + grid_1, + grid_2, + num_warps, + *cta_args, + shared, + stream, + function, + launch_enter_hook, + launch_exit_hook, + metadata, + ) + + # Getting the kernel launch args is extremely perf-sensitive. Evaluating + # `bin.launch_metadata` is relatively expensive, and returns None unless a + # `launch_enter_hook` is installed. So if we don't have that hook installed, + # we want to burn None in to the launch args with zero overhead. + if binary.launch_enter_hook: + + def get_launch_args_with_kernel_launch_metadata( + input_grid, + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + input_bin, + launch_enter_hook, + launch_exit_hook, + num_warps, + shared, + cta_args, + args, + ): + """ + Construct launch args after CompiledKernel.launch_metadata is added + """ + return ( + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + input_bin.launch_metadata(input_grid, stream, *args), + launch_enter_hook, + launch_exit_hook, + ) + + else: + + def get_launch_args_with_kernel_launch_metadata( + input_grid, + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + input_bin, + launch_enter_hook, + launch_exit_hook, + num_warps, + shared, + cta_args, + args, + ): + """ + Construct launch args after CompiledKernel.launch_metadata is added + """ + return ( + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + None, + launch_enter_hook, + launch_exit_hook, + ) + + scope["get_launch_args"] = ( + get_launch_args_with_kernel_launch_metadata + if hasattr(binary, "launch_metadata") + else get_launch_args_without_kernel_launch_metadata + ) + + scope["runner"] = get_first_attr(binary, "run", "c_wrapper") + + exec( + f""" + def launcher({', '.join(def_args)}, grid, stream): + if callable(grid): + grid_0, grid_1, grid_2 = grid(grid_meta) + else: + grid_0, grid_1, grid_2 = grid + + args = {', '.join(call_args)}, + launch_args = get_launch_args( + grid, grid_0, grid_1, grid_2, stream, function, + metadata, bin, launch_enter_hook, launch_exit_hook, + num_warps, shared, cta_args, args + ) + runner(*launch_args, *args) + return bin + """.lstrip(), + scope, + ) + + launcher = scope["launcher"] + launcher.config = cfg + launcher.n_regs = getattr(binary, "n_regs", None) + launcher.n_spills = getattr(binary, "n_spills", None) + launcher.shared = binary_shared + launcher.store_cubin = True + # store this global variable to avoid the high overhead of reading it when calling run + if launcher.store_cubin: + launcher.fn = self.fn + launcher.bin = binary + + return binary, launcher + + def save_gpu_kernel(self, input_grid, input_stream, input_launcher): + self.save_npu_kernel(input_grid, input_stream, input_launcher) + + def save_npu_kernel(self, input_grid, input_stream, input_launcher): + if callable(input_grid): + grid_x, grid_y, grid_z = input_grid(input_launcher.config.kwargs) + else: + grid_x, grid_y, grid_z = input_grid + + key = self.inductor_meta.get("kernel_name", None) # unique kernel name + + if key is None: + raise RuntimeError("assert key is not None, kernel_name can not be None") + params = { + "mangled_name": ( + input_launcher.bin.metadata.name + if hasattr(input_launcher.bin.metadata, "name") + else input_launcher.bin.metadata["name"] + ), + "grid_x": grid_x, + "grid_y": grid_y, + "grid_z": grid_z, + "num_warps": ( + input_launcher.bin.num_warps + if hasattr(input_launcher.bin, "num_warps") + else input_launcher.bin.metadata.num_warps + ), + "shared_mem": ( + input_launcher.bin.shared + if hasattr(input_launcher.bin, "shared") + else input_launcher.bin.metadata.shared + ), + "stream": input_stream, + # User defined triton kernels will have arbitrary kwarg names + "meta": input_launcher.config.kwargs, + } + from torch._inductor.codecache import CudaKernelParamCache + + bin_type = "npubin" + binary = input_launcher.bin.asm[bin_type] # npubin type = npubin + CudaKernelParamCache.set(key, params, binary, bin_type='cubin') # CudaKernelParam + + self.cuda_kernel_saved = True + + # bench method is called by torch, grid can not be modified + def bench(self, launcher, *args, grid, with_profiler=False, **kwargs): + """Measure the performance of a given launcher""" + + if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get( + "spill_threshold", 16 + ): + log.debug( + "Skip config %s because of register spilling: %d", + launcher.config, + launcher.n_spills, + ) + return float("inf") + + device_interface = self.get_device_interface() + stream = device_interface.get_raw_stream(device_interface.current_device()) + + def kernel_call(): + cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs) + launcher( + *cloned_args, + **cloned_kwargs, + grid=grid, + stream=stream, + ) + + if with_profiler: + from torch._inductor.utils import do_bench_using_profiling + ret = do_bench_using_profiling(kernel_call, warmup=10, rep=1) + + + print(f"start bench for kernel {self.inductor_meta['kernel_name']} config:{launcher.config}", flush=True) + # remove fast_flush=True for high version triton + ret = benchmarker.benchmark_gpu(kernel_call, rep=1) + print(f"do bench ret = {ret} ",flush=True) + return ret + + def autotune_to_one_config(self, *args, **kwargs): + """Do the actual autotuning""" + start_time = time.time_ns() + timings = self.benchmark_all_configs(*args, **kwargs) + benchmark_time_taken_ns = time.time_ns() - start_time + self.launchers = [builtins.min(timings, key=timings.get)] + self.autotune_time_taken_ns = ( + self.precompile_time_taken_ns + benchmark_time_taken_ns + ) + if self.save_cache_hook: + self.save_cache_hook(self.launchers[0].config, self.autotune_time_taken_ns) + print(f"saved best_config:{self.launchers[0].config.kwargs}", flush=True) + + + def get_fx_graph_call(self, auto_fallback=False): + kernel_name = self.inductor_meta.get("kernel_name", "triton_") + traced_graph_hash = self.inductor_meta.get("traced_graph_hash") + dump_path = os.getenv(traced_graph_hash, None) + if not dump_path: + return None + sys.path.append(dump_path) + fx_module = importlib.import_module(traced_graph_hash) + sys.path.remove(dump_path) + + model = fx_module.model + num_inputs = fx_module.num_inputs + num_outputs = fx_module.num_outputs + non_contiguous_indices = fx_module.non_contiguous_indices + mismatch_indices_shapes = fx_module.mismatch_indices_shapes + + def fx_graph_call(*fx_args): + fx_inputs = [fx_args[idx].contiguous() if idx in non_contiguous_indices['inputs'] else \ + fx_args[idx] for idx in range(num_inputs)] + if len(mismatch_indices_shapes): + for ind, shape in mismatch_indices_shapes.items(): + if ind >= num_inputs: + break + fx_inputs[ind] = fx_inputs[ind].reshape(shape) + model_outputs = model.forward(*fx_inputs) + for idx, (out1, out2) in enumerate(zip(model_outputs, fx_args[num_inputs:(num_inputs + num_outputs)])): + out1 = out1.reshape(out2.shape) + if idx in non_contiguous_indices['outputs']: + out2.copy_(out1) + else: + out2.data = out1.data + + def fallback_call(*args): + fx_args = [args[idx] for idx in fx_module.call_args_mapping] + return fx_graph_call(*fx_args) + if auto_fallback: + return fallback_call, kernel_name + return fx_graph_call, kernel_name, dump_path, fx_module + + def data_dump(self, *args, dump_path=None): + data_dump_path = os.path.join(dump_path, 'data.pth') + torch.save(args, data_dump_path) + + def check_accuracy(self, *args, launcher, grid, stream, **kwargs): + fx_call_and_kwargs = self.get_fx_graph_call() + if not fx_call_and_kwargs: + return None + fx_graph_call, kernel_name, dump_path, fx_module = self.get_fx_graph_call() + call_outputs_indices = fx_module.call_args_mapping[fx_module.num_inputs:] + self.data_dump(*args, dump_path=dump_path) + + fx_args = [] + for idx in fx_module.call_args_mapping: + arg = args[idx] + if isinstance(arg, torch.Tensor): + fx_arg = clone_preserve_strides(arg).float() if arg.dtype == torch.bfloat16 else clone_preserve_strides(arg) + fx_args.append(fx_arg) + + fx_graph_call(*fx_args) + + ret = launcher( + *args, + **kwargs, + grid=grid, + stream=stream, + ) + for actual, expected in zip([args[i] for i in call_outputs_indices], fx_args[fx_module.num_inputs:]): + if actual.dtype != expected.dtype: + expected = expected.to(actual.dtype) + acc_comp_tol = npu_config.acc_comp_tol.get(actual.dtype, npu_config.acc_comp_tol['default']) + rtol = acc_comp_tol['rtol'] + atol = acc_comp_tol['atol'] + + matches = torch.isclose( + actual, expected, rtol=rtol, atol=atol, equal_nan=False + ) + if not matches.all(): + abs_diff = torch.abs(actual - expected) + rel_diff = abs_diff / torch.abs(expected) + rel_diff.masked_fill_(matches, 0) + print(f"CHECK ACCURACY FAILED! Greatest Relative Difference: {rel_diff.max().item()}, " f"Kernel Name: {kernel_name}", flush=True) + print(f"kernel {kernel_name} Dump Path: {dump_path}") + actual.copy_(expected) + del matches + for arg in fx_args: + del arg + return True + + + def run( + self, *args, grid, stream, benchmark_run=False, **kwargs + ): # type:ignore[override] + if self.triton_interpret: + return self.fn[grid]( + *args, + **kwargs, + **self.configs[0].kwargs, + ) + + if hasattr(self.launchers[0], "fallback"): + return self.launchers[0]( + *args, + **kwargs, + ) + + if len(self.launchers) != 1: + if len(self.launchers) == 0: + start_time = time.time_ns() + self.precompile() + self.precompile_time_taken_ns = time.time_ns() - start_time + if len(self.launchers) > 1: + self.autotune_to_one_config(*args, grid=grid, **kwargs) + + if not getattr( + self.launchers[0].config, "found_by_coordesc", False + ) and self.inductor_meta.get("coordinate_descent_tuning", False): + self.launchers = [ + self.coordinate_descent_tuning( + self.launchers[0], *args, grid=grid, **kwargs + ) + ] + + (launcher,) = self.launchers + if launcher.store_cubin and (not benchmark_run or not self.cuda_kernel_saved): + self.save_gpu_kernel(grid, stream, launcher) + + if self.dump_launch_params: + _dump_launch_params(args, kwargs, launcher, self.fn.__name__) + + if npu_config.check_accuracy: + if self.check_accuracy(*args, launcher=launcher, grid=grid, stream=stream, **kwargs): + return + + # it is faster than entering and exiting a context manager, even if the context + # manager is a nullcontext. + if autograd_profiler._is_profiler_enabled: + # grid can be a tuple of ints or a string. + if isinstance(grid, tuple): + grid_info = str(grid) + else: + grid_info = getattr(grid, "grid_fn_str", "") + + with torch._C._profiler._RecordFunctionFast( + self.inductor_meta.get("kernel_name", "triton kernel"), + args, + { + "kernel_file": (self.filename or ""), + "kernel_hash": self.kernel_hash, + "kernel_backend": "triton", + "grid": grid_info, + "stream": stream, + }, + ): + return launcher( + *args, + **kwargs, + grid=grid, + stream=stream, + ) + else: + return launcher( + *args, + **kwargs, + grid=grid, + stream=stream, + ) + + +class NPUDebugAutotuner(NPUCachingAutotuner): + def __init__(self, *args, regex_filter="", **kwargs): + self.regex_filter = regex_filter + super().__init__(*args, **kwargs) + self.cached = None + + def run(self, *args, input_grid, stream): + possible_names = _find_names(self) + kernel_name = f"{max(possible_names, key=len)}" + if not re.match(self.regex_filter, kernel_name): + return + super().run(*args, grid=input_grid, stream=stream) + (launcher,) = self.launchers + + if self.cached is None: + ms = self.bench(launcher, *args, input_grid=input_grid) + num_in_out_ptrs = len( + [ + arg_name + for arg_name in self.fn.arg_names + if arg_name.startswith("in_out_ptr") + ] + ) + num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 + gb_per_s = num_gb / (ms / 1e3) + self.cached = (ms, num_gb, gb_per_s, kernel_name) + else: + ms, num_gb, gb_per_s, kernel_name = self.cached + collected_calls.append((ms, num_gb, gb_per_s, kernel_name)) + print( + create_bandwidth_info_str(ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}") + ) + + +def cached_autotune( + size_hints: Optional[List[int]], + configs: List[Config], + triton_meta, + heuristic_type, + filename=None, + inductor_meta=None, + custom_kernel=False, +): + """ + A copy of triton.autotune that calls our subclass. Our subclass + has additional debugging, error handling, and on-disk caching. + """ + configs = unique_configs(configs) + if not (len(configs) == 1 or filename): + raise RuntimeError("assert len(configs) == 1 or filename") + + inductor_meta = {} if inductor_meta is None else inductor_meta + + disabled = inductor_meta.get("force_disable_caches", False) + + # on disk caching logic and/or remote caching + autotune_cache = None + if ( + not disabled + and filename is not None + and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning")) + and not os.environ.get("TRITON_INTERPRET", "0") == "1" + ): + configs_hash = hash_configs(configs) + + autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash) + if autotune_cache: + if best_config := autotune_cache.read_best(inductor_meta, configs): + configs = [best_config] + print(f"loaded best_config: {best_config.kwargs}", flush=True) + else: + if disabled: + log.debug("autotune caching is disabled by config.force_disable_caches") + + mutated_arg_names = inductor_meta.pop("mutated_arg_names", ()) + optimize_mem = inductor_meta.pop("optimize_mem", True) + + if "restore_value" in triton_meta: + mutated_arg_names += triton_meta.pop("restore_value") + + reset_to_zero_arg_names: List[str] = [] + if "reset_to_zero" in triton_meta: + reset_to_zero_arg_names.extend(triton_meta.pop("reset_to_zero")) + + def decorator(fn): + + if inductor_meta.get("profile_bandwidth"): + return NPUDebugAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + regex_filter=inductor_meta["profile_bandwidth_regex"], + with_profiler=inductor_meta[ + "profile_bandwidth_with_do_bench_using_profiling" + ], + configs=configs, + save_cache_hook=autotune_cache and autotune_cache.save, + mutated_arg_names=mutated_arg_names, + reset_to_zero_arg_names=reset_to_zero_arg_names, + optimize_mem=optimize_mem, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + filename=filename, + with_bandwidth_info=True, + ) + return NPUCachingAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + configs=configs, + save_cache_hook=autotune_cache and autotune_cache.save, + mutated_arg_names=mutated_arg_names, + reset_to_zero_arg_names=reset_to_zero_arg_names, + optimize_mem=optimize_mem, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + filename=filename, + ) + + return decorator + + +###################################################### +## Main entry points for triton kernel invocation ## +## adapts original heuristics for NPU arch, and ## +## redirect to NPUCaching autotuner ## +###################################################### + +def grid(*numels): + def grid_fn(meta): + split_axis = meta["split_axis"] + split_blocks = meta["split_blocks"] + programs = [ ] + for i, order in enumerate(split_axis) : + if not numels : + continue + numel = numels[order] + block = split_blocks[i] + programs.append((numel + block -1) // block) + + for _ in range(3 - len(programs)) : + programs.append(1) + #log.debug("launch grid(numels:%s), programs:%s, meta:%s", numels, programs, meta) + return tuple(programs) + + return grid_fn + + +# split:sizeof split, xblock:axis1 length, rblock:axis2 length +def triton_config_npu_index( + size_hints, + inductor_meta, + triton_meta=None, + reduction=False, + persistent_reduction=False, + +) -> List[Config]: + num_warps = 1 + num_stages = 1 + configs = [] + log.info("[InductorNPU] processing kernel %s", inductor_meta['kernel_name']) + split_axis = inductor_meta["split_axis"] + tiling_axis = inductor_meta["tiling_axis"] + low_dims = inductor_meta["low_dims"] + split_axis_dtype = inductor_meta["split_axis_dtype"] + axis_names = inductor_meta["axis_names"] + dual_reduction = inductor_meta["dual_reduction"] + + tile_generator = TileGenerator(size_hints, axis_names, tiling_axis, split_axis, low_dims, + persistent_reduction = persistent_reduction, configs=configs, + dtype = split_axis_dtype, dual_reduction=dual_reduction ) + + tile_generator.descend_split_tiling() + + if not configs : + cfg = {} + for x in split_axis : + cfg[f"{axis_names[x].upper()}BLOCK"] = size_hints[x] + if not cfg : + cfg["dummy"] = 1 + tmp = Config(cfg, num_warps=num_warps, num_stages=num_stages) + configs.append(tmp) + + for cfg in configs : + split_blocks = [None for x in split_axis] + for i,axis in enumerate(split_axis) : + name = axis_names[axis] + block_name = f"{name.upper()}BLOCK" + split_blocks[i] = cfg.kwargs[block_name] + cfg.kwargs["split_axis"] = tuple(split_axis) + cfg.kwargs["split_blocks"] = tuple(split_blocks) + #log.info("generated tiling configs %s", cfg.kwargs) + + + return configs + + +def pointwise_npu_index( + size_hints, + triton_meta, + tile_hint=None, + filename=None, + min_elem_per_thread=0, + inductor_meta=None, +): + + inductor_meta = {} if inductor_meta is None else inductor_meta + triton_config_with_settings = functools.partial( + triton_config_npu_index + ) + return cached_autotune( + size_hints, + triton_config_with_settings(size_hints, inductor_meta=inductor_meta), + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + + +def reduction_npu_index( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + + """args to @triton.heuristics()""" + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if triton_meta is None: + raise RuntimeError("assert triton_meta is not None") + + contiguous_config = triton_config_npu_index(size_hints, inductor_meta=inductor_meta, reduction=True) + return cached_autotune( + size_hints, + [ + *contiguous_config, + ], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.REDUCTION, + ) + + +def persistent_reduction_npu_index( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + configs = triton_config_npu_index(size_hints, inductor_meta=inductor_meta, reduction=True, + persistent_reduction=True) + + + return cached_autotune( + size_hints, + configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.PERSISTENT_REDUCTION, + ) + + +def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): + """ + Compile a triton foreach kernel + """ + return cached_autotune( + None, + [triton.Config({}, num_stages=1, num_warps=num_warps)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.TEMPLATE, + filename=filename, + ) + + +@dynamo_timed +def benchmark_all_configs(self, *args, input_grid, **kwargs): + print(f"candidate launcher count = {len(self.launchers)}") + + tilling_kernel_list = [] + + def kernel_call(launcher): + def call_kernel(): + if launcher.config.pre_hook is not None: + launcher.config.pre_hook( + {**dict(zip(self.arg_names, args)), **launcher.config.kwargs} + ) + cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs) + launcher( + *cloned_args, + **cloned_kwargs, + grid=input_grid, + stream=stream, + ) + return call_kernel + + for launcher in self.launchers: + if not self.custom_kernel and launcher.n_spills > config.triton.spill_threshold: + log.debug( + "Skip config %s because of register spilling: %d", + launcher.config, + launcher.n_spills, + ) + return float("inf") + + stream = self.gpu_device.get_raw_stream( # type: ignore[call-arg] + self.gpu_device.current_device() + ) + tilling_kernel_list.append(kernel_call(launcher)) + + def do_batch_benchmark(tilling_kernel_list): + + def delete_file(base_path): + import shutil + if os.path.exists(base_path): + shutil.rmtree(base_path) + + import torch_npu + + stream = torch.npu.current_stream() + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False, + data_simplification=False + ) + + import uuid + random_uuid = uuid.uuid4().hex + md5_hash = hashlib.md5(random_uuid.encode()).hexdigest() + + from torch_npu._inductor.config import profile_path + + torch_path = profile_path + md5_hash + rep = 1 + with torch_npu.profiler.profile( + activities=[ + torch_npu.profiler.ProfilerActivity.NPU + ], + schedule=torch_npu.profiler.schedule(wait=0, warmup=1, active=rep, repeat=1, skip_first=1), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_path), + record_shapes=False, + profile_memory=False, + with_stack=False, + with_flops=False, + with_modules=False, + experimental_config=experimental_config) as prof: + stream.synchronize() + for _ in range(rep + 3): + for fn in tilling_kernel_list: + fn() + prof.step() + stream.synchronize() + + import pandas as pd + for root, _, files in os.walk(torch_path): + for file in files: + if file != 'kernel_details.csv': + continue + target_file = os.path.join(root, file) + df = pd.read_csv(target_file) + triton_rows = df[df['Name'].str.startswith('triton', na=False)] + ret = triton_rows['Duration(us)'].astype(float).tolist() + delete_file(torch_path) + return ret + + delete_file(torch_path) + return [] + + try: + timinglist = do_batch_benchmark(tilling_kernel_list) + if not len(timinglist) == len(self.launchers): + raise RuntimeError("not len(timinglist) == len(self.launchers)") + timings = {launcher: timing for launcher, timing in zip(self.launchers, timinglist)} + except Exception as e: + print("some cases in batch benchmark has error! Logging Exception as:") + print(e) + print("switched to single bench...") + timings = { + launcher: self.bench(launcher, *args, **kwargs) + for launcher in self.launchers + } + + for k, v in timings.items(): + self.coordesc_tuner.cache_benchmark_result(k.config, v) + + if log.isEnabledFor(logging.DEBUG): + log.debug("Benchmark all input configs for %s, get:", self.fn.__name__) + for k, v in timings.items(): + log.debug( + "%s: %f, nreg %d, nspill %d, #shared-mem %s", + k.config, + v, + k.n_regs, + k.n_spills, + k.shared, + ) + print(f"final valid tillings count = {len(timings)}") + return timings \ No newline at end of file diff --git a/torch_npu/_inductor/patch/.keep b/torch_npu/_inductor/patch/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/torch_npu/_inductor/patch/CMakeLists.txt b/torch_npu/_inductor/patch/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a0de479ff269dc7e2bc6cbdcbcd97a35bb6eb4cd --- /dev/null +++ b/torch_npu/_inductor/patch/CMakeLists.txt @@ -0,0 +1,36 @@ +# CMAKE_PREFIX_PATH=/home/wangmingfa/miniconda3/envs/wz_torch260/lib/python3.9/site-packages/torch/share/cmake cmake -DCMAKE_BUILD_TYPE=Debug .. +# make +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) +project(deberta_aoti) + +find_package(Torch REQUIRED) +# npu +include_directories("/usr/local/lib/python3.11/dist-packages/torch_npu/include") +include_directories("/host/zcl/pta_v2.6/libtorch_npu/include") +include_directories("/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/include") +# include_directories("/usr/local/Ascend/T115/ascend-toolkit/latest/x86_64-linux/include") + +link_directories("/host/zcl/aoti_files") +link_directories("/host/zcl/pta_v2.6/libtorch_npu/lib") +link_directories("/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/lib64") +# link_directories("/usr/local/Ascend/T115/ascend-toolkit/latest/x86_64-linux/lib64") +link_directories("/usr/lib/x86_64-linux-gnu") + +message("-----${TORCH_LIBRARIES}") + +add_executable(deberta_aoti + _runner.cpp + # test_libtorch.cpp +) + +SET(CMAKE_BUILD_TYPE "Debug") + +## npu +target_link_libraries(deberta_aoti aoti_npu) +target_link_libraries(deberta_aoti aoti_runner_npu) +target_link_libraries(deberta_aoti aoti_npuops) +target_link_libraries(deberta_aoti torch_npu) +target_link_libraries(deberta_aoti ascendcl) + +target_link_libraries(deberta_aoti "${TORCH_LIBRARIES}") +set_property(TARGET deberta_aoti PROPERTY CXX_STANDARD 17) \ No newline at end of file diff --git a/torch_npu/_inductor/patch/__init__.py b/torch_npu/_inductor/patch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/torch_npu/_inductor/patch/_runner.cpp b/torch_npu/_inductor/patch/_runner.cpp new file mode 100644 index 0000000000000000000000000000000000000000..555c8f32e2c6ca56db3deca11a7b6d99688d7415 --- /dev/null +++ b/torch_npu/_inductor/patch/_runner.cpp @@ -0,0 +1,170 @@ +#include +#include + +#include "torch/script.h" + +#include +#include +#include + +#include "torch_npu/csrc/core/npu/register/OptionRegister.h" +#include "torch_npu/csrc/core/npu/register/OptionsManager.h" + +#include +#include + +#include +#include +#include +#include + +void removeWhitespace(std::string& str) { + std::string result; + for (char c : str) { + if (!std::isspace(c)) { + result += c; + } + } + str = result; +} + +int extractValue(const std::string& json_str, const std::string& key) { + std::string target_key = "\"" + key + "\":"; // "input_ids": + size_t pos = json_str.find(target_key); + + if (pos == std::string::npos) { + std::cerr << "Key '" << target_key << "' not found!" << std::endl; + return -1; + } + + pos += target_key.length(); + size_t end = json_str.find_first_not_of("0123456789", pos); + return std::stoi(json_str.substr(pos, end - pos)); +} + +std::string parseArgMapJson(const std::string &argMapPath){ + std::ifstream jsonfile(argMapPath); + if (!jsonfile.is_open()) { + std::cerr << "Failed to open file!" << std::endl; + return nullptr; + } + std::string json_str{ + std::istreambuf_iterator(jsonfile), + std::istreambuf_iterator() + }; + removeWhitespace(json_str); + return json_str; +} + +void loadDebertaWeights(std::vector &inputs, std::string weightArgPath, int num){ + if(inputs.size()==num)return; + inputs.reserve(num); + torch::jit::script::Module weightTensors = torch::jit::load(weightArgPath); + //FIXME + for(int i=0;i getDebertaFilepathFromBatch(int batchSize){ + std::map ret; + // "/host/zcl/deberta_aoti/deberta_new.pt2", + // "/host/zcl/weights/args_aoti.pt", + // "/host/zcl/weights/input_args_map_aoti.json" + std::string batchString = std::to_string(batchSize); + std::string basePath = "/host/deberta_files/batch_" + batchString; + + ret["pt2Path"] = basePath + "/deberta_" + batchString + ".pt2"; + ret["weightArgPath"] = basePath + "/data/aotinductor/model/weight_args_" + batchString + ".pt"; + ret["argMapPath"] = basePath + "/data/aotinductor/model/args_map_" + batchString + ".json"; + return ret; +} + +std::vector runDebertaModelInference( + const std::map &userInputs, const int batchSize){ + + const auto paths = getDebertaFilepathFromBatch(batchSize); + const std::string pt2Path = paths.at("pt2Path"); + const std::string weightArgPath = paths.at("weightArgPath"); + const std::string argMapPath = paths.at("argMapPath"); + + std::cerr<<"pt path : "< inputs(4); + + // loadDebertaWeights(inputs, weightArgPath, extractValue(json_str, "input_arg_length")); + //FIXME + // inputs[extractValue(json_str, "input_ids")] = userInputs.at("input_ids"); + // inputs[extractValue(json_str, "segment_ids")] = userInputs.at("segment_ids"); + // inputs[extractValue(json_str, "input_mask")] = userInputs.at("input_mask"); + + inputs[0] = userInputs.at("input_ids"); + inputs[1] = userInputs.at("segment_ids"); + inputs[2] = userInputs.at("input_mask"); + inputs[3] = userInputs.at("batching_index"); + + torch::inductor::AOTIModelPackageLoader loader(pt2Path); + torch::inductor::AOTIModelContainerRunner* runner = loader.get_runner(); + std::vector outputs = runner->run(inputs); + + + // aclError error_flag = c10_npu::npuSynchronizeDevice(); + // if(error_flag!=ACL_SUCCESS){ + // std::cout<<"fxxk 0"< userInputs + c10_npu::option::SetOption("ALLOW_INTERNAL_FORMAT","disable"); + + torch::jit::script::Module tensors = torch::jit::load("/host/deberta_files/inputs/deberta_inputs_1.pth"); + + // tensors.to(at::kPrivateUse1); + torch::Tensor input_ids = tensors.attr("input_ids").toTensor().to(at::kPrivateUse1); + torch::Tensor segment_ids = tensors.attr("segment_ids").toTensor().to(at::kPrivateUse1); + torch::Tensor input_mask = tensors.attr("input_mask").toTensor().to(at::kPrivateUse1); + torch::Tensor batching_index = tensors.attr("batching_index").toTensor().to(at::kPrivateUse1); + + // std::cout< userInputs={ + {"input_ids", input_ids}, + {"segment_ids", segment_ids}, + {"input_mask", input_mask}, + {"batching_index", batching_index} + }; + + std::vector outputs = runDebertaModelInference(userInputs, batchSize); + + // aclError error_flag = c10_npu::npuSynchronizeDevice(); + // if(error_flag!=ACL_SUCCESS){ + // std::cout<<"fxxk"< +#include +#include +#include + +#include +#include +#include +#include + +#include + +namespace { +enum class DeviceType : int8_t { + CPU = 0, + CUDA = 1, // CUDA. + MKLDNN = 2, // Reserved for explicit MKLDNN + OPENGL = 3, // OpenGL + OPENCL = 4, // OpenCL + IDEEP = 5, // IDEEP. + HIP = 6, // AMD HIP + FPGA = 7, // FPGA + MAIA = 8, // ONNX Runtime / Microsoft + XLA = 9, // XLA / TPU + Vulkan = 10, // Vulkan + Metal = 11, // Metal + XPU = 12, // XPU + MPS = 13, // MPS + Meta = 14, // Meta (tensors with no data) + HPU = 15, // HPU / HABANA + VE = 16, // SX-Aurora / NEC + Lazy = 17, // Lazy Tensors + IPU = 18, // Graphcore IPU + MTIA = 19, // Meta training and inference devices + PrivateUse1 = 20, // PrivateUse1 device + // NB: If you add more devices: + // - Change the implementations of DeviceTypeName and isValidDeviceType + // in DeviceType.cpp + // - Change the number below + COMPILE_TIME_MAX_DEVICE_TYPES = 21, + }; +} + +#ifdef __cplusplus +extern "C" { +#endif +int32_t aoti_torch_device_type_npu() { + return (int32_t)DeviceType::PrivateUse1; +} + +#ifdef __cplusplus +} // extern "C" +#endif + +namespace { + static c10::Device c10_device(int32_t device_type, int32_t device_index) { + if (device_type == aoti_torch_device_type_cpu()) { + return c10::Device(static_cast(device_type)); + } else { + return c10::Device( + static_cast(device_type), + static_cast(device_index)); + } + } +} // namespace + +AOTITorchError aoti_torch_create_tensor_from_blob_npu( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AtenTensorHandle* ret_new_tensor) { +AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + c10::IntArrayRef sizes(sizes_ptr, ndim); + c10::IntArrayRef strides(strides_ptr, ndim); + c10::Device device = c10_device(device_type, device_index); + c10::TensorOptions options = c10::TensorOptions().device(device).dtype( + static_cast(dtype)); + *ret_new_tensor = torch::aot_inductor::new_tensor_handle( + // data == nullptr can happen for a 0-size tensor + (data != nullptr) ? at_npu::native::from_blob(data, sizes, strides, storage_offset, options, device) + : at::empty_strided(sizes, strides, options)); + // (data != nullptr) ? c10_npu::native::for_blob(data, sizes) + // .strides(strides) + // .storage_offset(storage_offset) + // .options(options) + // .make_tensor() + // : at::empty_strided(sizes, strides, options)); +}); +} + +AOTITorchError aoti_torch_create_tensor_from_blob_npu_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AtenTensorHandle* ret_new_tensor, + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size) { +AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + if (layout == static_cast(at::kMkldnn)) { + throw std::runtime_error("do not support mkldnn on npu."); + } else { + aoti_torch_create_tensor_from_blob_npu( + data, + ndim, + sizes_ptr, + strides_ptr, + storage_offset, + dtype, + device_type, + device_index, + ret_new_tensor); + } +}); +} + + diff --git a/torch_npu/_inductor/patch/ascend_aot_package.py b/torch_npu/_inductor/patch/ascend_aot_package.py new file mode 100644 index 0000000000000000000000000000000000000000..016c1fdc09d191a5db8e1c461d45d5459756889b --- /dev/null +++ b/torch_npu/_inductor/patch/ascend_aot_package.py @@ -0,0 +1,548 @@ +import torch +import torch_npu + +import os +import re +import sys +import torch_npu._inductor +from torch_npu.contrib import transfer_to_npu +import torch._inductor.package as inductor_package + +from typing import Dict, Any + +import importlib +import json +import shutil +import shlex +import subprocess +import logging + +from abc import ABC, abstractmethod + +DEBUG_MODE = os.getenv("DEBUG", 0) + +def MakePath(directory, name): + return os.path.abspath(os.path.join(directory, name)) + +DEPLOY_KERNEL_PATH = "/host/deberta_files" + +def modify_class_name(module_code: str) -> str: + """ replace '' with 'testModule' """ + modified_code = re.sub( + r'class \(torch\.nn\.Module\):', + "class testModule(torch.nn.Module):", + module_code, + count=1 + ) + header = """ +import torch +from torch import device +import torch_npu +import xpu_graph + +from xpu_graph.passes.patterns.targets.npu.triton_kernel.fused_brc_permute_sum import fused_brc_permute_sum +from xpu_graph.passes.patterns.targets.npu.triton_kernel.fused_div_mul_sum import fused_div_mul_sum + +import os +import torch_npu._inductor +from torch_npu.contrib import transfer_to_npu\n\n +""" + return header + modified_code + +# analysis forward func string and generate input tensors +def generate_inputs(code: str) -> Dict[str, torch.Tensor]: + # arg0_1: "i64[11, 12, 256, 256]" + pattern = r"(arg\d+_\d+): \"([if]\d+)\s*\[(.*?)\]\"" + + # 使用正则表达式查找所有匹配项 + matches = re.findall(pattern, code) + + from torch._dynamo.testing import rand_strided + # 解析结果 + fake_params = {} + dtype_map = { + "i64": torch.int64, + "i32": torch.int32, + "i16": torch.int16, + "i8": torch.int8, + "i1": torch.bool, + "f16": torch.float16, + "f32": torch.float32, + "bf16": torch.bfloat16, + } + for match in matches: + param_name = match[0] + dtype = match[1] + shape = tuple(int(dim) for dim in match[2].split(',')) + + fake_params[param_name] = torch.zeros(shape,dtype=dtype_map[dtype],device="npu") + + return fake_params + +def import_from_path(input_path): + module_name = os.path.basename(input_path).replace('.py', '') + spec = importlib.util.spec_from_file_location(module_name, input_path) + if not spec: + raise ImportError(f"can not create package: {input_path}") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + +def process_and_run_model(input_path: str, do_aoti = False): + # 1. read and replace class + with open(input_path) as f: + code = f.read() + modified_code = modify_class_name(code) + + # 2. generate inputs + forward_str = re.search(r"def forward\(.*?\):\n", modified_code).group() + fake_inputs = generate_inputs(forward_str) + + # 3. declare module + # module_dict = {} + output_path = os.path.join(os.path.dirname(input_path),"decorated_graph.py") + with open(output_path, "w") as f: + f.write(modified_code) + + module = import_from_path(output_path) + + # 4. create a module object + model = module.testModule().to("npu") + + # 5. run module + with torch.no_grad(): + if do_aoti: + exported = torch.export.export(model, tuple(fake_inputs.values())) + output_path = torch._inductor.aoti_compile_and_package( + exported, + # [Optional] Specify the generated shared library path. If not specified, + # the generated artifact is stored in your system temp directory. + package_path=os.path.join(os.path.dirname(input_path),"origin.pt2"), + ) + else: + logging.info(model(*(fake_inputs.values()))) + + return output_path + +class OpCodeGenerator(ABC): + @abstractmethod + def generate(self, num, opname, arglist, outbuf): + pass + +class LibraryOpGenerator(OpCodeGenerator): + def generate(self, node_id, opname, arglist, outbuf): + NewLines = [" // PATCHED_CODE :"] + # generate input declare + new_input_argnames = [] + for i, arg in enumerate(arglist): + if isinstance(arg, str): + new_input = f"node{node_id}_{opname}_{i}" + new_input_argnames.append(new_input) + NewLines.append(f" at::Tensor {new_input} = *reinterpret_cast({arg}.get());") + elif isinstance(arg, int): + new_input_argnames.append(str(arg)) + else : + raise TypeError(f"can not generate unsupport argtype: {type(arg).__name__}") + + # generate func call + func_arg_str = ", ".join(new_input_argnames) + + output_argname = f"{outbuf}_tensor" + NewLines.append(f" auto {output_argname} = at::{opname}({func_arg_str});") + # generate output tensor + NewLines.append(f" RAIIAtenTensorHandle {outbuf}(reinterpret_cast(new at::Tensor({output_argname})));") + + return NewLines + +class CustomOpGenerator(OpCodeGenerator): + def __init__(self, kernel_path, kernel_name, scalar_params=None, grid_size=(44, 1, 1)): + self.kernel_path = kernel_path + self.kernel_name = kernel_name + self.scalar_params = scalar_params or [] + self.grid_size = grid_size + + def generate(self, num, opname, arglist, outbuf): + code = [] + kernel_var = f"kernels.{self.kernel_name}" + + # Kernel加载逻辑 + code.append(f"if ({kernel_var} == nullptr) {{") + code.append(f" {kernel_var} = loadKernel(\"{self.kernel_path}\", \"{self.kernel_name}\", 1);") + code.append("}") + + # 网格配置 + code.append(f"Grid {self.kernel_name}_grid = Grid({self.grid_size[0]}, {self.grid_size[1]}, {self.grid_size[2]});") + + # 生成Tensor指针变量 + for i, arg in enumerate(arglist): + code.append(f"void* var_{i} = reinterpret_cast({arg}.data_ptr());") + + # 生成标量参数 + for name, value in self.scalar_params: + code.append(f"int {name} = {value};") + + # FFT地址获取 + code.append("rtError_t ret;") + code.append("void* ffts_addr = nullptr;") + code.append("uint32_t ffts_len;") + code.append("ret = rtGetC2cCtrlAddr((uint64_t*)&ffts_addr, &ffts_len);") + code.append("if (ret != RT_ERROR_NONE) return;") + code.append("void* workspace_addr = nullptr;") + + # 生成参数结构体 + code.append("struct __attribute__((packed)) {") + code.append(" void* ffts_addr __attribute__((aligned(8)));") + code.append(" void* workspace_addr __attribute__((aligned(8)));") + for i in range(len(arglist)): + code.append(f" void* var_{i} __attribute__((aligned(8)));") + for name, _ in self.scalar_params: + code.append(f" int {name} __attribute__((aligned(4)));") + code.append(" int32_t gridX __attribute__((aligned(4)));") + code.append(" int32_t gridY __attribute__((aligned(4)));") + code.append(" int32_t gridZ __attribute__((aligned(4)));") + code.append("} kernel_args = {") + code.append(" ffts_addr, workspace_addr,") + code.append(" " + ", ".join([f"var_{i}" for i in range(len(arglist))]) + ",") + code.append(" " + ", ".join([name for name, _ in self.scalar_params]) + ",") + code.append(f" {self.kernel_name}_grid.grid_x, {self.kernel_name}_grid.grid_y, {self.kernel_name}_grid.grid_z") + code.append("};") + + # 内核启动 + code.append(f"if ({self.kernel_name}_grid.is_non_zero()) {{") + code.append(f' launchKernel("{self.kernel_name}", {kernel_var}, stream, ' + f'{self.kernel_name}_grid.grid_x, {self.kernel_name}_grid.grid_y, ' + f'{self.kernel_name}_grid.grid_z, &kernel_args, sizeof(kernel_args));') + code.append("}") + return code + +class FallbackData: + def __init__(self, json_path: str): + with open(json_path, 'r') as f: + self.nodes = json.load(f)['nodes'] + + def get_node_info(self, node_id: int): + if node_id < 0 or node_id >= len(self.nodes): + raise ValueError(f"Invalid node_id: {node_id}. Total nodes: {len(self.nodes)}") + + # 提取目标节点 + node = self.nodes[node_id]['node'] + + # 遍历所有输入参数 + arg_types = [] + for input_item in node['inputs']: + arg = input_item['arg'] + arg_type = next(iter(arg.keys())) # as_tensor | as_float | as_int + arg_types.append(arg_type) + + return node['target'], arg_types + +class CodeManager: + OP_REGISTRY = { + "aten::addmm": {"revertlines": 2, "generator": LibraryOpGenerator()}, + "aten::gather": {"revertlines": 2, "generator": LibraryOpGenerator()}, + "aten::gelu": {"revertlines": 2, "generator": LibraryOpGenerator()}, + } + + def __init__(self, directory, cpp_name, json_name, batch_size): + self.code_list = [] + self.cpp_path = MakePath(directory, cpp_name) + self.batch_name = f"batch_{batch_size}" + self.proxy_data = FallbackData(MakePath(directory, json_name)) + + def clear(self): + self.code_list.clear() + + def pop_lines(self, num): + for _ in range(num): + self.code_list.pop() + + def append_lines(self, newlines): + self.code_list.extend(newlines) + + def save_new_file(self, new_file_path): + with open(new_file_path, "w", encoding="utf-8") as f: + for line in self.code_list: + f.write(line + "\n") + + def extract_proxy_executor_line(self, line: str, argtypes: list[str]): + # 匹配 int64_t vector + int64_pattern = r'std::vector\s*{\s*([^}]+?)\s*}\s*\.data\(\)' + int64_match = re.search(int64_pattern, line) + + # 匹配 AtenTensorHandle vector + aten_pattern = r'std::vector\s*{\s*([^}]+?)\s*}\s*\.data\(\)' + aten_match = re.search(aten_pattern, line) + + # 提取 int64_t 元素 + int64_list = [] + if int64_match: + int64_content = int64_match.group(1) + int64_list = [e.strip() for e in int64_content.split(',')] + + # 提取 AtenTensorHandle 元素 + aten_list = [] + if aten_match: + aten_content = aten_match.group(1) + aten_list = [e.strip() for e in aten_content.split(',')] + + arglist = [] + iptr = 0 + aptr = 0 + for argtype in argtypes: + if argtype == "as_tensor": + arglist.append(aten_list[aptr]) + aptr+=1 + elif argtype == "as_int": + arglist.append(int(int64_list[iptr])) + iptr+=1 + else: + raise ValueError(f"meeting unsupported argtype:{argtype}") + if aptr != len(aten_list) - 1: + raise ValueError(f"mismatched argtype length and arglist length!") + + return arglist, aten_list[aptr] + + def process_cpp_file(self): + fallbackOpPrefixPattern = re.compile( + r'^\s*aoti_torch_proxy_executor_call_function\(\s*proxy_executor\s*,\s*(\d+),' + ) + + compileCmdPattern = re.compile( + r'^//\sg\+\+\s+\S+\.cpp' + ) + + linkCmdPattern = re.compile( + r'^//\sg\+\+\s+\S+\.o' + ) + + loadKernelPattern = re.compile( + r'loadKernel\(\"(.*/)([^/]+\.cubin)\"' + ) + + kernelPathPattern = r'/tmp/(?:.*/)*([^/]+\.cubin)' + + launchKernelPattern = re.compile( + r'launchKernel\(\"' + ) + launch_cnt=0 + + with open(self.cpp_path, 'r', encoding='utf-8') as f: + for line in f: + # 保留原始行 + self.code_list.append(line.rstrip('\n')) + + if launchKernelPattern.search(line) and DEBUG_MODE: + self.code_list.append(" {") + self.code_list.append(" aclError error_flag = c10_npu::npuSynchronizeDevice();") + self.code_list.append(" if(error_flag!=ACL_SUCCESS){") + self.code_list.append(f" std::cerr<<\"[DEBUG] failed to synchronize TT_kernel {launch_cnt}\"< None: + """ + unzip /*.pt2 to /pt2tmp + and return /pt2tmp/data/aotinductor/model + """ + + self.pt2_dir = os.path.dirname(pt2_path) + extract_dir = os.path.join(self.pt2_dir, "pt2tmp") + extract_dir = os.path.abspath(extract_dir) + + if os.path.exists(extract_dir): + def handle_error(func, path, exc_info): + import stat + if not os.access(path, os.W_OK): + os.chmod(path, stat.S_IWUSR) + func(path) + else: + raise + shutil.rmtree(extract_dir, onerror=handle_error) + os.makedirs(extract_dir, exist_ok=True) + + import zipfile + with zipfile.ZipFile(pt2_path, 'r') as zip_ref: + zip_ref.extractall(extract_dir) + + return os.path.join(extract_dir,"data/aotinductor/model") + + def rewrite_cpp_wrapper(self): + old_compile_cmd, old_link_cmd = self.code_manager.process_cpp_file() + self.code_manager.save_new_file(self.new_cpp_path) + + compile_list = shlex.split(old_compile_cmd) + link_list = shlex.split(old_link_cmd) + + compile_list[1] = self.new_cpp_path + tmp_path = MakePath(self.extract_dir, "tmp.o") + compile_list[-1] = tmp_path + + link_list[1] = tmp_path + if link_list[2].endswith(".o"): + link_list[2] = self.weight_path + link_list[-1] = self.new_so_path + + logging.info(" after rewrite_cpp_wrapper:") + logging.info(f" new_cpp_path = {self.new_cpp_path}") + logging.info(f" compile_list = {compile_list}") + logging.info(f" link_list = {link_list}") + + return compile_list, link_list + + def recompile(self, compile_cmd, link_cmd): + + try: + subprocess.run(compile_cmd, check=True) + except Exception as e: + raise e + + try: + subprocess.run(link_cmd, check=True) + except Exception as e: + raise e + + + def repackage(self, new_pt2_directory, extra_files): + new_proxy_json_path = MakePath(self.extract_dir, self.new_name_prefix + ".json") + new_metadata_json_path = MakePath(self.extract_dir, self.new_name_prefix + "_metadata.json") + + shutil.copy(MakePath(self.extract_dir,self.proxy_json_name), new_proxy_json_path) + shutil.copy(MakePath(self.extract_dir, self.metadata_json_name),new_metadata_json_path) + + file_list = [ + self.new_cpp_path, + self.new_so_path, + new_proxy_json_path, + new_metadata_json_path, + ] + + for filename in self.binfiles: + file_list.append(MakePath(self.extract_dir, filename)) + + from pathlib import Path + for extra_file in extra_files: + try: + Path(extra_file).resolve(strict=True) + except Exception as e: + raise e + file_list.append(extra_file) + + if len(new_pt2_directory)==0: + new_pkg_path = MakePath(self.pt2_dir, self.new_name_prefix + ".pt2") + else: + new_pkg_path = MakePath(new_pt2_directory, self.new_name_prefix + ".pt2") + + inductor_package.package_aoti(new_pkg_path, file_list) + logging.info(f" OUTPUT NEW AOTI PACKAGE TO: {new_pkg_path}") + + return new_pkg_path + + + def make_new_pt2(self, new_pt2_directory="", extra_files = []): + compile_cmd, link_cmd = self.rewrite_cpp_wrapper() + self.recompile(compile_cmd, link_cmd) + new_pkg_path = self.repackage(new_pt2_directory, extra_files) + logging.info(f" ---------- SUCCESS MAKE NEW PT2 BATCH {self.batch_size} ----------") + return new_pkg_path \ No newline at end of file diff --git a/torch_npu/_inductor/patch/c_shim_npu.cpp b/torch_npu/_inductor/patch/c_shim_npu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7db50caba9b0322e20c4845c84119bd1c1cb4a2c --- /dev/null +++ b/torch_npu/_inductor/patch/c_shim_npu.cpp @@ -0,0 +1,69 @@ + + +// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND. +// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details +#include +#include + +// 基础支持 +#include +#include +#include +#include + +// NPU扩展 +// #include +#include "torch_npu/torch_npu.h" +// #include "torch_npu/csrc/framework/utils/NpuUtils.h" +// #include + +// 算子定义 +#include +#include +#include + +// 模板工具 +#include +#include // 智能指针支持 + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#include +#include +#include +#else + +#endif + +using namespace torch::aot_inductor; + + +AOTITorchError aoti_torch_npu_index_Tensor(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle* ret0) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto tmp_result = at::index( + *tensor_handle_to_tensor_pointer(self), c10::List<::std::optional>(c10::ArrayRef<::std::optional>(pointer_to_list<::std::optional>(indices, indices_len_))) + ); + *ret0 = new_tensor_handle(std::move(tmp_result));; + }); +} + +AOTITorchError aoti_torch_npu_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto tmp_result = at::cat( + pointer_to_list(tensors, tensors_len_), dim + ); + *ret0 = new_tensor_handle(std::move(tmp_result));; + }); +} + +AOTITorchError aoti_torch_npu_convolution(AtenTensorHandle input, AtenTensorHandle weight, AtenTensorHandle* bias, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, AtenTensorHandle* ret0) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto tmp_result = at::compositeexplicitautograd::convolution_symint( + *tensor_handle_to_tensor_pointer(input), *tensor_handle_to_tensor_pointer(weight), pointer_to_optional(bias), pointer_to_list(stride, stride_len_), pointer_to_list(padding, padding_len_), pointer_to_list(dilation, dilation_len_), transposed, pointer_to_list(output_padding, output_padding_len_), groups + ); + *ret0 = new_tensor_handle(std::move(tmp_result));; + }); +} + +// -------------------------------- split line -------------------------------- \ No newline at end of file diff --git a/torch_npu/_inductor/patch/deploy_aoti_model.sh b/torch_npu/_inductor/patch/deploy_aoti_model.sh new file mode 100644 index 0000000000000000000000000000000000000000..5ae9a4a535bdaafaef7dd0ffa67b028aebe48dad --- /dev/null +++ b/torch_npu/_inductor/patch/deploy_aoti_model.sh @@ -0,0 +1,38 @@ +#!/bin/bash +set -euo pipefail # 严格模式:任何错误立即终止脚本 + +# 参数校验 +if [[ $# -ne 1 ]]; then + echo "[ERROR] Incorrect usage detected!" + echo "[INFO] Usage: $0 " + exit 1 +fi + +input_path="$1" +output_path=/host/deberta_files + +[[ -d "$input_path" ]] || { echo "[ERROR] Invalid directory: $input_path"; exit 1; } + +for file in "$input_path"/deberta_*.pt2; do + # 提取文件名中的数字编号(如deberta_12.pt2 → 12) + base_name=$(basename "$file" .pt2) + num="${base_name##*_}" + + # 构建目标路径(如output_path/batch_12) + target_dir="$output_path/batch_$num" + if [[ -d "$target_dir" ]]; then + echo "[INFO] Delete old path: $target_dir" + rm -rf "$target_dir" || { echo "[ERROR] Delete old path failed"; exit 1; } + fi + + mkdir -p "$target_dir" || { echo "[ERROR] Create directory failed: $target_dir"; exit 1; } + + # 复制文件到目标目录 + cp -v "$file" "$target_dir/" || { echo "[ERROR] Failed to copy $file to $target_dir"; exit 1; } + + # 解压文件并处理错误 + zip_file="$target_dir/$(basename "$file")" + unzip "$zip_file" -d "$target_dir" || { echo "[ERROR] failed to unzip $zip_file"; exit 1; } +done + +echo "[SUCCESS] All AOTI pt2 files have deployed to $output_path" \ No newline at end of file diff --git a/torch_npu/_inductor/patch/model_container_runner_npu.cpp b/torch_npu/_inductor/patch/model_container_runner_npu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e8a80a50e35fc19d4a8b6a3e4373d39ebd142941 --- /dev/null +++ b/torch_npu/_inductor/patch/model_container_runner_npu.cpp @@ -0,0 +1,56 @@ +#include +#include +#include + +namespace torch::inductor { + +AOTIModelContainerRunnerNpu::AOTIModelContainerRunnerNpu( + const std::string& model_so_path, + size_t num_models, + const std::string& device_str, + const std::string& cubin_dir) + : AOTIModelContainerRunner( + model_so_path, + num_models, + device_str, + cubin_dir) { + std::cerr <<"[DEBUG] in create func" << std::endl; + } + +AOTIModelContainerRunnerNpu::~AOTIModelContainerRunnerNpu() = default; + + +std::vector AOTIModelContainerRunnerNpu::run( + const std::vector& inputs, void* stream_handle){ + c10_npu::NPUStream npu_stream = c10_npu::getCurrentNPUStream(); + std::cerr<<"[DEBUG] before ModelContainer run, stream = "<(npu_stream.stream())); +} + +std::vector AOTIModelContainerRunnerNpu::run_with_npu_stream( + std::vector& inputs, + c10_npu::NPUStream npu_stream) { + return AOTIModelContainerRunner::run( + inputs, reinterpret_cast(npu_stream.stream())); +} + + +std::unique_ptr create_aoti_runner_npu( + const std::string& model_so_path, + size_t num_models, + const std::string& device_str, + const std::string& cubin_dir) { + std::cout <<"[DEBUG] in create_aoti_runner_npu" << std::endl; + return std::make_unique( + model_so_path, num_models, device_str, cubin_dir); +} + +void RegistNpu() { + std::cout << "[DEBUG] start regist npu" << std::endl; + RegisterAOTIModelRunner register_npu_runner("npu", &create_aoti_runner_npu); + std::cout << "[DEBUG] end regist npu" << std::endl; +} + + +} // namespace torch::inductor diff --git a/torch_npu/_inductor/patch/torch_changes.patch b/torch_npu/_inductor/patch/torch_changes.patch new file mode 100644 index 0000000000000000000000000000000000000000..46ebd614ad3841022893d60c757139720de61abd --- /dev/null +++ b/torch_npu/_inductor/patch/torch_changes.patch @@ -0,0 +1,1138 @@ +diff --git a/_inductor/codecache.py b/_inductor/codecache.py +index de72c7e..f7b52c4 100644 +--- a/_inductor/codecache.py ++++ b/_inductor/codecache.py +@@ -1469,6 +1469,7 @@ class AotCodeCompiler: + generated_files.append(input_path) + + output_code_log.info("Output code written to: %s", input_path) ++ print("Output code written to: %s", input_path) + trace_structured( + "graph_dump", + lambda: { +@@ -1544,6 +1545,7 @@ class AotCodeCompiler: + output_dir=object_output_dir, + BuildOption=object_build_options, + ) ++ # import pdb;pdb.set_trace() + compile_cmd = object_builder.get_command_line() + consts_o = object_builder.get_target_file_path() + if fbcode_aot_cpu_re: +@@ -1675,10 +1677,12 @@ class AotCodeCompiler: + output_dir=object_output_dir, + BuildOption=object_build_options, + ) ++ # import pdb;pdb.set_trace() + compile_cmd = object_builder.get_command_line() + output_o = object_builder.get_target_file_path() + + log.debug("aot compilation command: %s", compile_cmd) ++ print(f"aot compilation command: {compile_cmd}") + if not config.aot_inductor.package_cpp_only: + if fbcode_aot_cpu_re: + output_o = os.path.splitext(input_path)[0] + ".o" +@@ -1686,7 +1690,7 @@ class AotCodeCompiler: + os.chmod(output_o, 0o644) + else: + run_command_and_check(compile_cmd) +- ++ # import pdb;pdb.set_trace() + if config.aot_inductor.package_cpp_only: + compile_flags = os.path.splitext(input_path)[0] + "_compile_flags.json" + object_build_options.save_flags_to_file(compile_flags) +@@ -1713,24 +1717,26 @@ class AotCodeCompiler: + kernels_o = " ".join(kernels_o) + + output_name, output_dir = get_name_and_dir_from_output_file_path(output_so) ++ + so_build_options = CppTorchDeviceOptions( + vec_isa=picked_vec_isa, + device_type=device_type, + aot_mode=graph.aot_mode, + use_absolute_path=use_absolute_path, + ) +- + so_builder = CppBuilder( + name=output_name, + sources=[output_o, consts_o, kernels_o], + output_dir=output_dir, + BuildOption=so_build_options, + ) ++ + link_cmd = so_builder.get_command_line() ++ shutil.copy(consts_o, "/host/aoti_weights/weight.o") + output_so = so_builder.get_target_file_path() + + log.debug("aot linkage command: %s", link_cmd) +- ++ print(f"aot linkage command: {link_cmd}") + # Append cmds to the end of codegen-ed wrapper file + with open(input_path, "a") as f: + f.write("\n") +@@ -2000,6 +2006,7 @@ class CppCodeCache: + # And then pass the command_line to below write function as extra parameter to + # guarantee the source code hash contains ISA difference. + vec_isa_cmd = repr(command_gen.get_command_line()) ++ # import pdb;pdb.set_trace() + key, input_path = write(source_code, "cpp", extra=vec_isa_cmd) + + if key not in cls.cache: +diff --git a/_inductor/codegen/aoti_runtime/interface.cpp b/_inductor/codegen/aoti_runtime/interface.cpp +index b270ccb..f9e0a7f 100644 +--- a/_inductor/codegen/aoti_runtime/interface.cpp ++++ b/_inductor/codegen/aoti_runtime/interface.cpp +@@ -1,10 +1,15 @@ + // Definition of AOTI runtime interface functions + ++#include + #include + #include ++#include ++#include + + #include + #include ++#include ++#include + #include + #include + +@@ -55,7 +60,7 @@ AOTIRuntimeError AOTInductorModelContainerCreate( + return AOTInductorModelContainerCreateWithDevice( + container_handle, + num_models, +- is_cpu ? "cpu" : "cuda", ++ is_cpu ? "cpu" : "npu", + cubin_dir); + } + +diff --git a/_inductor/codegen/cpp_utils.py b/_inductor/codegen/cpp_utils.py +index 4a62f92..849476f 100644 +--- a/_inductor/codegen/cpp_utils.py ++++ b/_inductor/codegen/cpp_utils.py +@@ -82,6 +82,7 @@ DEVICE_TO_ATEN = { + "cpu": "at::kCPU", + "cuda": "at::kCUDA", + "xpu": "at::kXPU", ++ "npu": "at::kNPU", + } + + LAYOUT_TO_ATEN = { +diff --git a/_inductor/codegen/cpp_wrapper_cpu.py b/_inductor/codegen/cpp_wrapper_cpu.py +index f92da71..532c38d 100644 +--- a/_inductor/codegen/cpp_wrapper_cpu.py ++++ b/_inductor/codegen/cpp_wrapper_cpu.py +@@ -190,7 +190,6 @@ class CppWrapperCpu(PythonWrapperCodegen): + #include + #include + #include +- #include + + #include + typedef at::Half half; +diff --git a/_inductor/codegen/wrapper.py b/_inductor/codegen/wrapper.py +index 4da5e4c..ff7f724 100644 +--- a/_inductor/codegen/wrapper.py ++++ b/_inductor/codegen/wrapper.py +@@ -402,9 +402,8 @@ class EnterDeviceContextManagerLine(WrapperLine): + # associated with a device, so we never expect the device to change. + # CUDAStreamGuard sets the stream and the device. + if self.last_seen_device_guard_index is None: +- code.writeline( +- f"{V.graph.device_ops.cpp_aoti_stream_guard()} stream_guard(stream, this->device_idx_);" +- ) ++ code.writeline(f"c10_npu::NPUStream npuStream = c10_npu::getCurrentNPUStream(this->device_idx_);") ++ code.writeline(f"if(stream != npuStream.stream()){{std::cerr<<\"stream not equal to npuStream!!!\"< List[str]: + ] + # TODO: this is to avoid FC breakage for fbcode. When using newly + # generated model.so on an older verion of PyTorch, need to use +- # the v1 version for aoti_torch_create_tensor_from_blob ++ # the v1 version for aoti_torch_create_tensor_from_blob_npu + create_tensor_from_blob_v1 = "AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1" + + fb_internal_macros.append(create_tensor_from_blob_v1) +@@ -786,6 +787,15 @@ def _get_torch_related_args( + + return include_dirs, libraries_dirs, libraries + ++def _get_torch_npu_related_args( ++ include_pytorch: bool, aot_mode: bool ++): ++ from torch_npu.utils._inductor import _TORCH_NPU_PATH, TORCH_NPU_LIB_PATH ++ ++ include_dirs = [os.path.join(_TORCH_NPU_PATH, "include"), "/host/zcl/aoti_files"] ++ libraries_dirs = ["/host/zcl/pta_v2.6/libtorch_npu/lib" ,"/host/zcl/aoti_files"] ++ libraries = ["torch_npu", "aoti_npu", "aoti_runner_npu", "aoti_npuops"] ++ return include_dirs, libraries_dirs, libraries + + def _get_python_include_dirs() -> List[str]: + include_dir = Path(sysconfig.get_path("include")) +@@ -1043,6 +1053,12 @@ def get_cpp_torch_options( + torch_libraries, + ) = _get_torch_related_args(include_pytorch=include_pytorch, aot_mode=aot_mode) + ++ ( ++ torch_npu_include_dirs, ++ torch_npu_libraries_dirs, ++ torch_npu_libraries, ++ ) = _get_torch_npu_related_args(include_pytorch=include_pytorch, aot_mode=aot_mode) ++ + python_include_dirs, python_libraries_dirs = _get_python_related_args() + + ( +@@ -1070,12 +1086,13 @@ def get_cpp_torch_options( + sys_libs_include_dirs + + python_include_dirs + + torch_include_dirs ++ + torch_npu_include_dirs + + omp_include_dir_paths + ) + cflags = sys_libs_cflags + omp_cflags + ldflags = omp_ldflags +- libraries_dirs = python_libraries_dirs + torch_libraries_dirs + omp_lib_dir_paths +- libraries = torch_libraries + omp_lib ++ libraries_dirs = python_libraries_dirs + torch_libraries_dirs + omp_lib_dir_paths + torch_npu_libraries_dirs ++ libraries = torch_libraries + omp_lib + torch_npu_libraries + passthough_args = ( + sys_libs_passthough_args + + isa_ps_args_build_flags +@@ -1233,6 +1250,12 @@ def get_cpp_torch_device_options( + cflags += ["fsycl", "Wno-unsupported-floating-point-opt"] + libraries += ["c10_xpu", "sycl", "ze_loader", "torch_xpu"] + ++ if device_type == "npu": ++ definations.append(" USE_NPU") ++ definations.append(" BUILD_LIBTORCH=ON") ++ # cflags += [""] ++ libraries += ["runtime", "ascendcl"] ++ + if aot_mode: + if config.is_fbcode(): + from torch._inductor.codecache import cpp_prefix_path +@@ -1306,7 +1329,6 @@ class CppTorchDeviceOptions(CppTorchOptions): + device_libraries_dirs: List[str] = [] + device_libraries: List[str] = [] + device_passthough_args: List[str] = [] +- + ( + device_definations, + device_include_dirs, +@@ -1325,6 +1347,7 @@ class CppTorchDeviceOptions(CppTorchOptions): + _append_list(self._libraries_dirs, device_libraries_dirs) + _append_list(self._libraries, device_libraries) + _append_list(self._passthough_args, device_passthough_args) ++ + self._finalize_options() + + def _finalize_options(self) -> None: +@@ -1448,7 +1471,8 @@ class CppBuilder: + self._cflags_args += f"/{cflag} " + else: + self._cflags_args += f"-{cflag} " +- ++ # if self._compile_only: ++ # import pdb;pdb.set_trace() + for defination in BuildOption.get_definations(): + if _IS_WINDOWS: + self._definations_args += f"/D {defination} " +diff --git a/_inductor/graph.py b/_inductor/graph.py +index 3a5942f..8e3018b 100644 +--- a/_inductor/graph.py ++++ b/_inductor/graph.py +@@ -1860,7 +1860,7 @@ class GraphLowering(torch.fx.Interpreter): + """ + For GPU, Triton kernels are autotuned and stored as cubin files + """ +- if any(device in self.device_types for device in ["cuda", "xpu"]): ++ if any(device in self.device_types for device in ["cuda", "xpu", 'npu']): + if config.triton.autotune_at_compile_time: + # If autotune_at_compile_time is True, we can do the codegen in one-pass + # TODO: once autotune_at_compile_time is stable, we should delete the else branch +diff --git a/_inductor/utils.py b/_inductor/utils.py +index d5c096a..fcceb62 100644 +--- a/_inductor/utils.py ++++ b/_inductor/utils.py +@@ -64,6 +64,7 @@ GPU_TYPES = ["cuda", "xpu"] + @functools.lru_cache(None) + def get_gpu_type(): + avail_gpus = [x for x in GPU_TYPES if getattr(torch, x).is_available()] ++ # import pdb;pdb.set_trace() + assert len(avail_gpus) <= 1 + gpu_type = "cuda" if len(avail_gpus) == 0 else avail_gpus.pop() + return gpu_type +@@ -1944,7 +1945,7 @@ def get_cloned_parameter_buffer_name(name: str): + + def is_gpu(device: Optional[str]): + assert isinstance(device, str) or device is None, device +- return device in GPU_TYPES ++ return device in GPU_TYPES or device == "npu" + + + def device_need_guard(device: str): +diff --git a/include/torch/csrc/inductor/aoti_runner/model_container_runner_npu.h b/include/torch/csrc/inductor/aoti_runner/model_container_runner_npu.h +new file mode 100644 +index 0000000..848cab6 +--- /dev/null ++++ b/include/torch/csrc/inductor/aoti_runner/model_container_runner_npu.h +@@ -0,0 +1,32 @@ ++#pragma once ++ ++#include ++#include ++ ++namespace torch::inductor { ++ ++// NOTICE: Following APIs are subject to change due to active development ++// We provide NO BC guarantee for these APIs ++class TORCH_API AOTIModelContainerRunnerNpu : public AOTIModelContainerRunner { ++ public: ++ // @param device_str: cuda device string, e.g. "cuda", "cuda:0" ++ AOTIModelContainerRunnerNpu( ++ const std::string& model_so_path, ++ size_t num_models = 1, ++ const std::string& device_str = "npu", ++ const std::string& cubin_dir = ""); ++ ++ ~AOTIModelContainerRunnerNpu(); ++ ++ std::vector run( ++ const std::vector& inputs, ++ void* stream_handle = nullptr) override; ++ ++ std::vector run_with_npu_stream( ++ std::vector& inputs, ++ c10_npu::NPUStream npu_stream); ++}; ++ ++void RegistNpu(); ++ ++} // namespace torch::inductor +\ No newline at end of file +diff --git a/include/torch/csrc/inductor/aoti_runtime/arrayref_tensor.h b/include/torch/csrc/inductor/aoti_runtime/arrayref_tensor.h +index e2f2957..9730c60 100644 +--- a/include/torch/csrc/inductor/aoti_runtime/arrayref_tensor.h ++++ b/include/torch/csrc/inductor/aoti_runtime/arrayref_tensor.h +@@ -229,7 +229,7 @@ class ArrayRefTensor { + + AtenTensorHandle borrowAsTensor() const { + AtenTensorHandle result = nullptr; +- AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_v2( ++ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_npu_v2( + data(), + sizes_.size(), + sizes_.data(), +diff --git a/include/torch/csrc/inductor/aoti_runtime/device_utils.h b/include/torch/csrc/inductor/aoti_runtime/device_utils.h +index 7b48f49..b1aa844 100644 +--- a/include/torch/csrc/inductor/aoti_runtime/device_utils.h ++++ b/include/torch/csrc/inductor/aoti_runtime/device_utils.h +@@ -50,6 +50,32 @@ using DeviceStreamType = sycl::queue*; + + } // namespace torch::aot_inductor + ++#elif defined(USE_NPU) ++ ++#include "third_party/acl/inc/acl/acl_base.h" ++#include "third_party/acl/inc/acl/acl_rt.h" ++ ++// DCK_TODO: do we need to support 32bit os. ++typedef void* NPUdeviceptr; ++ ++typedef void* NPUfunction; ++ ++#define AOTI_RUNTIME_DEVICE_CHECK(EXPR) \ ++ do { \ ++ const aclError code = EXPR; \ ++ if (code != ACL_SUCCESS) { \ ++ throw std::runtime_error( \ ++ std::string("NPU error core: ") + std::to_string(code) \ ++ + std::string(" ") + std::string(__FILE__) + std::string(":") + std::to_string(__LINE__)); \ ++ } \ ++ } while (0) ++ ++namespace torch::aot_inductor { ++ ++using DeviceStreamType = aclrtStream; ++ ++} // namespace torch::aot_inductor ++ + #else + + #define AOTI_RUNTIME_DEVICE_CHECK(EXPR) \ +diff --git a/include/torch/csrc/inductor/aoti_runtime/model.h b/include/torch/csrc/inductor/aoti_runtime/model.h +index a8ec3a6..2c3ca04 100644 +--- a/include/torch/csrc/inductor/aoti_runtime/model.h ++++ b/include/torch/csrc/inductor/aoti_runtime/model.h +@@ -59,6 +59,22 @@ GPUPtr RAII_gpuMalloc(size_t num_bytes) { + + #endif // USE_CUDA + ++#ifdef USE_NPU ++ ++using NPUPtr = std::unique_ptr>; ++ ++NPUPtr RAII_npuMalloc(size_t num_bytes) { ++ void* data_ptr; ++ // DCK_TODO: aclrtMalloc bytes cannot be 0, how to adapt. ++ if (num_bytes == 0) num_bytes = 4; ++ // DCK_TODO: ACL_MEM_MALLOC_NORMAL_ONLY ? ++ AOTI_RUNTIME_DEVICE_CHECK(aclrtMalloc((void**)&data_ptr, num_bytes, ACL_MEM_MALLOC_NORMAL_ONLY)); ++ auto deleter = [](void* ptr) { AOTI_RUNTIME_DEVICE_CHECK(aclrtFree(ptr)); }; ++ return NPUPtr(data_ptr, deleter); ++} ++ ++#endif // USE_NPU ++ + #ifdef USE_XPU + + using GPUPtr = std::unique_ptr>; +@@ -92,9 +108,10 @@ inline void parse_device_str( + const std::string& device_str, + int32_t& device_type, + int32_t& device_idx) { +- std::regex re("(cpu|cuda|xpu)(:([0-9]+))?"); ++ std::regex re("(cpu|cuda|xpu|npu)(:([0-9]+))?"); + std::smatch sm; + bool matched = std::regex_match(device_str, sm, re); ++ std::cout <<"wz 1" << std::endl; + AOTI_RUNTIME_CHECK(matched, "Invalid device: " + device_str); + + if (sm[1].str() == "cpu") { +@@ -104,8 +121,13 @@ inline void parse_device_str( + #ifdef USE_XPU + } else if (sm[1].str() == "xpu") { + device_type = aoti_torch_device_type_xpu(); ++#endif ++#ifdef USE_NPU ++ } else if (sm[1].str() == "npu") { ++ device_type = aoti_torch_device_type_npu(); + #endif + } else { ++ std::cout <<"wz 1" << std::endl; + AOTI_RUNTIME_CHECK(false, "Invalid device: " + device_str); + } + +@@ -153,6 +175,14 @@ class AOTInductorModelBase { + aoti_torch_set_current_xpu_device(device_idx_); + } + #endif // USE_XPU ++#ifdef USE_NPU ++ if (device_idx_ == -1) { ++ // DCK_TODO: which device to set WZ_TODO: match CUDA ++ std::cout << "wzdebugsetnpu device0" << std::endl; ++ AOTI_RUNTIME_DEVICE_CHECK(aclrtSetDevice(0)); ++ AOTI_RUNTIME_DEVICE_CHECK(aclrtGetDevice(&device_idx_)); ++ } ++#endif // USE_NPU + } + + // NOLINTNEXTLINE(modernize-use-equals-default) +@@ -172,6 +202,15 @@ class AOTInductorModelBase { + delete *run_finished_; + } + #endif // USE_XPU ++#ifdef USE_NPU ++ if (run_finished_) { ++ auto code = aclrtDestroyEvent(*run_finished_); ++ if (code != ACL_SUCCESS) { ++ std::cerr << "Failed to destroy NPU event in AOTInductor model erorr code: " ++ << code << std::endl; ++ } ++ } ++#endif // USE_NPU + } + + AOTInductorModelBase(AOTInductorModelBase&&) = delete; +@@ -201,6 +240,12 @@ class AOTInductorModelBase { + delete *run_finished_; + run_finished_.reset(); + } ++#elif defined(USE_NPU) ++ if (!run_finished_) { ++ aclrtEvent run_finished; ++ AOTI_RUNTIME_DEVICE_CHECK(aclrtCreateEvent(&run_finished)); ++ run_finished_.emplace(run_finished); ++ } + #else // !USE_CUDA && !USE_XPU + run_finished_ = false; + #endif +@@ -213,6 +258,8 @@ class AOTInductorModelBase { + #elif defined(USE_XPU) + run_finished_ = std::make_optional(new sycl::event( + static_cast(stream)->ext_oneapi_submit_barrier())); ++#elif defined(USE_NPU) ++ AOTI_RUNTIME_DEVICE_CHECK(aclrtRecordEvent(*run_finished_, stream)); + #else // !USE_CUDA && !USE_XPU + run_finished_ = true; + #endif // USE_CUDA +@@ -234,6 +281,12 @@ class AOTInductorModelBase { + delete *run_finished_; + run_finished_.reset(); + } ++#elif defined(USE_NPU) ++ if (!run_finished_) { ++ aclrtEvent run_finished; ++ AOTI_RUNTIME_DEVICE_CHECK(aclrtCreateEvent(&run_finished)); ++ run_finished_.emplace(run_finished); ++ } + #else // !USE_CUDA && !USE_XPU + run_finished_ = false; + #endif +@@ -250,6 +303,8 @@ class AOTInductorModelBase { + run_finished_ = std::make_optional(new sycl::event( + static_cast(stream)->ext_oneapi_submit_barrier())); + ++#elif defined(USE_NPU) ++ AOTI_RUNTIME_DEVICE_CHECK(aclrtRecordEvent(*run_finished_, stream)); + #else // !USE_CUDA && !USE_XPU + run_finished_ = true; + #endif // USE_CUDA +@@ -267,6 +322,8 @@ class AOTInductorModelBase { + compute_gpu_constant_blob(blob_size, constants_internal_offset); + #if defined(USE_CUDA) || defined(USE_XPU) + constant_blob_ = RAII_gpuMalloc(blob_size); ++#elif defined(USE_NPU) ++ constant_blob_ = RAII_npuMalloc(blob_size); + #endif + } + if (!include_weights) { +@@ -276,7 +333,7 @@ class AOTInductorModelBase { + size_t bytes_read = 0; + for (size_t i = 0; i < num_constants; i++) { + bool from_folded = this->constant_from_folded(i); +-#if not defined(USE_XPU) && not defined(USE_CUDA) ++#if not defined(USE_XPU) && not defined(USE_CUDA) && not defined(USE_NPU) + if (from_folded) { + // We do not reallocate and copy for CPU. + continue; +@@ -306,11 +363,11 @@ class AOTInductorModelBase { + AtenTensorHandle tensor_handle = nullptr; + #ifdef AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1 + // When opaque_metadata_size is not 0, we need to have the +- // aoti_torch_create_tensor_from_blob_v2 available ++ // aoti_torch_create_tensor_from_blob_npu_v2 available + AOTI_RUNTIME_CHECK( + opaque_metadata_size == 0, + "Expect opaque_metadata_size to be 0 when AOTI_USE_CREATE_TENSOR_FROM_BLOB_V1 is defined"); +- AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob( ++ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_npu( + internal_ptr, + ndim, + size, +@@ -321,7 +378,7 @@ class AOTInductorModelBase { + device_idx_, + &tensor_handle)); + #else +- AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_v2( ++ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_npu_v2( + internal_ptr, + ndim, + size, +@@ -347,6 +404,11 @@ class AOTInductorModelBase { + return std::move(constant_blob_); + } + #endif ++#ifdef USE_NPU ++ NPUPtr&& release_constant_blob() { ++ return std::move(constant_blob_); ++ } ++#endif + + std::shared_ptr> get_constants_array() { + return constants_; +@@ -361,7 +423,7 @@ class AOTInductorModelBase { + size_t bytes_read, + size_t data_size, + bool skip_copy) { +-#if defined(USE_CUDA) || defined(USE_XPU) ++#if defined(USE_CUDA) || defined(USE_XPU) || defined(USE_NPU) + auto* constants_ptr = static_cast(constant_blob_.get()); + uint8_t* internal_ptr = constants_ptr + constant_offset; + // Copy data to GPU memory +@@ -374,6 +436,13 @@ class AOTInductorModelBase { + ->memcpy(internal_ptr, _get_constants_start() + bytes_read, data_size) + .wait(); + ++#elif defined(USE_NPU) ++ AOTI_RUNTIME_DEVICE_CHECK(aclrtMemcpy( ++ internal_ptr, ++ data_size, ++ _get_constants_start() + bytes_read, ++ data_size, ++ ACL_MEMCPY_HOST_TO_DEVICE)); + #else + AOTI_RUNTIME_DEVICE_CHECK(cudaMemcpy( + internal_ptr, +@@ -394,7 +463,7 @@ class AOTInductorModelBase { + void compute_gpu_constant_blob( + size_t& blob_size, + std::vector& constants_internal_offset) { +-#if defined(USE_CUDA) || defined(USE_XPU) ++#if defined(USE_CUDA) || defined(USE_XPU) || defined(USE_NPU) + size_t num_constants = this->num_constants(); + // Compute required blob size with 64-alignment if on GPU. + blob_size = 0; +@@ -544,6 +613,19 @@ class AOTInductorModelBase { + throw std::runtime_error( + std::string("The model did not finish successfully. Error: ") + + cudaGetErrorString(cudaGetLastError())); ++ ++#elif defined(USE_NPU) ++ if (!run_finished_) { ++ throw std::runtime_error{"Model NPU event was not initialized"}; ++ } ++ aclrtEventRecordedStatus recordStatus = ACL_EVENT_RECORDED_STATUS_NOT_READY; ++ AOTI_RUNTIME_DEVICE_CHECK(aclrtQueryEventStatus(*run_finished_, &recordStatus)); ++ ++ if (recordStatus == ACL_EVENT_RECORDED_STATUS_COMPLETE) { ++ return true; ++ } else { ++ return false; ++ } + #elif defined(USE_XPU) + if (!run_finished_) { + throw std::runtime_error{"Model XPU event was not initialized"}; +@@ -648,6 +730,12 @@ class AOTInductorModelBase { + GPUPtr constant_blob_; + #endif // USE_CUDA + ++#ifdef USE_NPU ++ // Holds the blob storage for constants' at::Tensor for CUDA. ++ NPUPtr constant_blob_; ++#endif // USE_NPU ++ ++ + #ifdef USE_MMAP_SELF + uint8_t* self_mmap = NULL; + #endif +@@ -666,6 +754,8 @@ class AOTInductorModelBase { + std::optional run_finished_; + #elif defined(USE_XPU) + std::optional run_finished_; ++#elif defined(USE_NPU) ++ std::optional run_finished_; + #else // !USE_CUDA + bool run_finished_{}; + #endif +diff --git a/include/torch/csrc/inductor/aoti_runtime/model_container.h b/include/torch/csrc/inductor/aoti_runtime/model_container.h +index d94ee86..29cb503 100644 +--- a/include/torch/csrc/inductor/aoti_runtime/model_container.h ++++ b/include/torch/csrc/inductor/aoti_runtime/model_container.h +@@ -52,7 +52,7 @@ class AOTInductorModelContainer { + output_names_.emplace_back(model->output_name(static_cast(i))); + } + model->load_constants(); +-#if defined(USE_CUDA) || defined(USE_XPU) ++#if defined(USE_CUDA) || defined(USE_XPU) || defined(USE_NPU) + constant_blob_ = model->release_constant_blob(); + constants_internal_offset_.resize(model->num_constants()); + model->compute_gpu_constant_blob(blob_size_, constants_internal_offset_); +@@ -299,6 +299,13 @@ class AOTInductorModelContainer { + ->memcpy(internal_constants_ptr, user_constant_ptr, constant_size) + .wait(); + ++#elif defined(USE_NPU) ++AOTI_RUNTIME_DEVICE_CHECK(aclrtMemcpy( ++ internal_constants_ptr, ++ constant_size, ++ user_constant_ptr, ++ constant_size, ++ ACL_MEMCPY_HOST_TO_DEVICE)); + #else + AOTI_RUNTIME_DEVICE_CHECK(cudaMemcpy( + internal_constants_ptr, +@@ -316,7 +323,7 @@ class AOTInductorModelContainer { + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(tensor, &stride)); + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_storage_offset(tensor, &offset)); +- AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob( ++ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_npu( + internal_constants_ptr, + models_[0]->constant_ndim(idx), + models_[0]->constant_shape(idx), +@@ -325,6 +332,8 @@ class AOTInductorModelContainer { + models_[0]->constant_dtype(idx), + #ifdef USE_XPU + aoti_torch_device_type_xpu(), ++#elif defined(USE_NPU) ++ aoti_torch_device_type_npu(), + #else + aoti_torch_device_type_cuda(), + #endif +@@ -418,6 +427,17 @@ class AOTInductorModelContainer { + std::vector constants_internal_offset_; + #endif // USE_CUDA + ++#ifdef USE_NPU ++ // Holds the blob storage for constants' at::Tensor for CUDA. ++ NPUPtr constant_blob_; ++ NPUPtr constant_blob_secondary_; ++ ++ // Let's place this within USE_NPU at the moment before we fully support ++ // update for CPU cases. ++ size_t blob_size_; ++ std::vector constants_internal_offset_; ++#endif // USE_NPU ++ + // Determine which constants is being used for the model. + // If true, + // constants_map_secondary/constant_blob_secondary/constants_array_secondary +@@ -485,6 +505,20 @@ class AOTInductorModelContainer { + } + #endif // USE_CUDA + ++#ifdef USE_NPU ++ void* get_constant_blob_ptr(bool get_inactive) { ++ if ((get_inactive && use_secondary_) || ++ (!get_inactive && !use_secondary_)) { ++ return constant_blob_.get(); ++ } else { ++ if (!constant_blob_secondary_) { ++ constant_blob_secondary_ = RAII_npuMalloc(blob_size_); ++ } ++ return constant_blob_secondary_.get(); ++ } ++ } ++#endif // USE_NPU ++ + std::shared_ptr get_constants_map(bool get_inactive) { + if ((get_inactive && use_secondary_) || + (!get_inactive && !use_secondary_)) { +diff --git a/include/torch/csrc/inductor/aoti_runtime/thread_local.h b/include/torch/csrc/inductor/aoti_runtime/thread_local.h +index fd931c9..a614bf4 100644 +--- a/include/torch/csrc/inductor/aoti_runtime/thread_local.h ++++ b/include/torch/csrc/inductor/aoti_runtime/thread_local.h +@@ -66,7 +66,7 @@ struct ThreadLocalCachedOutputTensor> { + // NOLINTNEXTLINE(*arrays*) + storage_ = std::make_unique(t.numel()); + AtenTensorHandle handle = nullptr; +- AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob( ++ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_npu( + storage_.get(), + t.sizes().size(), + t.sizes().data(), +diff --git a/include/torch/csrc/inductor/aoti_runtime/utils_npu.h b/include/torch/csrc/inductor/aoti_runtime/utils_npu.h +new file mode 100644 +index 0000000..2a3d8ea +--- /dev/null ++++ b/include/torch/csrc/inductor/aoti_runtime/utils_npu.h +@@ -0,0 +1,114 @@ ++#pragma once ++ ++#ifdef USE_CUDA ++// WARNING: Be careful when adding new includes here. This header will be used ++// in model.so, and should not refer to any aten/c10 headers except the stable ++// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule ++// applies to other files under torch/csrc/inductor/aoti_runtime/. ++#include ++ ++#include ++#include ++ ++namespace torch::aot_inductor { ++ ++inline void delete_cuda_guard(void* ptr) { ++ AOTI_TORCH_ERROR_CODE_CHECK( ++ aoti_torch_delete_cuda_guard(reinterpret_cast(ptr))); ++} ++ ++inline void delete_cuda_stream_guard(void* ptr) { ++ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_cuda_stream_guard( ++ reinterpret_cast(ptr))); ++} ++ ++class AOTICudaGuard { ++ public: ++ AOTICudaGuard(int32_t device_index) : guard_(nullptr, delete_cuda_guard) { ++ CUDAGuardHandle ptr = nullptr; ++ AOTI_TORCH_ERROR_CODE_CHECK( ++ aoti_torch_create_cuda_guard(device_index, &ptr)); ++ guard_.reset(ptr); ++ } ++ ++ void set_index(int32_t device_index) { ++ AOTI_TORCH_ERROR_CODE_CHECK( ++ aoti_torch_cuda_guard_set_index(guard_.get(), device_index)); ++ } ++ ++ private: ++ std::unique_ptr guard_; ++}; ++ ++class AOTICudaStreamGuard { ++ public: ++ AOTICudaStreamGuard(cudaStream_t stream, int32_t device_index) ++ : guard_(nullptr, delete_cuda_stream_guard) { ++ CUDAStreamGuardHandle ptr = nullptr; ++ AOTI_TORCH_ERROR_CODE_CHECK( ++ aoti_torch_create_cuda_stream_guard(stream, device_index, &ptr)); ++ guard_.reset(ptr); ++ } ++ ++ private: ++ std::unique_ptr guard_; ++}; ++ ++} // namespace torch::aot_inductor ++#endif // USE_CUDA ++ ++#ifdef USE_NPU ++// WARNING: Be careful when adding new includes here. This header will be used ++// in model.so, and should not refer to any aten/c10 headers except the stable ++// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule ++// applies to other files under torch/csrc/inductor/aoti_runtime/. ++#include ++ ++#include ++ ++namespace torch::aot_inductor { ++ ++inline void delete_npu_guard(void* ptr) { ++ AOTI_TORCH_ERROR_CODE_CHECK( ++ aoti_torch_delete_npu_guard(reinterpret_cast(ptr))); ++} ++ ++inline void delete_npu_stream_guard(void* ptr) { ++ AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_npu_stream_guard( ++ reinterpret_cast(ptr))); ++} ++ ++class AOTINpuGuard { ++ public: ++ AOTINpuGuard(int32_t device_index) : guard_(nullptr, delete_npu_guard) { ++ NPUGuardHandle ptr = nullptr; ++ AOTI_TORCH_ERROR_CODE_CHECK( ++ aoti_torch_create_npu_guard(device_index, &ptr)); ++ guard_.reset(ptr); ++ } ++ ++ void set_index(int32_t device_index) { ++ AOTI_TORCH_ERROR_CODE_CHECK( ++ aoti_torch_npu_guard_set_index(guard_.get(), device_index)); ++ } ++ ++ private: ++ std::unique_ptr guard_; ++}; ++ ++class AOTINpuStreamGuard { ++ public: ++ AOTINpuStreamGuard(aclrtStream stream, int32_t device_index) ++ : guard_(nullptr, delete_npu_stream_guard) { ++ NpuStreamGuardHandle ptr = nullptr; ++ AOTI_TORCH_ERROR_CODE_CHECK( ++ aoti_torch_create_npu_stream_guard(stream, device_index, &ptr)); ++ guard_.reset(ptr); ++ } ++ ++ private: ++ std::unique_ptr guard_; ++}; ++ ++} // namespace torch::aot_inductor ++#endif // USE_NPU +diff --git a/include/torch/csrc/inductor/aoti_torch/c/shim.h b/include/torch/csrc/inductor/aoti_torch/c/shim.h +index 4c6c9af..e66a0b9 100644 +--- a/include/torch/csrc/inductor/aoti_torch/c/shim.h ++++ b/include/torch/csrc/inductor/aoti_torch/c/shim.h +@@ -88,6 +88,139 @@ using AOTITorchError = int32_t; + #define AOTI_TORCH_SUCCESS 0 + #define AOTI_TORCH_FAILURE 1 + ++ ++ ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__adaptive_avg_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__adaptive_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__adaptive_avg_pool3d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__adaptive_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__cdist_backward(AtenTensorHandle grad, AtenTensorHandle x1, AtenTensorHandle x2, double p, AtenTensorHandle cdist, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__embedding_bag(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__embedding_bag_dense_backward(AtenTensorHandle grad, AtenTensorHandle indices, AtenTensorHandle offset2bag, AtenTensorHandle bag_size, AtenTensorHandle maximum_indices, int64_t num_weights, int32_t scale_grad_by_freq, int64_t mode, AtenTensorHandle* per_sample_weights, int64_t padding_idx, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__embedding_bag_forward_only(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__embedding_bag_per_sample_weights_backward(AtenTensorHandle grad, AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, AtenTensorHandle offset2bag, int64_t mode, int64_t padding_idx, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__scaled_dot_product_flash_attention_for_npu(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__scaled_dot_product_flash_attention_for_npu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_adaptive_max_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_adaptive_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_adaptive_max_pool3d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_adaptive_max_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_add_Scalar(AtenTensorHandle self, double other, double alpha, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_add_Tensor(AtenTensorHandle self, AtenTensorHandle other, double alpha, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_addbmm(AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_addmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_addmv(AtenTensorHandle self, AtenTensorHandle mat, AtenTensorHandle vec, double beta, double alpha, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_angle(AtenTensorHandle self, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_avg_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_baddbmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_cholesky_inverse(AtenTensorHandle self, int32_t upper, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_cholesky_solve(AtenTensorHandle self, AtenTensorHandle input2, int32_t upper, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_convolution(AtenTensorHandle input, AtenTensorHandle weight, AtenTensorHandle* bias, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_convolution_backward(AtenTensorHandle grad_output, AtenTensorHandle input, AtenTensorHandle weight, const int64_t** bias_sizes, int64_t bias_sizes_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, const int32_t* output_mask, int64_t output_mask_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_cummax(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_cummin(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_fractional_max_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_fractional_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_fractional_max_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_fractional_max_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_gcd(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_geqrf(AtenTensorHandle self, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_grid_sampler_2d_backward(AtenTensorHandle grad_output, AtenTensorHandle input, AtenTensorHandle grid, int64_t interpolation_mode, int64_t padding_mode, int32_t align_corners, const int32_t* output_mask, int64_t output_mask_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_histc(AtenTensorHandle self, int64_t bins, double min, double max, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_histogram_bin_ct(AtenTensorHandle self, int64_t bins, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_index_Tensor(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_index_put(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle values, int32_t accumulate, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_index_reduce(AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle source, const char* reduce, int32_t include_self, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_kthvalue(AtenTensorHandle self, int64_t k, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_logcumsumexp(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_lu_unpack(AtenTensorHandle LU_data, AtenTensorHandle LU_pivots, int32_t unpack_data, int32_t unpack_pivots, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_masked_scatter(AtenTensorHandle self, AtenTensorHandle mask, AtenTensorHandle source, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_masked_scatter_backward(AtenTensorHandle grad_output, AtenTensorHandle mask, const int64_t* sizes, int64_t sizes_len_, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_masked_select(AtenTensorHandle self, AtenTensorHandle mask, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_max_pool2d_with_indices(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_max_pool2d_with_indices_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle indices, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_max_pool3d_with_indices(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_max_pool3d_with_indices_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle indices, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_max_unpool2d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_max_unpool3d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_median(AtenTensorHandle self, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_mode(AtenTensorHandle self, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_mul_Tensor(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_nanmedian(AtenTensorHandle self, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_native_dropout(AtenTensorHandle input, double p, int32_t* train, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_ormqr(AtenTensorHandle self, AtenTensorHandle input2, AtenTensorHandle input3, int32_t left, int32_t transpose, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_polar(AtenTensorHandle abs, AtenTensorHandle angle, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_pow_Scalar(double self, AtenTensorHandle exponent, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_pow_Tensor_Scalar(AtenTensorHandle self, double exponent, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_pow_Tensor_Tensor(AtenTensorHandle self, AtenTensorHandle exponent, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_rand_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_randint_generator(int64_t high, const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_randint_low(int64_t low, int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_randn_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_reshape(AtenTensorHandle self, const int64_t* shape, int64_t shape_len_, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_searchsorted_Scalar(AtenTensorHandle sorted_sequence, double self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_searchsorted_Tensor(AtenTensorHandle sorted_sequence, AtenTensorHandle self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_segment_reduce(AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* indices, AtenTensorHandle* offsets, int64_t axis, int32_t unsafe, double* initial, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_sort_stable(AtenTensorHandle self, int32_t* stable, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_topk(AtenTensorHandle self, int64_t k, int64_t dim, int32_t largest, int32_t sorted, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_triangular_solve(AtenTensorHandle self, AtenTensorHandle A, int32_t upper, int32_t transpose, int32_t unitriangular, AtenTensorHandle* ret0, AtenTensorHandle* ret1); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_uniform(AtenTensorHandle self, double from, double to, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_upsample_bicubic2d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_h, double* scales_w, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_upsample_linear1d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_upsample_trilinear3d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_d, double* scales_h, double* scales_w, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_view_dtype(AtenTensorHandle self, int32_t dtype, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_view_as_complex(AtenTensorHandle self, AtenTensorHandle* ret0); ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_npu_view_as_real(AtenTensorHandle self, AtenTensorHandle* ret0); ++ ++ ++ ++ + // Getter functions for retrieving various constants from the runtime, that + // can subsequently be passed to other aoti_* functions. By hiding these + // behind functions, the precise value of device/dtype is NOT part of the +@@ -97,6 +230,7 @@ using AOTITorchError = int32_t; + AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cpu(); + AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cuda(); + AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_xpu(); ++AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_npu(); + AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_privateuse1(); + + AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2(); +@@ -293,7 +427,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_empty_strided( + AtenTensorHandle* ret_new_tensor // returns new reference + ); + +-AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob( ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_npu( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, +@@ -305,7 +439,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob( + AtenTensorHandle* ret // returns new reference + ); + +-AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_v2( ++AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_npu_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, +diff --git a/utils/_triton.py b/utils/_triton.py +index 1609a3f..4977801 100644 +--- a/utils/_triton.py ++++ b/utils/_triton.py +@@ -19,7 +19,14 @@ def has_triton_package() -> bool: + def has_triton_tma(): + if has_triton_package(): + import torch +- ++ try: ++ from triton.tools.experimental_descriptor import ( # noqa: F401 ++ create_1d_tma_descriptor, ++ create_2d_tma_descriptor, ++ ) ++ return True ++ except ImportError: ++ pass + if ( + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (9, 0) +@@ -80,6 +87,7 @@ def has_triton() -> bool: + return True + + triton_supported_devices = { ++ "npu": _return_true, + "cuda": cuda_extra_check, + "xpu": _return_true, + "cpu": cpu_extra_check, + +diff --git a/utils/cpp_extension.py b/utils/cpp_extension.py +index b4a70dc..41cd7a2 100644 +--- a/utils/cpp_extension.py ++++ b/utils/cpp_extension.py +@@ -141,6 +141,22 @@ def _find_rocm_home() -> Optional[str]: + file=sys.stderr) + return rocm_home + ++def _find_npu_home() -> Optional[str]: ++ """Find the NPU install path.""" ++ # Guess #1 ++ npu_home = os.environ.get('ASCEND_HOME_PATH') or os.environ.get('ASCEND_TOOLKIT_HOME') or os.environ.get('TOOLCHAIN_HOME') ++ if npu_home is None: ++ npu_home = '/usr/local/Ascend/ascend-toolkit/latest' ++ if not os.path.exists(npu_home): ++ npu_home = None ++ if not npu_home: ++ print(f"Warning ASCEND_HOME_PATH not found") ++ # TODO NPU runtime check ++ # if npu_home and not torch.cuda.is_available(): ++ # print(f"No CUDA runtime is found, using CUDA_HOME='{cuda_home}'", ++ # file=sys.stderr) ++ return npu_home ++ + def _find_sycl_home() -> Optional[str]: + """Find the OneAPI install path.""" + # Guess #1 +@@ -239,6 +255,7 @@ CUDA_HOME = _find_cuda_home() if torch.cuda._is_compiled() else None + CUDNN_HOME = os.environ.get('CUDNN_HOME') or os.environ.get('CUDNN_PATH') + SYCL_HOME = _find_sycl_home() if torch.xpu._is_compiled() else None + ++NPU_HOME = _find_npu_home() + # PyTorch releases have the version pattern major.minor.patch, whereas when + # PyTorch is built from source, we append the git commit hash, which gives + # it the below pattern. +@@ -1235,6 +1252,14 @@ def include_paths(device_type: str = "cpu") -> List[str]: + paths.append(os.path.join(CUDNN_HOME, 'include')) + elif device_type == "xpu": + paths.append(_join_sycl_home('include')) ++ elif device_type == "npu": ++ npu_home_include = _join_npu_home('x86_64-linux/include') ++ paths.append(npu_home_include) ++ npu_exp_include = _join_npu_home('x86_64-linux/include/experiment') ++ paths.append(npu_exp_include) ++ npu_home_prof_include = _join_npu_home('x86_64-linux/include/experiment/msprof') ++ paths.append(npu_home_prof_include) ++ paths.append("") + return paths + + +@@ -1281,6 +1306,10 @@ def library_paths(device_type: str = "cpu") -> List[str]: + lib_dir = 'lib' + + paths.append(_join_sycl_home(lib_dir)) ++ elif device_type == "npu": ++ npu_home_lib = _join_npu_home('lib64') ++ paths.append(npu_home_lib) ++ paths.append("/host/zcl/aoti_files") + + return paths + +@@ -2532,3 +2561,9 @@ def _is_cuda_file(path: str) -> bool: + if IS_HIP_EXTENSION: + valid_ext.append('.hip') + return os.path.splitext(path)[1] in valid_ext ++ ++def _join_npu_home(*paths) -> str: ++ if NPU_HOME is None: ++ raise OSError('ASCEND_HOME_PATH environment variable is not set. ' ++ 'Please set it to your CUDA install root. suggest using set_env.sh') ++ return os.path.join(NPU_HOME, *paths) +\ No newline at end of file \ No newline at end of file diff --git a/torch_npu/_inductor/patch/torch_npu_changes.patch b/torch_npu/_inductor/patch/torch_npu_changes.patch new file mode 100644 index 0000000000000000000000000000000000000000..73da7279c5cf5b8d851315dad4879b712bf9f47c --- /dev/null +++ b/torch_npu/_inductor/patch/torch_npu_changes.patch @@ -0,0 +1,14 @@ +diff --git a/utils/_inductor.py b/utils/_inductor.py +index 9a36ddb..5a3c874 100755 +--- a/utils/_inductor.py ++++ b/utils/_inductor.py +@@ -1,5 +1,9 @@ ++import os + from torch._inductor.codegen.common import DeviceOpOverrides, register_device_op_overrides + ++_HERE = os.path.abspath(__file__) ++_TORCH_NPU_PATH = os.path.dirname(os.path.dirname(_HERE)) ++TORCH_NPU_LIB_PATH = "/host/zcl/pta_v2.6/libtorch_npu/lib" + + class NPUDeviceOpOverrides(DeviceOpOverrides): + def import_get_raw_stream_as(self, name): diff --git a/torch_npu/_inductor/runtime.py b/torch_npu/_inductor/runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..296fff4de1a3cfc2b230d9bb66ab754858e9188a --- /dev/null +++ b/torch_npu/_inductor/runtime.py @@ -0,0 +1,71 @@ +from typing import Optional +import functools + +from torch._inductor.runtime.hints import DeviceProperties +from .config import num_vector_core +from typing import List , Dict +from torch.utils._triton import has_triton, has_triton_package +from torch._inductor.remote_cache import JsonDataTy + + +if has_triton_package(): + from triton import Config + +# overload this to avoid autotune after best_config already generated +def _load_cached_autotuning( + best_config: Dict[str, JsonDataTy], + configs_hash: str, + configs: List[Config], + inductor_meta: Dict, +) -> Optional[Config]: + if best_config is None: + return None + if best_config.pop("configs_hash", None) != configs_hash: + return None + # Remove time taken for comparison + best_config.pop("time_taken_ms", None) + + #if inductor_meta.get("coordinate_descent_tuning") : + num_warps = best_config.pop("num_warps") + num_stages = best_config.pop("num_stages") + triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages) + triton_config.found_by_coordesc = True + return triton_config + + +class NPUDeviceProperties(DeviceProperties): + + + @classmethod + @functools.lru_cache(None) + def create(cls, device) -> DeviceProperties: + import torch + from torch._dynamo.device_interface import get_interface_for_device + + device_type = device.type + + if torch.version.hip and device_type == "cuda": + device_type = "hip" + + device_interface = get_interface_for_device(device) + props = device_interface.get_device_properties(device) + + try: + multi_processor_count = num_vector_core + except AttributeError: + if device_type == "xpu": + multi_processor_count = props.gpu_subslice_count + else: + raise + return cls( + type=device_type, + index=device.index, + multi_processor_count=multi_processor_count, + cc=device_interface.get_compute_capability(device), + major=getattr(props, "major", None), + regs_per_multiprocessor=getattr(props, "regs_per_multiprocessor", None), + max_threads_per_multi_processor=getattr( + props, "max_threads_per_multi_processor", None + ), + warp_size=getattr(props, "warp_size", 32 if device_type != "cpu" else None), + ) diff --git a/torch_npu/_inductor/utils.py b/torch_npu/_inductor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..697059f7a885b8b2ec3912954728b6a3a84a234c --- /dev/null +++ b/torch_npu/_inductor/utils.py @@ -0,0 +1,7 @@ +import torch +import torch_npu + + +# Not good implementation, but no other way +def get_current_raw_stream(device): + return torch.npu.current_stream(device).npu_stream \ No newline at end of file