diff --git a/samples/sample.py b/samples/sample.py index fa84542c052526bf1804e0cf5cbfb9d7b1a9bfa0..4cfb0e9da7e02eac3a072dcec3855dd55b045da4 100644 --- a/samples/sample.py +++ b/samples/sample.py @@ -6,6 +6,7 @@ from tdb.common import tdb_mish_grad, tdb_mish, tdb_fast_gelu, tdb_mse_loss_grad from tdb.common import tdb_mse_loss, tdb_clip_by_value from tdb.common import tdb_group_norm from tdb.common import tdb_mse_loss_grad +from tdb.common import tdb_less_equal def tdb_test(): @@ -62,6 +63,12 @@ def tdb_group_norm_test(): c = torch.randn(320, 1, 1).half().npu() res = tdb_group_norm(a, b, c, 32, 0.00001) print("tdb_group_norm_test result\n: ", res) + +def tdb_less_equal_test(): + x1 = torch.randn(3, 4).half().npu() + x2 = torch.randn(3, 4).half().npu() + res = tdb_less_equal(x1, x2) + print("tdb_less_equal_test result\n: ", res) if __name__ == "__main__": @@ -73,4 +80,5 @@ if __name__ == "__main__": tdb_mse_loss_test() tdb_clip_by_value_test() tdb_group_norm_test() - tdb_mse_loss_grad_test() \ No newline at end of file + tdb_mse_loss_grad_test() + tdb_less_equal_test() diff --git a/tdb/common/__init__.py b/tdb/common/__init__.py index 9024c40333e532effa38000e0f49a4c0deab33f8..966ffde535793fe5073c47b136b35de894de4a40 100644 --- a/tdb/common/__init__.py +++ b/tdb/common/__init__.py @@ -5,4 +5,5 @@ from .ops.mse_loss import tdb_mse_loss from .ops.clip_by_value import tdb_clip_by_value from .ops.mse_loss_grad import tdb_mse_loss_grad from .ops.group_norm import tdb_group_norm -from .ops.mish import tdb_mish \ No newline at end of file +from .ops.mish import tdb_mish +from .ops.less_equal import tdb_less_equal \ No newline at end of file diff --git a/tdb/common/ops/csrc/LessEqualKernelNpu.cpp b/tdb/common/ops/csrc/LessEqualKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..404184f07e24dff217a45fd162db8fcb06b50295 --- /dev/null +++ b/tdb/common/ops/csrc/LessEqualKernelNpu.cpp @@ -0,0 +1,26 @@ +// Copyright (c) 2025 Td-Tech Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "OpApiCommon.h" +#include "functions.h" + +at::Tensor less_equal(const at::Tensor &x1, const at::Tensor &x2) +{ + at::Tensor out = at::empty(x1.sizes(), x1.options().dtype(at::kBool)); + EXEC_NPU_CMD(aclnnLessEqual, x1, x2, out); + return out; +} diff --git a/tdb/common/ops/csrc/functions.h b/tdb/common/ops/csrc/functions.h index e827961e6c57abae0cfeefb823eaeb77b6f9098f..cea43b99b2a2f4491f931e2bd6ee17d3780b5bef 100644 --- a/tdb/common/ops/csrc/functions.h +++ b/tdb/common/ops/csrc/functions.h @@ -32,4 +32,5 @@ at::Tensor group_norm(const at::Tensor &x, const at::Tensor &gamma, const at::Te at::Tensor mse_loss_grad(const at::Tensor &predict, const at::Tensor &label, const at::Tensor &dout, const char* reduction); at::Tensor clip_by_value(const at::Tensor &x1, const at::Tensor &clip_value_min, const at::Tensor &clip_value_max); at::Tensor mse_loss(const at::Tensor &predict, const at::Tensor &label, const char* reduction); +at::Tensor less_equal(const at::Tensor &x1, const at::Tensor &x2); #endif // __FUNCTIONS_H__ diff --git a/tdb/common/ops/csrc/pybind.cpp b/tdb/common/ops/csrc/pybind.cpp index 37dafad6557ec397eff2b19fa493c9fcdd999b92..6a45c3c05062203781dd72fb8e0531186d0227a3 100644 --- a/tdb/common/ops/csrc/pybind.cpp +++ b/tdb/common/ops/csrc/pybind.cpp @@ -12,4 +12,5 @@ void init_common(pybind11::module &m) m.def("mse_loss_grad", &mse_loss_grad); m.def("clip_by_value", &clip_by_value); m.def("mse_loss", &mse_loss); + m.def("less_equal", &less_equal); } diff --git a/tdb/common/ops/kernels/operators/op_host/less_equal.cpp b/tdb/common/ops/kernels/operators/op_host/less_equal.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a0458957fa2477bd5277ddf11301522a7ece0c48 --- /dev/null +++ b/tdb/common/ops/kernels/operators/op_host/less_equal.cpp @@ -0,0 +1,109 @@ +#include "less_equal_tiling.h" +#include "register/op_def_registry.h" + +namespace optiling { + +constexpr int32_t CORE_NUM = 1; +constexpr int32_t UB_SIZE = 196608; + +constexpr int32_t X1_INDEX = 0; + +constexpr int32_t BYTE_BLOCK = 32; +constexpr int32_t BUFFER_NUM = 1; +constexpr int32_t BUF_COUNT = (2 + 1) * BUFFER_NUM; +constexpr int32_t INT8_BUF_COUNT = (2 + 1) * BUFFER_NUM + 4; + +template +inline T1 CeilDiv(T1 a, T2 b) { + a = int64_t(a); + b = int64_t(b); + return T1(b == 0 ? a : (a + b - 1) / b); +}; + +template +inline T1 CeilAlignA2B(T1 a, T2 b) { + a = int64_t(a); + b = int64_t(b); + return T1(b == 0 ? a : CeilDiv(a, b) * b); +}; + +static ge::graphStatus TilingFunc(gert::TilingContext* context) { + LessEqualTilingData tiling; + + uint32_t totalCoreNum = CORE_NUM; + uint64_t ubSizePlatForm = UB_SIZE; + + const gert::StorageShape* x1Shape = context->GetInputShape(X1_INDEX); + const ge::DataType dataType = context->GetInputDesc(X1_INDEX)->GetDataType(); + int32_t dTypeSize = ge::GetSizeByDataType(dataType); + int32_t elementsPerBlock = BYTE_BLOCK / dTypeSize; + uint64_t totalDataCount = x1Shape->GetStorageShape().GetShapeSize(); + + uint64_t ubMaxProcCount = ubSizePlatForm / BUF_COUNT / BYTE_BLOCK * elementsPerBlock; + context->SetTilingKey(101); + if (dataType == ge::DT_FLOAT16) { + context->SetTilingKey(201); + } else if (dataType == ge::DT_INT8) { + ubMaxProcCount = ubSizePlatForm / INT8_BUF_COUNT / BYTE_BLOCK * elementsPerBlock; + context->SetTilingKey(301); + } else if (dataType == ge::DT_INT32) { + context->SetTilingKey(401); + } + + uint64_t loopCount = totalDataCount / ubMaxProcCount; + uint64_t tailCount = totalDataCount % ubMaxProcCount; + + tiling.set_ubMaxProcCount(ubMaxProcCount); + tiling.set_totalDataCount(totalDataCount); + tiling.set_loopCount(loopCount); + tiling.set_tailCount(tailCount); + + context->SetBlockDim(totalCoreNum); + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + return ge::GRAPH_SUCCESS; +} +} // namespace optiling + +namespace ge { +static ge::graphStatus InferShape(gert::InferShapeContext* context) { + const gert::Shape* x1_shape = context->GetInputShape(0); + gert::Shape* y_shape = context->GetOutputShape(0); + *y_shape = *x1_shape; + return GRAPH_SUCCESS; +} +} // namespace ge + +namespace ops { +class LessEqual : public OpDef { +public: + explicit LessEqual(const char* name) : OpDef(name) { + this->Input("x1") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32, ge::DT_INT8}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("x2") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32, ge::DT_INT8}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("y") + .ParamType(REQUIRED) + .DataType({ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + + this->SetInferShape(ge::InferShape); + + this->AICore().SetTiling(optiling::TilingFunc); + this->AICore().AddConfig("ascend910"); + this->AICore().AddConfig("ascend310p"); + this->AICore().AddConfig("ascend310b"); + this->AICore().AddConfig("ascend910b"); + } +}; + +OP_ADD(LessEqual); +} // namespace ops \ No newline at end of file diff --git a/tdb/common/ops/kernels/operators/op_host/less_equal_tiling.h b/tdb/common/ops/kernels/operators/op_host/less_equal_tiling.h new file mode 100644 index 0000000000000000000000000000000000000000..bf6055f44a807ae0bc2b47d86d8d7c660a8ff2ea --- /dev/null +++ b/tdb/common/ops/kernels/operators/op_host/less_equal_tiling.h @@ -0,0 +1,12 @@ +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(LessEqualTilingData) + TILING_DATA_FIELD_DEF(uint64_t, ubMaxProcCount); + TILING_DATA_FIELD_DEF(uint64_t, totalDataCount); + TILING_DATA_FIELD_DEF(uint64_t, loopCount); + TILING_DATA_FIELD_DEF(uint64_t, tailCount); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(LessEqual, LessEqualTilingData) +} // namespace optiling \ No newline at end of file diff --git a/tdb/common/ops/kernels/operators/op_kernel/less_equal.cpp b/tdb/common/ops/kernels/operators/op_kernel/less_equal.cpp new file mode 100644 index 0000000000000000000000000000000000000000..efc1cce5efb41790ae87fc18f0e976585a4a0f58 --- /dev/null +++ b/tdb/common/ops/kernels/operators/op_kernel/less_equal.cpp @@ -0,0 +1,25 @@ +#include "less_equal_base.h" + +extern "C" __global__ __aicore__ void less_equal(GM_ADDR x1, GM_ADDR x2, GM_ADDR y, GM_ADDR workspace, GM_ADDR tiling) { + GET_TILING_DATA(tilingData, tiling); + + if (TILING_KEY_IS(101)) { + LessEqual::KernelLessEqual op; + op.Init(x1, x2, y, &tilingData); + op.Process(); + } else if (TILING_KEY_IS(201)) { + LessEqual::KernelLessEqual op; + op.Init(x1, x2, y, &tilingData); + op.Process(); + } else if (TILING_KEY_IS(301)) { +#if defined(ORIG_DTYPE_INPUT_DATA) && ORIG_DTYPE_INPUT_DATA == DT_INT8 + LessEqual::KernelLessEqual op; + op.Init(x1, x2, y, &tilingData); + op.Process(); +#endif + } else if (TILING_KEY_IS(401)) { + LessEqual::KernelLessEqual op; + op.Init(x1, x2, y, &tilingData); + op.Process(); + } +} \ No newline at end of file diff --git a/tdb/common/ops/kernels/operators/op_kernel/less_equal_base.h b/tdb/common/ops/kernels/operators/op_kernel/less_equal_base.h new file mode 100644 index 0000000000000000000000000000000000000000..5a4a03df3d3fd1893a7402ea5a99c6a5192cf168 --- /dev/null +++ b/tdb/common/ops/kernels/operators/op_kernel/less_equal_base.h @@ -0,0 +1,160 @@ +#ifndef LESS_EQUAL_BASE_H +#define LESS_EQUAL_BASE_H +#include "kernel_operator.h" + +namespace LessEqual { +using namespace AscendC; +constexpr int32_t BUFFER_NUM = 1; +constexpr int32_t BYTE_BLOCK = 32; +constexpr int32_t INT8_TMP_BUF_COUNT = 4; + +template +class KernelLessEqual { +public: + __aicore__ inline KernelLessEqual(){}; + __aicore__ inline void Init(GM_ADDR x1, GM_ADDR x2, GM_ADDR y, const LessEqualTilingData* __restrict tilingData); + __aicore__ inline void Process(); + +private: + template + __aicore__ inline T1 CeilDiv(T1 a, T2 b) { + a = int64_t(a); + b = int64_t(b); + return T1(b == 0 ? a : (a + b - 1) / b); + } + + template + __aicore__ inline T1 CeilAlignA2B(T1 a, T2 b) { + a = int64_t(a); + b = int64_t(b); + return T1(b == 0 ? a : CeilDiv(a, b) * b); + } + + __aicore__ inline void CopyIn(int64_t gmOffset, int64_t dataCount); + __aicore__ inline void Compute(int64_t gmOffset, int64_t dataCount); + __aicore__ inline void ComputeInt8(int64_t gmOffset, int64_t dataCount); + __aicore__ inline void CopyOut(int64_t gmOffset, int64_t dataCount); + +private: + TPipe pipe; + TQue inQueueX1, inQueueX2; + TQue outQueueY; + TBuf tempValBuf; + LocalTensor tempValLT; + GlobalTensor x1GM, x2GM; + GlobalTensor yGM; + + int64_t blockIdx = 0; + uint64_t perBlockCount = 0; + uint64_t perBlockCountOut = 0; + + // tiling params + uint64_t ubMaxProcCount = 0; + uint64_t totalDataCount = 0; + uint64_t loopCount = 0; + uint64_t tailCount = 0; +}; + +template +__aicore__ inline void KernelLessEqual::Init(GM_ADDR x1, GM_ADDR x2, GM_ADDR y, + const LessEqualTilingData* __restrict tilingData) { + blockIdx = GetBlockIdx(); + perBlockCount = BYTE_BLOCK / sizeof(T); + perBlockCountOut = BYTE_BLOCK / sizeof(uint8_t); + ubMaxProcCount = tilingData->ubMaxProcCount; + totalDataCount = tilingData->totalDataCount; + loopCount = tilingData->loopCount; + tailCount = tilingData->tailCount; + + x1GM.SetGlobalBuffer((__gm__ T*)x1, totalDataCount); + x2GM.SetGlobalBuffer((__gm__ T*)x2, totalDataCount); + yGM.SetGlobalBuffer(y, totalDataCount); + + int64_t singleBufferSize = ubMaxProcCount * sizeof(T); + pipe.InitBuffer(inQueueX1, BUFFER_NUM, singleBufferSize); + pipe.InitBuffer(inQueueX2, BUFFER_NUM, singleBufferSize); + pipe.InitBuffer(outQueueY, BUFFER_NUM, singleBufferSize); + +#if defined(ORIG_DTYPE_X1) && ORIG_DTYPE_X1 == DT_INT8 + pipe.InitBuffer(tempValBuf, singleBufferSize * INT8_TMP_BUF_COUNT); + tempValLT = tempValBuf.Get(); +#endif +} + +template +__aicore__ inline void KernelLessEqual::Process() { + int64_t gmOffset = 0; + for (int64_t i = 0; i < loopCount; i++) { + CopyIn(gmOffset, ubMaxProcCount); +#if defined(ORIG_DTYPE_X1) && ORIG_DTYPE_X1 == DT_INT8 + ComputeInt8(gmOffset, ubMaxProcCount); +#else + Compute(gmOffset, ubMaxProcCount); +#endif + CopyOut(gmOffset, ubMaxProcCount); + gmOffset += ubMaxProcCount; + } + if (tailCount) { + int64_t alignDataCount = CeilAlignA2B(tailCount, perBlockCount); + CopyIn(gmOffset, alignDataCount); +#if defined(ORIG_DTYPE_X1) && ORIG_DTYPE_X1 == DT_INT8 + ComputeInt8(gmOffset, alignDataCount); +#else + Compute(gmOffset, alignDataCount); +#endif + CopyOut(gmOffset, alignDataCount); + } +} + +template +__aicore__ inline void KernelLessEqual::CopyIn(int64_t gmOffset, int64_t dataCount) { + LocalTensor x1LT = inQueueX1.AllocTensor(); + LocalTensor x2LT = inQueueX2.AllocTensor(); + + DataCopy(x1LT, x1GM[gmOffset], dataCount); + DataCopy(x2LT, x2GM[gmOffset], dataCount); + inQueueX1.EnQue(x1LT); + inQueueX2.EnQue(x2LT); +} + +template +__aicore__ inline void KernelLessEqual::Compute(int64_t gmOffset, int64_t dataCount) { + LocalTensor x1LT = inQueueX1.DeQue(); + LocalTensor x2LT = inQueueX2.DeQue(); + LocalTensor yLT = outQueueY.AllocTensor(); + + Compare(yLT, x1LT, x2LT, CMPMODE::LE, dataCount); + + outQueueY.EnQue(yLT); + inQueueX1.FreeTensor(x1LT); + inQueueX2.FreeTensor(x2LT); +} + +template +__aicore__ inline void KernelLessEqual::ComputeInt8(int64_t gmOffset, int64_t dataCount) { + LocalTensor x1InLT = inQueueX1.DeQue(); + LocalTensor x2InLT = inQueueX2.DeQue(); + LocalTensor x1LT = tempValLT; + LocalTensor x2LT = tempValLT[ubMaxProcCount]; + + LocalTensor yLT = outQueueY.AllocTensor(); + + Cast(x1LT, x1InLT, RoundMode::CAST_NONE, dataCount); + Cast(x2LT, x2InLT, RoundMode::CAST_NONE, dataCount); + + Compare(yLT, x1LT, x2LT, CMPMODE::LE, dataCount); + + outQueueY.EnQue(yLT); + inQueueX1.FreeTensor(x1LT); + inQueueX2.FreeTensor(x2LT); +} + +template +__aicore__ inline void KernelLessEqual::CopyOut(int64_t gmOffset, int64_t dataCount) { + int64_t alignDataCountOut = CeilAlignA2B(tailCount, perBlockCountOut); + LocalTensor yOutLT = outQueueY.DeQue(); + DataCopy(yGM[gmOffset], yOutLT, alignDataCountOut); + outQueueY.FreeTensor(yOutLT); +} +} // namespace LessEqual +#endif // LESS_EQUAL_BASE_H \ No newline at end of file diff --git a/tdb/common/ops/less_equal.py b/tdb/common/ops/less_equal.py new file mode 100644 index 0000000000000000000000000000000000000000..f0575a54888238f1629ab9fcdfa0e2d054f820de --- /dev/null +++ b/tdb/common/ops/less_equal.py @@ -0,0 +1,19 @@ +import torch +from torch.autograd import Function +from torch.nn import Module + +import torch_npu +import tdb_C + +class LessEqualFunction(Function): + @staticmethod + def forward(ctx, x1, x2): + result = tdb_C.less_equal(x1, x2) + ctx.save_for_backward(result) + return result + + @staticmethod + def backward(ctx, grad_output): + return None + +tdb_less_equal = LessEqualFunction.apply \ No newline at end of file