From 9dd6de5847308f94b63cbaa770abb1be2455fa89 Mon Sep 17 00:00:00 2001 From: junqiang521 Date: Mon, 21 Feb 2022 15:40:55 +0800 Subject: [PATCH 1/6] =?UTF-8?q?=E9=97=A8=E7=A6=81=E5=8A=A0=E5=85=A5resnet?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ci/access_control_test.py | 10 +- ci/pytorch_resnet.py | 451 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 457 insertions(+), 4 deletions(-) create mode 100644 ci/pytorch_resnet.py diff --git a/ci/access_control_test.py b/ci/access_control_test.py index aa61c310b33..005422c798b 100644 --- a/ci/access_control_test.py +++ b/ci/access_control_test.py @@ -22,7 +22,7 @@ from abc import ABCMeta, abstractmethod BASE_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -DEFAULT_UT_FILE = os.path.join(BASE_DIR, 'test/test_network_ops/test_add.py') +DEFAULT_UT_FILE = os.path.join(BASE_DIR, 'pytorch_resnet.py') class AccurateTest(metaclass=ABCMeta): @@ -115,9 +115,11 @@ class TestMgr(): if os.path.exists(changed_file): exist_ut_file.append(changed_file) self.ut_files = exist_ut_file - - if len(self.ut_files) == 0: - self.ut_files.append(DEFAULT_UT_FILE) + + for ut in self.ut_files: + if ut.split('/')[-1] == 'run_tests.py': + self.ut_files.remove(ut) + self.ut_files.append(DEFAULT_UT_FILE) def get_ut_files(self): return self.ut_files diff --git a/ci/pytorch_resnet.py b/ci/pytorch_resnet.py new file mode 100644 index 00000000000..81c7ee8d7f2 --- /dev/null +++ b/ci/pytorch_resnet.py @@ -0,0 +1,451 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import random +import shutil +import time +import warnings + +import torch +import torch_npu +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +import torch.multiprocessing as mp +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import torchvision.models as models +import torch.npu + +SOURCE_DIR = os.environ.get('SOURCE_DIR') +BATCH_SIZE = 128 +EPOCHS_SIZE = 1 +TRAIN_STEP = 10 +LOG_STEP = 1 + +CALCULATE_DEVICE = "npu:0" +PRINT_DEVICE = "cpu" + + +model_names = sorted(name for name in models.__dict__ + if name.islower() and not name.startswith("__") + and callable(models.__dict__[name])) + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('--data', metavar='DIR', default=SOURCE_DIR, + help='path to dataset') +parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', + choices=model_names, + help='model architecture: ' + + ' | '.join(model_names) + + ' (default: resnet18)') +parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', + help='number of data loading workers (default: 8)') +parser.add_argument('--epochs', default=EPOCHS_SIZE, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=BATCH_SIZE, type=int, + metavar='N', + help='mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate', dest='lr') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') +parser.add_argument('-p', '--print-freq', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +parser.add_argument('--world-size', default=-1, type=int, + help='number of nodes for distributed training') +parser.add_argument('--rank', default=-1, type=int, + help='node rank for distributed training') +parser.add_argument('--dist-url', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='nccl', type=str, + help='distributed backend') +parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training. ') +parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use.') +parser.add_argument('--multiprocessing-distributed', action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') + +best_acc1 = 0 + + +def main(): + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + cudnn.deterministic = True + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + if args.gpu is not None: + warnings.warn('You have chosen a specific GPU. This will completely ' + 'disable data parallelism.') + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + + args.distributed = args.world_size > 1 or args.multiprocessing_distributed + + ngpus_per_node = torch.cuda.device_count() + if args.multiprocessing_distributed: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + args.world_size = ngpus_per_node * args.world_size + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) + else: + # Simply call main_worker function + main_worker(args.gpu, ngpus_per_node, args) + + +def main_worker(gpu, ngpus_per_node, args): + global best_acc1 + args.gpu = gpu + + if args.gpu is not None: + print("Use GPU: {} for training".format(args.gpu)) + + if args.distributed: + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if args.multiprocessing_distributed: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + args.rank = args.rank * ngpus_per_node + gpu + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + # create model + if args.pretrained: + print("=> using pre-trained model '{}'".format(args.arch)) + model = models.__dict__[args.arch](pretrained=True) + else: + print("=> creating model '{}'".format(args.arch)) + model = models.__dict__[args.arch]() + + if args.distributed: + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + args.batch_size = int(args.batch_size / ngpus_per_node) + args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + else: + model.cuda() + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + model = torch.nn.parallel.DistributedDataParallel(model) + elif args.gpu is not None: + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + else: + # DataParallel will divide and allocate batch_size to all available GPUs + if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): + model.features = torch.nn.DataParallel(model.features) + model.cuda() + else: + model = model.to(CALCULATE_DEVICE) + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().to(CALCULATE_DEVICE) + + optimizer = torch.optim.SGD(model.parameters(), args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + if args.gpu is None: + checkpoint = torch.load(args.resume) + else: + # Map model to be loaded to specified single gpu. + loc = 'cuda:{}'.format(args.gpu) + checkpoint = torch.load(args.resume, map_location=loc) + args.start_epoch = checkpoint['epoch'] + best_acc1 = checkpoint['best_acc1'] + if args.gpu is not None: + # best_acc1 may be from a checkpoint from a different GPU + best_acc1 = best_acc1.to(args.gpu) + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + + # Data loading code + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler) + + val_loader = torch.utils.data.DataLoader( + datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])), + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + if args.evaluate: + validate(val_loader, model, criterion, args) + return + + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, epoch, args) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, args) + + # evaluate on validation set + acc1 = validate(val_loader, model, criterion, args) + + # remember best acc@1 and save checkpoint + is_best = acc1 > best_acc1 + best_acc1 = max(acc1, best_acc1) + +def train(train_loader, model, criterion, optimizer, epoch, args): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses, top1, top5], + prefix="Epoch: [{}]".format(epoch)) + + # switch to train mode + model.train() + + end = time.time() + for i, (images, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + if 'npu' in CALCULATE_DEVICE: + target = target.to(torch.int32) + images, target = images.to(CALCULATE_DEVICE, non_blocking=True), target.to(CALCULATE_DEVICE, non_blocking=True) + + # compute output + output = model(images) + + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % LOG_STEP == 0: + progress.display(i) + + if i == TRAIN_STEP: + break + + +def validate(val_loader, model, criterion, args): + batch_time = AverageMeter('Time', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(val_loader), + [batch_time, losses, top1, top5], + prefix='Test: ') + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + end = time.time() + for i, (images, target) in enumerate(val_loader): + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + if 'npu' in CALCULATE_DEVICE: + target = target.to(torch.int32) + images, target = images.to(CALCULATE_DEVICE, non_blocking=True), target.to(CALCULATE_DEVICE, non_blocking=True) + # compute output + output = model(images) + + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % LOG_STEP == 0: + progress.display(i) + break + # TODO: this should also be done with the ProgressMeter + print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' + .format(top1=top1, top5=top5)) + + return top1.avg + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + + +def adjust_learning_rate(optimizer, epoch, args): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr = args.lr * (0.1 ** (epoch // 30)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == '__main__': + if 'npu' in CALCULATE_DEVICE: + torch.npu.set_device(CALCULATE_DEVICE) + main() -- Gitee From d2f58bf45ebd73e785d030199f3202b0310accfc Mon Sep 17 00:00:00 2001 From: junqiang521 Date: Mon, 21 Feb 2022 16:54:26 +0800 Subject: [PATCH 2/6] =?UTF-8?q?=E4=BF=AE=E6=94=B9pytorch=5Fresnet.py=20?= =?UTF-8?q?=E8=BF=90=E8=A1=8C=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ci/access_control_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/access_control_test.py b/ci/access_control_test.py index 005422c798b..266345be4b8 100644 --- a/ci/access_control_test.py +++ b/ci/access_control_test.py @@ -22,7 +22,7 @@ from abc import ABCMeta, abstractmethod BASE_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -DEFAULT_UT_FILE = os.path.join(BASE_DIR, 'pytorch_resnet.py') +DEFAULT_UT_FILE = os.path.join(BASE_DIR, 'ci/pytorch_resnet.py') class AccurateTest(metaclass=ABCMeta): -- Gitee From eb38842cc7c32609fac108806b66d0402ce205d2 Mon Sep 17 00:00:00 2001 From: junqiang521 Date: Mon, 21 Feb 2022 18:40:09 +0800 Subject: [PATCH 3/6] =?UTF-8?q?=E5=88=A0=E9=99=A4run=5Ftests.py=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E4=BB=A3=E7=A0=81=EF=BC=8C=E8=BF=99=E4=B8=AA=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E4=B8=8D=E5=AD=98=E5=9C=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ci/access_control_test.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ci/access_control_test.py b/ci/access_control_test.py index 266345be4b8..3ebebbc8848 100644 --- a/ci/access_control_test.py +++ b/ci/access_control_test.py @@ -115,10 +115,6 @@ class TestMgr(): if os.path.exists(changed_file): exist_ut_file.append(changed_file) self.ut_files = exist_ut_file - - for ut in self.ut_files: - if ut.split('/')[-1] == 'run_tests.py': - self.ut_files.remove(ut) self.ut_files.append(DEFAULT_UT_FILE) def get_ut_files(self): -- Gitee From 427364dee04570d37b529c1f07875790d389501b Mon Sep 17 00:00:00 2001 From: wangxiao Date: Fri, 18 Feb 2022 11:28:42 +0800 Subject: [PATCH 4/6] argmin, argsort, asin, atan, atan2, bitwise_not, cdist, cdist_backward, ceil, celu, gelu_backward, isnan, unfold --- test/test_network_ops/test_argmin.py | 69 ++++++ test/test_network_ops/test_argsort.py | 81 ++++++ test/test_network_ops/test_asin.py | 63 +++++ test/test_network_ops/test_atan.py | 50 ++++ test/test_network_ops/test_atan2.py | 90 +++++++ test/test_network_ops/test_bitwise_not.py | 101 ++++++++ test/test_network_ops/test_cdist.py | 191 ++++++++++++++ test/test_network_ops/test_cdist_backward.py | 117 +++++++++ test/test_network_ops/test_ceil.py | 60 +++++ test/test_network_ops/test_celu.py | 234 ++++++++++++++++++ test/test_network_ops/test_default.py | 41 +++ test/test_network_ops/test_gelu_backward.py | 77 ++++++ torch_npu/csrc/aten/ops/ArgminKernelNpu.cpp | 53 ++++ torch_npu/csrc/aten/ops/ArgsortKernelNpu.cpp | 95 +++++++ torch_npu/csrc/aten/ops/AsinKernelNpu.cpp | 58 +++++ torch_npu/csrc/aten/ops/Atan2KernelNpu.cpp | 77 ++++++ torch_npu/csrc/aten/ops/AtanKernelNpu.cpp | 57 +++++ .../csrc/aten/ops/BitwiseNotKernelNpu.cpp | 63 +++++ .../csrc/aten/ops/CdistBackwardKernelNpu.cpp | 93 +++++++ torch_npu/csrc/aten/ops/CdistKernelNpu.cpp | 100 ++++++++ torch_npu/csrc/aten/ops/CeilKernelNpu.cpp | 57 +++++ torch_npu/csrc/aten/ops/CeluKernelNpu.cpp | 55 ++++ torch_npu/csrc/aten/ops/DefaultKernelNpu.cpp | 31 +++ .../aten/ops/EmbeddingBackwardKernelNpu.cpp | 2 +- .../csrc/aten/ops/GeluBackwardKernelNpu.cpp | 47 ++++ 25 files changed, 1961 insertions(+), 1 deletion(-) create mode 100644 test/test_network_ops/test_argmin.py create mode 100644 test/test_network_ops/test_argsort.py create mode 100644 test/test_network_ops/test_asin.py create mode 100644 test/test_network_ops/test_atan.py create mode 100644 test/test_network_ops/test_atan2.py create mode 100644 test/test_network_ops/test_bitwise_not.py create mode 100644 test/test_network_ops/test_cdist.py create mode 100644 test/test_network_ops/test_cdist_backward.py create mode 100644 test/test_network_ops/test_ceil.py create mode 100644 test/test_network_ops/test_celu.py create mode 100644 test/test_network_ops/test_default.py create mode 100644 test/test_network_ops/test_gelu_backward.py create mode 100644 torch_npu/csrc/aten/ops/ArgminKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/ArgsortKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/AsinKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/Atan2KernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/AtanKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/BitwiseNotKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/CdistBackwardKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/CdistKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/CeilKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/CeluKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/DefaultKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/GeluBackwardKernelNpu.cpp diff --git a/test/test_network_ops/test_argmin.py b/test/test_network_ops/test_argmin.py new file mode 100644 index 00000000000..e64104fb71e --- /dev/null +++ b/test/test_network_ops/test_argmin.py @@ -0,0 +1,69 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# coding: utf-8 +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestArgmin(TestCase): + @Dtypes(torch.float) + def test_argmin(self, device, dtype): + inputValues = [-1000, -1, 0, 0.5, 1, 2, 1000] + expectedOutput = [0.0000, 0.2689, 0.5, 0.6225, 0.7311, 0.8808, 1.000] + precision_4dps = 0.0002 + a = torch.tensor(inputValues, dtype=dtype, device=device) + + self.assertRtolEqual(torch.tensor(inputValues, dtype=dtype, device=device).sigmoid().cpu(), + torch.tensor(expectedOutput, dtype=dtype, device=device).cpu(), + precision_4dps) + + def cpu_op_exec(self, input1, dims, keepdim=False): + output = torch.argmin(input1, dim=dims, keepdim=keepdim) + if output.dtype != torch.int32: + output = output.to(torch.int32) + output = output.numpy() + return output + + def npu_op_exec(self, input1, dims, keepdim=False): + output = torch.argmin(input1, dim=dims, keepdim=keepdim) + output = output.to("cpu") + if output.dtype != torch.int32: + output = output.to(torch.int32) + output = output.numpy() + return output + + def test_argmin_shape_format(self, device): + shape_format = [ + [ [np.float32, 0, (6, 4)], 0, False], + [ [np.float32, 0, (6, 4)], 1, True ], + [ [np.float32, 0, (2, 4, 5)], 2, True ], + [ [np.float32, 0, (1, 2, 3, 3)], 2, False], + [ [np.float32, 0, (1, 2, 3, 3)], 2, False], + [ [np.float32, 29, (15, 15, 15, 16)], 1, False], + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 1, 100) + cpu_output = self.cpu_op_exec(cpu_input, item[1], keepdim=item[2]) + npu_output = self.npu_op_exec(npu_input, item[1], keepdim=item[2]) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestArgmin, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_argsort.py b/test/test_network_ops/test_argsort.py new file mode 100644 index 00000000000..57f45693750 --- /dev/null +++ b/test/test_network_ops/test_argsort.py @@ -0,0 +1,81 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestArgSort(TestCase): + def cpu_op_exec(self, input1, dim, descending): + output = torch.argsort(input1, dim=dim, descending=descending) + return output.numpy() + + def npu_op_exec(self, input1, dim, descending): + output = torch.argsort(input1, dim=dim, descending=descending) + + return output.cpu().numpy() + + def cpu_default_op_exec(self, input1): + output = torch.argsort(input1) + return output.numpy() + + def npu_default_op_exec(self, input1): + output = torch.argsort(input1) + return output.cpu().numpy() + + def test_sort_shape_format_fp32(self, device): + shape_format = [ + [[np.float32, 0, (8, 4, 3, 9)], 2, False], + [[np.float32, 0, (2, 3)]], + [[np.float32, 0, (1, 7)], 0, True], + [[np.float32, 0, (1, 5, 6)], 1, False], + ] + + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -100, 100) + if len(item) > 1: + cpu_output = self.cpu_op_exec(cpu_input1, item[1], item[2]) + npu_output = self.npu_op_exec(npu_input1, item[1], item[2]) + else: + cpu_output = self.cpu_default_op_exec(cpu_input1) + npu_output = self.npu_default_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_sort_shape_format_fp16(self, device): + shape_format = [ + [[np.float16, 0, (8, 4, 3, 9)], 2, False], + [[np.float16, 0, (2, 3)]], + [[np.float16, 0, (1, 7)], 0, True], + [[np.float16, 0, (1, 5, 6)], 1, False], + ] + + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -100, 100) + if len(item) > 1: + cpu_output = self.cpu_op_exec(cpu_input1.to(torch.float32), item[1], item[2]) + npu_output = self.npu_op_exec(npu_input1, item[1], item[2]) + else: + cpu_output = self.cpu_default_op_exec(cpu_input1.to(torch.float32)) + npu_output = self.npu_default_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + +instantiate_device_type_tests(TestArgSort, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_asin.py b/test/test_network_ops/test_asin.py new file mode 100644 index 00000000000..695010ca0d7 --- /dev/null +++ b/test/test_network_ops/test_asin.py @@ -0,0 +1,63 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestAsin(TestCase): + def cpu_op_exec(self,input1): + output = torch.asin(input1) + output = output.numpy() + return output + + def npu_op_exec(self,input1): + output = torch.asin(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec_out(self,input1, input2): + torch.asin(input1, out=input2) + output = input2.to("cpu") + output = output.numpy() + return output + + def test_asin_common_shape_format(self, device): + shape_format = [ + [[np.float32, 0, (5,3)]], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -1, 1) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_asin_out_common_shape_format(self, device): + shape_format = [ + [[np.float32, 0, (4,3)], [np.float32, 0, (4,3)]], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -1, 1) + cpu_input2, npu_input2 = create_common_tensor(item[1], -1, 1) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec_out(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestAsin, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_atan.py b/test/test_network_ops/test_atan.py new file mode 100644 index 00000000000..d51e441933c --- /dev/null +++ b/test/test_network_ops/test_atan.py @@ -0,0 +1,50 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestAtan(TestCase): + def cpu_op_exec(self, input): + output = torch.atan(input) + return output + + def npu_op_exec(self, input): + output = torch.atan(input) + output = output.to("cpu") + return output + + def test_atan_shape_format(self, device): + shape_format = [ + [[np.float32, 0, 1]], + [[np.float32, 0, (64, 10)]], + [[np.float32, 3, (256, 2048, 7, 7)]], + [[np.float32, 4, (32, 1, 3, 3)]], + [[np.float32, 29, (10, 128)]] + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], -1, 1) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestAtan, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_atan2.py b/test/test_network_ops/test_atan2.py new file mode 100644 index 00000000000..39c992dd685 --- /dev/null +++ b/test/test_network_ops/test_atan2.py @@ -0,0 +1,90 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestAtan2(TestCase): + def cpu_op_exec(self,input1, input2): + output = torch.atan2(input1, input2) + output = output.numpy() + return output + + def npu_op_exec(self,input1, input2): + output = torch.atan2(input1, input2) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec_out(self,input1, input2, out): + torch.atan2(input1, input2, out=out) + output = out.to("cpu") + output = output.numpy() + return output + + def test_atan2_common_shape_format(self, device): + shape_format = [ + [[np.float16, 0, [4, 12, 12, 128]], [np.float16, 0, [4]]], + [[np.float16, 0, [4, 128]], [np.float16, 0, [4, 256, 12]]], + [[np.float32, 0, [4, 12, 12, 128]], [np.float32, 0, [4]]], + [[np.float32, 0, [4, 128]], [np.float32, 0, [4, 256, 12]]], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -1, 1) + cpu_input2, npu_input2 = create_common_tensor(item[0], -1, 1) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + if cpu_input2.dtype == torch.float16: + cpu_input2 = cpu_input2.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + + def test_atan2_out_common_shape_format(self, device): + shape_format = [ + [[np.float16, 0, [4, 12, 12, 128]], [np.float16, 0, [4]]], + [[np.float16, 0, [4, 128]], [np.float16, 0, [4, 256, 12]]], + [[np.float32, 0, [4, 12, 12, 128]], [np.float32, 0, [4]]], + [[np.float32, 0, [4, 128]], [np.float32, 0, [4, 256, 12]]], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -1, 1) + cpu_input2, npu_input2 = create_common_tensor(item[0], -1, 1) + cpu_out, npu_out = create_common_tensor(item[1], -1, 1) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + if cpu_input2.dtype == torch.float16: + cpu_input2 = cpu_input2.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + npu_output_out = self.npu_op_exec_out(npu_input1, npu_input2, npu_out) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_output, npu_output_out) + + def test_atan2_mix_dtype(self, device): + npu_input1, npu_input2 = create_common_tensor([np.float32, 0, (2, 3)], 1, 100) + npu_input3, npu_input4 = create_common_tensor([np.float16, 0, (2, 3)], 1, 100) + cpu_output = self.cpu_op_exec(npu_input1, npu_input3) + npu_output = self.npu_op_exec(npu_input2, npu_input4) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestAtan2, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_bitwise_not.py b/test/test_network_ops/test_bitwise_not.py new file mode 100644 index 00000000000..feb0d3e2c4a --- /dev/null +++ b/test/test_network_ops/test_bitwise_not.py @@ -0,0 +1,101 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class Test_Bitwise_Not(TestCase): + def generate_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + npu_input1 = torch.from_numpy(input1) + + return npu_input1 + + def generate_bool_data(self, shape): + input1 = np.random.randint(0, 2, shape).astype(np.bool_) + npu_input1 = torch.from_numpy(input1) + return npu_input1 + + def cpu_op_exec(self, input1): + output = torch.bitwise_not(input1) + if output.dtype not in [torch.int32, torch.int8, torch.bool]: + output = output.to(torch.int32) + output = output.numpy() + return output + + def npu_op_exec(self, input1): + input1 = input1.to("npu") + output = torch.bitwise_not(input1) + output = output.to("cpu") + if output.dtype not in [torch.int32, torch.int8, torch.bool]: + output = output.to(torch.int32) + output = output.numpy() + return output + + def npu_op_exec_out(self, input1, input2): + input1 = input1.to("npu") + input2 = input2.to("npu") + torch.bitwise_not(input1, out = input2) + output = input2.to("cpu") + if output.dtype not in [torch.int32, torch.int8, torch.bool]: + output = output.to(torch.int32) + output = output.numpy() + return output + + def test_bitwise_not_bool(self, device): + npu_input1 = self.generate_bool_data((2, 3)) + cpu_output = self.cpu_op_exec(npu_input1) + npu_output = self.npu_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_bitwise_not_int16(self, device): + npu_input1 = self.generate_data(0, 2342, (2, 3), np.int16) + cpu_output = self.cpu_op_exec(npu_input1) + npu_output = self.npu_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_bitwise_not_int32(self, device): + npu_input1 = self.generate_data(0, 34222, (2, 3), np.int32) + cpu_output = self.cpu_op_exec(npu_input1) + npu_output = self.npu_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_bitwise_not_int64(self, device): + npu_input1 = self.generate_data(0, 355553, (2, 3), np.int64) + cpu_output = self.cpu_op_exec(npu_input1) + npu_output = self.npu_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_bitwise_not_out(self, device): + shape_format = [ + [[0, 2342, [2, 3], np.int16], [0, 2342, [10, 20], np.int16]], + [[0, 34222, [2, 3], np.int32], [0, 34222, [10, 20], np.int32]], + [[0, 355553, [2, 3], np.int64], [0, 355553, [1, 1], np.int64]], + ] + for item in shape_format: + npu_input1 = self.generate_data(item[0][0], item[0][1], item[0][2], item[0][3]) + npu_input2 = self.generate_data(item[1][0], item[1][1], item[1][2], item[1][3]) + cpu_output = self.cpu_op_exec(npu_input1) + npu_output1 = self.npu_op_exec_out(npu_input1, npu_input1) + npu_output2 = self.npu_op_exec_out(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output1) + self.assertRtolEqual(cpu_output, npu_output1) + +instantiate_device_type_tests(Test_Bitwise_Not, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_cdist.py b/test/test_network_ops/test_cdist.py new file mode 100644 index 00000000000..d3e95248a88 --- /dev/null +++ b/test/test_network_ops/test_cdist.py @@ -0,0 +1,191 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class Testcdist(TestCase): + def generate_data(self, min_n, max_n, shape_x, shape_y, src_type): + np.random.seed(10086) + x1 = np.random.uniform(min_n, max_n, shape_x).astype(src_type) + x2 = np.random.uniform(min_n, max_n, shape_y).astype(src_type) + return x1, x2 + + def op_exec(self, x1, x2, p, device='cpu'): + is_fp16 = x1.dtype == np.float16 + if device == 'cpu' and is_fp16: + x1 = x1.astype(np.float32) + x2 = x2.astype(np.float32) + + x1 = torch.from_numpy(x1) + x2 = torch.from_numpy(x2) + + x1 = x1.to(device) + x2 = x2.to(device) + + y = torch.cdist(x1, x2, p) + y = y.cpu().numpy() + + if device == 'cpu' and is_fp16: + y = y.astype(np.float16) + return y + + def test_cdist_float16_1(self, device): + npu_input1, npu_input2 = self.generate_data(-1, 1, + (5, 64), (4, 64), np.float16) + cpu_output = self.op_exec(npu_input1, npu_input2, 0.0, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 0.0, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdist_float16_2(self, device): + npu_input1, npu_input2 = self.generate_data(-1, 1, + (5, 10), (4, 10), np.float16) + cpu_output = self.op_exec(npu_input1, npu_input2, 0.5, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 0.5, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdist_float16_3(self, device): + npu_input1, npu_input2 = self.generate_data(-1, 1, + (5, 10), (4, 10), np.float16) + cpu_output = self.op_exec(npu_input1, npu_input2, 1.0, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 1.0, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdist_float16_4(self, device): + npu_input1, npu_input2 = self.generate_data(-1, 1, + (5, 10), (4, 10), np.float16) + cpu_output = self.op_exec(npu_input1, npu_input2, 1.5, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 1.5, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdist_float16_5(self, device): + npu_input1, npu_input2 = self.generate_data(-1, 1, + (5, 10), (4, 10), np.float16) + cpu_output = self.op_exec(npu_input1, npu_input2, 2.0, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 2.0, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdist_float16_6(self, device): + npu_input1, npu_input2 = self.generate_data(-1, 1, + (5, 10), (4, 10), np.float16) + cpu_output = self.op_exec(npu_input1, npu_input2, 2.5, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 2.5, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdist_float16_7(self, device): + npu_input1, npu_input2 = self.generate_data(-1, 1, + (3, 5, 500), (4, 500), np.float16) + cpu_output = self.op_exec(npu_input1, npu_input2, 2.0, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 2.0, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdist_float32_1(self, device): + npu_input1, npu_input2 = self.generate_data(-1, 1, + (5, 10), (4, 10), np.float32) + cpu_output = self.op_exec(npu_input1, npu_input2, 0.0, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 0.0, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdist_float32_2(self, device): + npu_input1, npu_input2 = self.generate_data(-1, 1, + (5, 10), (4, 10), np.float32) + cpu_output = self.op_exec(npu_input1, npu_input2, 0.5, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 0.5, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdist_float32_3(self, device): + npu_input1, npu_input2 = self.generate_data(-1, 1, + (5, 10), (4, 10), np.float32) + cpu_output = self.op_exec(npu_input1, npu_input2, 1.0, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 1.0, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdist_float32_4(self, device): + npu_input1, npu_input2 = self.generate_data(-1, 1, + (5, 10), (4, 10), np.float32) + cpu_output = self.op_exec(npu_input1, npu_input2, 1.5, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 1.5, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdist_float32_5(self, device): + npu_input1, npu_input2 = self.generate_data(-1, 1, + (5, 10), (4, 10), np.float32) + cpu_output = self.op_exec(npu_input1, npu_input2, 2.0, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 2.0, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdist_float32_6(self, device): + npu_input1, npu_input2 = self.generate_data(-1, 1, + (5, 10), (4, 10), np.float32) + cpu_output = self.op_exec(npu_input1, npu_input2, 2.5, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 2.5, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdist_float32_7(self, device): + npu_input1, npu_input2 = self.generate_data(-1, 1, + (5, 500), (3, 4, 500), np.float32) + cpu_output = self.op_exec(npu_input1, npu_input2, 2.0, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 2.0, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdist_float32_8(self, device): + npu_input1, npu_input2 = self.generate_data(-100, 100, + (5, 100), (3, 4, 100), np.float32) + cpu_output = self.op_exec(npu_input1, npu_input2, 2.5, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 2.5, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdist_float32_9(self, device): + npu_input1, npu_input2 = self.generate_data(-1000, 1000, + (5, 100), (3, 4, 100), np.float32) + cpu_output = self.op_exec(npu_input1, npu_input2, 1.5, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 1.5, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdist_float32_10(self, device): + npu_input1, npu_input2 = self.generate_data(-0.1, 0.1, + (5, 100), (3, 4, 100), np.float32) + cpu_output = self.op_exec(npu_input1, npu_input2, 2.5, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 2.5, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdist_float32_11(self, device): + npu_input1, npu_input2 = self.generate_data(-0.1, 0.1, + (5, 100), (3, 4, 100), np.float32) + cpu_output = self.op_exec(npu_input1, npu_input2, 0.5, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 0.5, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdist_float32_12(self, device): + npu_input1, npu_input2 = self.generate_data(-0.1, 0.1, + (16, 11, 17, 5, 84, 2), (16, 11, 17, 5, 84, 2), np.float32) + cpu_output = self.op_exec(npu_input1, npu_input2, 2.0, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 2.0, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdist_float32_13(self, device): + npu_input1, npu_input2 = self.generate_data(-0.1, 0.1, + (2, 2, 13, 39, 97, 14, 2, 7), (2, 2, 13, 39, 97, 14, 12, 7), np.float32) + cpu_output = self.op_exec(npu_input1, npu_input2, 2.0, 'cpu') + npu_output = self.op_exec(npu_input1, npu_input2, 2.0, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + +instantiate_device_type_tests(Testcdist, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_cdist_backward.py b/test/test_network_ops/test_cdist_backward.py new file mode 100644 index 00000000000..6d3b41aaa06 --- /dev/null +++ b/test/test_network_ops/test_cdist_backward.py @@ -0,0 +1,117 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +def cdist_backward(x1, x2, p, grad, cdist): + x1 = torch.unsqueeze(x1, -2) + x2 = torch.unsqueeze(x2, -3) + grad = torch.unsqueeze(grad, -1) + cdist = torch.unsqueeze(cdist, -1) + diff = x1 - x2 + diff_abs = torch.abs(diff) + nz_cdist = torch.where(cdist == 0, torch.ones_like(cdist), cdist) + sign = torch.where(diff > 0, torch.ones_like(diff), torch.full_like(diff, -1)) + sign = torch.where(diff == 0, torch.zeros_like(diff), sign) + + if p == 0.0: + res = torch.zeros_like(diff) + elif p == 1.0: + res = grad * sign + elif p < 2.0: + res = sign * torch.pow(diff_abs, p - 1.0) * grad / torch.pow(nz_cdist, p - 1.0) + res = torch.where(cdist == 0, torch.zeros_like(res), res) + elif p == 2.0: + res = grad * diff / nz_cdist + res = torch.where(cdist == 0, torch.zeros_like(res), res) + elif p == float("inf"): + mask = torch.where(cdist - diff_abs > 0, torch.zeros_like(diff), torch.ones_like(diff)) + res = grad * sign * mask + else: + res = diff * torch.pow(diff_abs, p - 2) * grad / torch.pow(nz_cdist, p - 1.0) + res = torch.where(cdist == 0, torch.zeros_like(res), res) + res = torch.sum(res, -2) + return res + + +class Testcdist(TestCase): + def generate_data(self, min_n, max_n, shape_x, shape_y, src_type): + np.random.seed(10086) + x1 = np.random.uniform(min_n, max_n, shape_x).astype(src_type) + x2 = np.random.uniform(min_n, max_n, shape_y).astype(src_type) + return x1, x2 + + def op_exec(self, x1, x2, p, device='cpu'): + is_fp16 = x1.dtype == np.float16 + + if device == 'cpu' and is_fp16: + x1 = x1.astype(np.float32) + x2 = x2.astype(np.float32) + + x1 = torch.tensor(x1, device=device, requires_grad=True) + x2 = torch.tensor(x2, device=device, requires_grad=True) + + y = torch.cdist(x1, x2, p) + grad = torch.ones_like(y, requires_grad=True, device=device) + + if device == 'cpu' and is_fp16: + y = y.half() + y = y.float() + out = cdist_backward(x1, x2, p, grad, y) + return out.detach().numpy().astype('float16') + + y.backward(grad, retain_graph=True) + out = x1.grad.detach().cpu().numpy() + + return out + + def test_cdis_backward_common_shape(self, device): + shape_items = [ + [np.float16, (5, 10), (4, 10)], + [np.float16, (20, 5, 10), (20, 4, 10)], + [np.float32, (5, 10), (4, 10)], + [np.float32, (20, 5, 10), (20, 4, 10)], + ] + p_ranges = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5] + for item in shape_items: + for p in p_ranges: + input1, input2 = self.generate_data(-1, 1, + item[1], item[2], item[0]) + cpu_output = self.op_exec(input1, input2, p, 'cpu') + npu_output = self.op_exec(input1, input2, p, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdis_backward_input_range(self, device): + item = [np.float32, (20, 5, 5), (20, 4, 5)] + p_ranges = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5] + input_ragnes = [(-0.1, 0.1), (10, 10), (-100, 100)] + for p in p_ranges: + for min_max in input_ragnes: + + input1, input2 = self.generate_data(min_max[0], min_max[1], + item[1], item[2], item[0]) + cpu_output = self.op_exec(input1, input2, p, 'cpu') + npu_output = self.op_exec(input1, input2, p, 'npu') + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(Testcdist, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_ceil.py b/test/test_network_ops/test_ceil.py new file mode 100644 index 00000000000..66138309ffc --- /dev/null +++ b/test/test_network_ops/test_ceil.py @@ -0,0 +1,60 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestCeil(TestCase): + @Dtypes(torch.float) + def test_ceil(self, device, dtype): + cpu_input = torch.randn(10, 10, dtype=torch.float, device="cpu").to(dtype=dtype) + npu_input = cpu_input.to("npu") + cpu_output = torch.ceil_(cpu_input) + npu_output = torch.ceil_(npu_input) + npu_output = npu_output.to("cpu") + + self.assertRtolEqual(cpu_output, npu_output) + + def cpu_op_exec(self, input): + output = torch.ceil(input) + return output + + def npu_op_exec(self, input): + output = torch.ceil(input) + output = output.to("cpu") + return output + + def test_ceil_shape_format(self, device): + shape_format = [ + [np.float32, 0, 10 ], + [np.float32, 0, (64, 10) ], + [np.float32, 3, (256, 2048, 7, 7)], + [np.float32, 4, (32, 1, 3, 3) ], + [np.float32, 29, (10, 128) ], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 1, 100) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestCeil, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_celu.py b/test/test_network_ops/test_celu.py new file mode 100644 index 00000000000..20f0714e8fb --- /dev/null +++ b/test/test_network_ops/test_celu.py @@ -0,0 +1,234 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestCelu(TestCase): + def generate_data(self, min_d, max_d, shape, dtype): + input_x = np.random.uniform(min_d, max_d, shape).astype(dtype) + #modify from numpy.ndarray to torch.tensor + npu_input = torch.from_numpy(input_x) + return npu_input + + def cpu_op_exec_functional(self, input1, alpha): + output = torch.nn.functional.celu(input1, alpha=alpha) + output = output.numpy() + return output + + def npu_op_exec_functional(self, input1, alpha): + output = torch.nn.functional.celu(input1, alpha=alpha) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_exec(self, input1, alpha): + output = torch.celu(input1, alpha=alpha) + output = output.numpy() + return output + + def npu_op_exec(self, input1, alpha): + output = torch.celu(input1, alpha=alpha) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_inplace_exec_functional(self, input1, alpha): + output = torch.nn.functional.celu_(input1, alpha=alpha) + output = output.numpy() + return output + + def npu_op_inplace_exec_functional(self, input1, alpha): + output = torch.nn.functional.celu_(input1, alpha=alpha) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_inplace_exec(self, input1, alpha): + output = torch.celu_(input1, alpha=alpha) + output = output.numpy() + return output + + def npu_op_inplace_exec(self, input1, alpha): + output = torch.celu_(input1, alpha=alpha) + output = output.to("cpu") + output = output.numpy() + return output + + def test_celu_3_3_float32_alpha1(self, device): + input_x1 = self.generate_data(-1, 1, (3, 3), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 1.0) + npu_output1 = self.npu_op_exec(input_x1, 1.0) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_celu_10_10_10_10_float32_alpha1(self, device): + input_x1 = self.generate_data(-1, 1, (10, 10, 10, 10), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 1.0) + npu_output1 = self.npu_op_exec(input_x1, 1.0) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_celu_100_100_float32_alpha2(self, device): + input_x1 = self.generate_data(-1, 1, (100, 100), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, 2.0) + npu_output1 = self.npu_op_exec(input_x1, 2.0) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_celu_float16_alpha1(self, device): + def cpu_op_exec_fp16(input1, alpha): + input1 = input1.to(torch.float32) + output = torch.nn.functional.celu(input1, alpha=alpha) + output = output.numpy() + output = output.astype(np.float16) + return output + + shape_format = [ + [[np.float16, 0, (65535, 1, 1, 1)]], + [[np.float16, 0, (1, 1, 1, 65535)]], + ] + + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -2, 2) + cpu_output = cpu_op_exec_fp16(cpu_input1, 1.0) + npu_output = self.npu_op_exec(npu_input1, 1.0) + self.assertRtolEqual(cpu_output, npu_output) + + def test_celu_float16_alpha2_success(self, device): + def cpu_op_exec_fp16(input1, alpha): + input1 = input1.to(torch.float32) + output = torch.nn.functional.celu(input1, alpha=alpha) + output = output.numpy() + output = output.astype(np.float16) + return output + + shape_format = [ + [[np.float16, 0, (65535, 1, 1, 1)]], + [[np.float16, 0, (1, 1, 1, 65535)]], + ] + + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 100) + cpu_output = cpu_op_exec_fp16(cpu_input1, 2.0) + npu_output = self.npu_op_exec(npu_input1, 2.0) + self.assertRtolEqual(cpu_output, npu_output) + + def test_celu_float16_alpha2_fail(self, device): + def cpu_op_exec_fp16(input1, alpha): + input1 = input1.to(torch.float32) + output = torch.nn.functional.celu(input1, alpha=alpha) + output = output.numpy() + output = output.astype(np.float16) + return output + + shape_format = [ + [[np.float16, 0, (65535, 1, 1, 1)]], + [[np.float16, 0, (1, 1, 1, 65535)]], + ] + + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -2, 2) + cpu_output = cpu_op_exec_fp16(cpu_input1, 2.0) + npu_output = self.npu_op_exec(npu_input1, 2.0) + self.assertRtolEqual(cpu_output, npu_output) + + def test_celu_inplace_alpha1(self, device): + shape_format = [ + [[np.float32, 0, (65535, 1, 1, 1)]], + [[np.float32, 0, (1, 1, 1, 65535)]], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -2, 2) + cpu_output = self.cpu_op_inplace_exec(cpu_input1, 1.0) + npu_output = self.npu_op_inplace_exec(npu_input1, 1.0) + self.assertRtolEqual(cpu_output, npu_output) + + def test_celu_inplace_alpha2(self, device): + shape_format = [ + [[np.float32, 0, (65535, 1, 1, 1)]], + [[np.float32, 0, (1, 1, 1, 65535)]], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 100) + cpu_output = self.cpu_op_inplace_exec(cpu_input1, 2.0) + npu_output = self.npu_op_inplace_exec(npu_input1, 2.0) + self.assertRtolEqual(cpu_output, npu_output) + + def test_celu_inplace_alpha2_fail(self, device): + shape_format = [ + [[np.float32, 0, (65535, 1, 1, 1)]], + [[np.float32, 0, (1, 1, 1, 65535)]], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -2, 2) + cpu_output = self.cpu_op_inplace_exec(cpu_input1, 2.0) + npu_output = self.npu_op_inplace_exec(npu_input1, 2.0) + self.assertRtolEqual(cpu_output, npu_output) + + def test_celu_inplace_shape_format_alpha_range(self, device): + shape_format_alpha_range = [ + # [[dtype, format, shape], alpha, min, max] + [[np.float16, 2, (16, 5, 7, 11)], 5.6, -2, 2], + [[np.float32, 2, (16, 5, 7, 11)], 0.5, -2, 2], + [[np.float32, 2, (16, 5, 7, 11)], 0.7, -2, 2], + [[np.float32, 2, (16, 5, 7, 11)], 2.6, -2, 2], + [[np.float16, 2, (16, 136, 5, 4)], 0.5, -0.0078125, 0.0078125], + [[np.float16, 2, (16, 136, 5, 4)], 0.7, -0.0078125, 0.0078125], + [[np.float16, 2, (16, 136, 5, 4)], 0.5, -0.01, 0.01], + [[np.float16, 2, (176, 3, 67, 47, 5, 12)], 0.5, -2, 2], + [[np.float16, 2, (176, 3, 67, 47, 5, 12)], 5.4, -2, 2], + [[np.float16, 2, (23, 5, 11, 50, 26, 13, 1, 23)], 0.5, -2, 2], + [[np.float16, 2, (2560, 17)], 0.5, -2, 2], + [[np.float16, 2, (2560, 17)], 5.4, -2, 2] + ] + for item in shape_format_alpha_range: + cpu_input1, npu_input1 = create_common_tensor(item[0], item[2], item[3]) + alpha = item[1] + npu_output = self.npu_op_inplace_exec(npu_input1, alpha) + if item[0][0] == np.float16: + cpu_output = self.cpu_op_inplace_exec(cpu_input1.float(), alpha).astype(np.float16) + else: + cpu_output = self.cpu_op_inplace_exec(cpu_input1, alpha) + self.assertRtolEqual(cpu_output, npu_output) + + def test_celu_inplace_shape_format_alpha_range(self, device): + shape_format_alpha_range = [ + # [[dtype, format, shape], alpha, min, max] + [[np.float32, 2, (16, 5, 7, 11)], 0.5, -2, 2], + [[np.float32, 2, (16, 5, 7, 11)], 0.7, -2, 2], + [[np.float32, 2, (16, 5, 7, 11)], 2.6, -2, 2], + [[np.float16, 2, (16, 136, 5, 4)], 0.5, -0.0078125, 0.0078125], + [[np.float16, 2, (16, 136, 5, 4)], 0.7, -0.0078125, 0.0078125], + [[np.float16, 2, (16, 136, 5, 4)], 0.5, -0.01, 0.01], + [[np.float16, 2, (16, 136, 5, 4)], 0.7, -0.01, 0.01], + [[np.float16, 2, (176, 3, 67, 47, 5, 12)], 0.5, -2, 2], + [[np.float16, 2, (176, 3, 67, 47, 5, 12)], 5.4, -2, 2], + [[np.float16, 2, (2560, 17)], 0.5, -2, 2], + [[np.float16, 2, (2560, 17)], 5.4, -2, 2] + ] + for item in shape_format_alpha_range: + cpu_input1, npu_input1 = create_common_tensor(item[0], item[2], item[3]) + alpha = item[1] + npu_output = self.npu_op_exec(npu_input1, alpha) + if item[0][0] == np.float16: + cpu_output = self.cpu_op_exec(cpu_input1.float(), alpha).astype(np.float16) + else: + cpu_output = self.cpu_op_exec(cpu_input1, alpha) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestCelu, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_default.py b/test/test_network_ops/test_default.py new file mode 100644 index 00000000000..c92a1664ba8 --- /dev/null +++ b/test/test_network_ops/test_default.py @@ -0,0 +1,41 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestDefault(TestCase): + def test_isnan(self, device): + cpu_input = torch.arange(1., 10) + npu_input = cpu_input.npu() + + cpu_output = torch.isnan(cpu_input) + npu_output = torch.isnan(npu_input) + self.assertRtolEqual(cpu_output, npu_output.cpu()) + + def test_unfold(self, device): + cpu_input = torch.arange(1., 8) + npu_input = cpu_input.npu() + + cpu_output = cpu_input.unfold(0, 2, 1) + npu_output = npu_input.unfold(0, 2, 1) + self.assertRtolEqual(cpu_output, npu_output.cpu()) + +instantiate_device_type_tests(TestDefault, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_gelu_backward.py b/test/test_network_ops/test_gelu_backward.py new file mode 100644 index 00000000000..89d367df6ce --- /dev/null +++ b/test/test_network_ops/test_gelu_backward.py @@ -0,0 +1,77 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestGeluBackward(TestCase): + def generate_single_data(self, min_val, max_val, shape, dtype): + input1 = np.random.uniform(min_val, max_val, shape).astype(dtype) + npu_input1 = torch.from_numpy(input1) + return npu_input1 + + def cpu_op_exec(self, input1): + input1.requires_grad_(True) + output = torch.nn.functional.gelu(input1) + z = output.sum() + z.backward() + res = input1.grad + return res.detach().numpy() + + def npu_op_exec(self, input1): + input1 = input1.to("npu") + input1.requires_grad = True + output = torch.nn.functional.gelu(input1) + z = output.sum() + z.backward() + res = input1.grad.to("cpu") + return res.detach().numpy() + + def test_gelu_backward_float32_1(self, device): + input1= self.generate_single_data(0, 100, (4, 3, 1, 1), np.float32) + cpu_input1 = copy.deepcopy(input1) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_gelu_backward_float32_2(self, device): + input1= self.generate_single_data(0, 100, (15, 3, 1), np.float32) + cpu_input1 = copy.deepcopy(input1) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_gelu_backward_float32_3(self, device): + input1= self.generate_single_data(0, 100, (4, 4), np.float32) + cpu_input1 = copy.deepcopy(input1) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_gelu_backward_float16(self, device): + input1 = self.generate_single_data(0, 100, (5, 10, 100), np.float16) + cpu_input1 = input1.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input1) + cpu_output = cpu_output.astype(np.float16) + npu_output = self.npu_op_exec(input1) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestGeluBackward, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/ArgminKernelNpu.cpp b/torch_npu/csrc/aten/ops/ArgminKernelNpu.cpp new file mode 100644 index 00000000000..8c046ad6eca --- /dev/null +++ b/torch_npu/csrc/aten/ops/ArgminKernelNpu.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor NPUNativeFunctions::argmin(const at::Tensor& self, at::optional dim, bool keepdim) { + TORCH_CHECK( + self.numel() > 0, + "cannot perform reduction function argmin on a " + "tensor with no elements because the operation does not have an identity"); + + at::Tensor input = dim.has_value() ? self : self.reshape({-1}); + int64_t realDim = dim.has_value() ? dim.value() : 0; + bool realKeepDim = dim.has_value() ? keepdim : false; + + // calculate the output size + auto outputSize = reduce_ops_npu_output_size(input, realDim, realKeepDim); + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + outputSize, + self.options().dtype(at::kInt), + ACL_FORMAT_ND); + at::SmallVector DimVec = {realDim}; + // calculate the output result of the NPU + OpCommand cmd; + cmd.Name("ArgMin") + .Input(input) + .Input(DimVec, at::kInt) + .Output(result) + .Attr("keep_dims", realKeepDim) + .Run(); + + result = NPUNativeFunctions::npu_dtype_cast(result, at::ScalarType::Long); + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/ArgsortKernelNpu.cpp b/torch_npu/csrc/aten/ops/ArgsortKernelNpu.cpp new file mode 100644 index 00000000000..38b839a91b7 --- /dev/null +++ b/torch_npu/csrc/aten/ops/ArgsortKernelNpu.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& argsort_out_npu_no_transpose( + at::Tensor& values, + at::Tensor& indices, + const at::Tensor& self, + int64_t dim, + bool descending) { + OpCommand cmd; + cmd.Name("Sort") + .Input(self) + .Output(values) + .Output(indices) + .Attr("axis", dim) + .Attr("descending", descending) + .Run(); + + return indices; +} + +at::Tensor& argsort_out_npu_nocheck( + at::Tensor& values, + at::Tensor& indices, + const at::Tensor& self, + int64_t dim, + bool descending) { + dim = make_wrap_dim(dim, self.dim()); + int64_t lastDim = make_wrap_dim(-1, self.dim()); + + at::SmallVector perm; + for (int64_t i = 0; i < self.dim(); i++) { + perm.emplace_back(i); + } + std::swap(perm[dim], perm[lastDim]); + + at::Tensor transposeSelf = NPUNativeFunctions::npu_transpose(self, perm); + auto outputSize = transpose_npu_output_size(values, perm); + at::Tensor transposeValues = OpPreparation::ApplyTensor( + values, + outputSize); + at::Tensor transposeIndices = OpPreparation::ApplyTensor( + indices, + outputSize); + + argsort_out_npu_no_transpose( + transposeValues, transposeIndices, transposeSelf, lastDim, descending); + + NPUNativeFunctions::npu_transpose_out(indices, transposeIndices, perm); + + // indices dtype transform to Int64 + indices = NPUNativeFunctions::npu_dtype_cast(indices, at::kLong); + + return indices; +} + +at::Tensor NPUNativeFunctions::argsort(const at::Tensor& self, + int64_t dim, + bool descending) { + // construct the output tensor of the NPU + at::Tensor values = OpPreparation::ApplyTensor(self); + at::Tensor indices = OpPreparation::ApplyTensor(self, self.options().dtype(at::kInt)); + // calculate the output result of the NPU + argsort_out_npu_nocheck(values, indices, self, dim, descending); + + return indices; +} + +at::Tensor NPUNativeFunctions::argsort_dim(const at::Tensor& self, + at::Dimname dim, + bool descending) { + return NPUNativeFunctions::argsort(self, dimname_to_position(self, dim), descending); +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/AsinKernelNpu.cpp b/torch_npu/csrc/aten/ops/AsinKernelNpu.cpp new file mode 100644 index 00000000000..b0256aa0583 --- /dev/null +++ b/torch_npu/csrc/aten/ops/AsinKernelNpu.cpp @@ -0,0 +1,58 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& asin_out_npu_nocheck(at::Tensor& result, const at::Tensor& self) { + OpCommand cmd; + cmd.Name("Asin") + .Input(self) + .Output(result) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::asin_out( + const at::Tensor& self, + at::Tensor& result) { + OpPreparation::CheckOut( + {self}, + result, + self); + asin_out_npu_nocheck(result, self); + return result; +} + +at::Tensor NPUNativeFunctions::asin(const at::Tensor& self) { + at::Tensor result = OpPreparation::ApplyTensor(self); + asin_out_npu_nocheck(result, self); + return result; +} + +at::Tensor& NPUNativeFunctions::asin_(at::Tensor& self) { + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor result = asin_out_npu_nocheck(contiguousSelf, contiguousSelf); + NpuUtils::format_fresh_view(self, result); + } else { + asin_out_npu_nocheck(self, self); + } + return self; +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/Atan2KernelNpu.cpp b/torch_npu/csrc/aten/ops/Atan2KernelNpu.cpp new file mode 100644 index 00000000000..d2e4ce41a23 --- /dev/null +++ b/torch_npu/csrc/aten/ops/Atan2KernelNpu.cpp @@ -0,0 +1,77 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& atan2_out_npu_nocheck( + at::Tensor& result, + const at::Tensor& self, + const at::Tensor& other) { + auto unified_result = OpPreparation::binary_op_check(result, self, other, true); + OpCommand cmd; + cmd.Name("Atan2") + .Expect(unified_result) + .Input(self) + .Input(other) + .Output(result) + .Run(); + + return result; +} + +at::Tensor& NPUNativeFunctions::atan2_out( + const at::Tensor& self, + const at::Tensor& other, + at::Tensor& result) { + auto outputSize = broadcast_ops_npu_output_size(self, other); + + OpPreparation::CheckOut( + {self}, + result, + self, + outputSize); + + atan2_out_npu_nocheck(result, self, other); + return result; +} + +at::Tensor NPUNativeFunctions::atan2(const at::Tensor& self, const at::Tensor& other) { + auto outputSize = broadcast_ops_npu_output_size(self, other); + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + atan2_out_npu_nocheck(result, self, other); + return result; +} + +at::Tensor& NPUNativeFunctions::atan2_(at::Tensor& self, const at::Tensor& other) { + OpPreparation::CheckMemory({self, other}, {self}); + + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor result = atan2_out_npu_nocheck(contiguousSelf, contiguousSelf, other); + NpuUtils::format_fresh_view(self, result); + } else { + atan2_out_npu_nocheck(self, self, other); + } + return self; +} +TORCH_LIBRARY_IMPL(aten, NPU, m) { + m.impl("atan2", TORCH_FN(atan2_npu)); + m.impl("atan2_", TORCH_FN(atan2_npu_)); + m.impl("atan2.out", TORCH_FN(atan2_out_npu)); +} +} // namespace native +} // namespace at diff --git a/torch_npu/csrc/aten/ops/AtanKernelNpu.cpp b/torch_npu/csrc/aten/ops/AtanKernelNpu.cpp new file mode 100644 index 00000000000..0adf46763c1 --- /dev/null +++ b/torch_npu/csrc/aten/ops/AtanKernelNpu.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& atan_out_npu_nocheck(const at::Tensor& self, at::Tensor& result) { + OpCommand cmd; + cmd.Name("Atan") + .Input(self) + .Output(result) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::atan_out(const at::Tensor& self, at::Tensor& result) { + OpPreparation::CheckOut( + {self}, + result, + self); + atan_out_npu_nocheck(self, result); + return result; +} + +at::Tensor NPUNativeFunctions::atan(const at::Tensor& self) { + at::Tensor result = OpPreparation::ApplyTensor(self); + atan_out_npu_nocheck(self, result); + return result; +} + +at::Tensor& NPUNativeFunctions::atan_(at::Tensor& self) { + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor result = atan_out_npu_nocheck(contiguousSelf, contiguousSelf); + NpuUtils::format_fresh_view(self, result); + } else { + atan_out_npu_nocheck(self, self); + } + return self; +} + +}} // namespace at_npu::native \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/BitwiseNotKernelNpu.cpp b/torch_npu/csrc/aten/ops/BitwiseNotKernelNpu.cpp new file mode 100644 index 00000000000..81faf2ede03 --- /dev/null +++ b/torch_npu/csrc/aten/ops/BitwiseNotKernelNpu.cpp @@ -0,0 +1,63 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& bitwise_not_out_npu_nocheck(at::Tensor& result, const at::Tensor& self) { + string real_op_name = + (self.dtype() == at::ScalarType::Bool) ? "LogicalNot" : "Invert"; + + OpCommand cmd; + cmd.Name(real_op_name) + .Input(self) + .Output(result) + .Run(); + + return result; +} + +at::Tensor& NPUNativeFunctions::bitwise_not_out(const at::Tensor& self, at::Tensor& result) { + OpPreparation::CheckOut( + {self}, + result, + self); + + bitwise_not_out_npu_nocheck(result, self); + return result; +} + +at::Tensor NPUNativeFunctions::bitwise_not(const at::Tensor& self) { + at::Tensor result = OpPreparation::ApplyTensor(self); + bitwise_not_out_npu_nocheck(result, self); + return result; +} + +at::Tensor& NPUNativeFunctions::bitwise_not_(at::Tensor& self) { + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor result = bitwise_not_out_npu_nocheck(contiguousSelf, contiguousSelf); + NpuUtils::format_fresh_view(self, result); + } else { + bitwise_not_out_npu_nocheck(self, self); + } + return self; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/CdistBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/CdistBackwardKernelNpu.cpp new file mode 100644 index 00000000000..ead328f8f73 --- /dev/null +++ b/torch_npu/csrc/aten/ops/CdistBackwardKernelNpu.cpp @@ -0,0 +1,93 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +static void check_cdist_backward_input( + const at::Tensor& grad, + const at::Tensor& x1, + const at::Tensor& x2, + const double p, + const at::Tensor& cdist) { + TORCH_CHECK(x1.is_contiguous(), "_cdist_backward requires X1 to be contiguous"); + TORCH_CHECK(x2.is_contiguous(), "_cdist_backward requires X2 to be contiguous"); + TORCH_CHECK(cdist.is_contiguous(), "_cdist_backward requires dist to be contiguous"); + TORCH_CHECK(grad.is_contiguous(), "_cdist_backward requires grad to be contiguous"); + auto device1 = x1.device().type(); + TORCH_CHECK(device1 == at::kCPU || device1 == at::kCUDA || device1 == at::kNPU, "_cdist_backward only supports CPU, CUDA and NPU devices, X1 got: ", device1); + auto device2 = x2.device().type(); + TORCH_CHECK(device2 == at::kCPU || device2 == at::kCUDA || device2 == at::kNPU, "_cdist_backward only supports CPU, CUDA and NPU devices, X2 got: ", device2); + TORCH_CHECK(p <= std::numeric_limits::max(), "npu dose not support float64" ); +} + +at::Tensor NPUNativeFunctions::_cdist_backward( + const at::Tensor& grad, + const at::Tensor& x1, + const at::Tensor& x2, + const double p, + const at::Tensor& cdist) { + check_cdist_backward_input(grad, x1, x2, p, cdist); + + // Since double is not supported in NPU, the type of P needs to be converted from double to float. + float p_float; + if (std::isinf(p)) { + p_float = std::numeric_limits::infinity(); + } + else { + p_float = static_cast(p); + } + + // Broadcast + auto dim1 = x1.dim(); + auto dim2 = x2.dim(); + + at::SmallVector tensor1_expand_size = array_to_small_vector(x1.sizes()); + tensor1_expand_size.insert(tensor1_expand_size.begin() + (dim1 - 1), 1); + + at::SmallVector tensor2_expand_size = array_to_small_vector(x2.sizes()); + tensor2_expand_size.insert(tensor2_expand_size.begin() + (dim2 - 2), 1); + + at::SmallVector grad_expand_size = array_to_small_vector(grad.sizes()); + grad_expand_size.insert(grad_expand_size.end(), 1); + + at::SmallVector cdist_expand_size = array_to_small_vector(cdist.sizes()); + cdist_expand_size.insert(cdist_expand_size.end(), 1); + + std::vector tensor_broadcast_size = infer_size(tensor1_expand_size, tensor2_expand_size); + + at::Tensor tensor1_broadcast = x1.view(tensor1_expand_size).expand(tensor_broadcast_size).contiguous(); + at::Tensor tensor2_broadcast = x2.view(tensor2_expand_size).expand(tensor_broadcast_size).contiguous(); + at::Tensor grad_broadcast = grad.view(grad_expand_size).expand(tensor_broadcast_size).contiguous(); + at::Tensor cdist_broadcast = cdist.view(cdist_expand_size).expand(tensor_broadcast_size).contiguous(); + + auto outputSize = input_same_output_size(x1); + at::Tensor result = OpPreparation::ApplyTensor(tensor1_broadcast, outputSize); + OpCommand cmd; + cmd.Name("CdistGrad") + .Input(grad_broadcast) + .Input(tensor1_broadcast) + .Input(tensor2_broadcast) + .Input(cdist_broadcast) + .Attr("p", p_float) + .Output(result) + .Run(); + + return result; +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/CdistKernelNpu.cpp b/torch_npu/csrc/aten/ops/CdistKernelNpu.cpp new file mode 100644 index 00000000000..b60ed5734ea --- /dev/null +++ b/torch_npu/csrc/aten/ops/CdistKernelNpu.cpp @@ -0,0 +1,100 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +using namespace at::native::npu; + +at::Tensor NPUNativeFunctions::_cdist_forward( + const at::Tensor& x1, + const at::Tensor& x2, + const double p, + c10::optional compute_mode) { + TORCH_CHECK(x1.dim() >= 2, "cdist only supports at least 2D tensors, X1 got: ", x1.dim(), "D"); + TORCH_CHECK(x2.dim() >= 2, "cdist only supports at least 2D tensors, X2 got: ", x2.dim(), "D"); + TORCH_CHECK(x1.size(-1) == x2.size(-1), "X1 and X2 must have the same number of columns. X1: ", x1.size(-1), " X2: ", x2.size(-1)); + TORCH_CHECK(at::isFloatingType(x1.scalar_type()), "cdist only supports floating-point dtypes, X1 got: ", x1.scalar_type()); + TORCH_CHECK(at::isFloatingType(x1.scalar_type()), "cdist only supports floating-point dtypes, X2 got: ", x2.scalar_type()); + TORCH_CHECK(p >= 0, "cdist only supports non-negative p values"); + + // Since double is not supported in NPU, the type of P needs to be converted from double to float. + float p_float; + if (std::isinf(p)) { + p_float = std::numeric_limits::infinity(); + } + else { + TORCH_CHECK(p <= std::numeric_limits::max(), "npu dose not support float64" ); + p_float = static_cast(p); + } + + int64_t mode = compute_mode.value_or(0); + TORCH_CHECK(mode >= 0 && mode <= 2, "possible modes: 0, 1, 2, but was: ", mode); + + // Broadcast + int64_t c1 = x1.size(-1); + int64_t c2 = x2.size(-1); + int64_t r1 = x1.size(-2); + int64_t r2 = x2.size(-2); + auto dim1 = x1.dim(); + auto dim2 = x2.dim(); + + at::IntArrayRef batch_tensor1(x1.sizes().data(), dim1 - 2); + at::IntArrayRef batch_tensor2(x2.sizes().data(), dim2 - 2); + std::vector expand_batch_portion = infer_size(batch_tensor1, batch_tensor2); + std::vector tensor1_expand_size(expand_batch_portion); + tensor1_expand_size.insert(tensor1_expand_size.end(), {r1, c1}); + std::vector tensor2_expand_size(expand_batch_portion); + tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2}); + + int expand_batch_product = std::accumulate(expand_batch_portion.begin(), expand_batch_portion.end(), 1, std::multiplies()); + std::vector tensor1_view{expand_batch_product, r1, 1, c1}; + std::vector tensor2_view{expand_batch_product, 1, r2, c2}; + std::vector result_size{expand_batch_product, r1, r2}; + std::vector tensor_broadcast_size = infer_size(tensor1_view, tensor2_view); + + // Broadcast batch dim. + at::Tensor tensor1_expanded = x1.expand(tensor1_expand_size).contiguous().view(tensor1_view); + at::Tensor tensor2_expanded = x2.expand(tensor2_expand_size).contiguous().view(tensor2_view); + + // Broadcast r1 and r2. + at::Tensor tensor1_broadcast = tensor1_expanded.expand(tensor_broadcast_size).contiguous(); + at::Tensor tensor2_broadcast = tensor2_expanded.expand(tensor_broadcast_size).contiguous(); + + auto output_size = cdist_npu_output_size(x1, x2); + at::Tensor result = OpPreparation::ApplyTensor(tensor1_broadcast, result_size); + + OpCommand cmd; + cmd.Name("Cdist") + .Input(tensor1_broadcast) + .Input(tensor2_broadcast) + .Attr("p", p_float) + .Output(result) + .Run(); + + return result.view(output_size); +} + +at::Tensor NPUNativeFunctions::cdist( + const at::Tensor& x1, + const at::Tensor& x2, + const double p, + c10::optional compute_mode) { + return at::_cdist_forward(x1, x2, p, compute_mode); +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/CeilKernelNpu.cpp b/torch_npu/csrc/aten/ops/CeilKernelNpu.cpp new file mode 100644 index 00000000000..133fca49859 --- /dev/null +++ b/torch_npu/csrc/aten/ops/CeilKernelNpu.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& ceil_out_npu_nocheck(at::Tensor& result, const at::Tensor& self) { + OpCommand cmd; + cmd.Name("Ceil") + .Input(self) + .Output(result) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::ceil_out(const at::Tensor& self, at::Tensor& result) { + OpPreparation::CheckOut( + {self}, + result, + self); + return ceil_out_npu_nocheck(result, self); +} + +at::Tensor NPUNativeFunctions::ceil(const at::Tensor& self) { + at::Tensor result = OpPreparation::ApplyTensor(self); + ceil_out_npu_nocheck(result, self); + return result; +} + +at::Tensor& NPUNativeFunctions::ceil_(at::Tensor& self) { + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor result = ceil_out_npu_nocheck(contiguousSelf, contiguousSelf); + NpuUtils::format_fresh_view(self, result); + } else { + ceil_out_npu_nocheck(self, self); + } + return self; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/CeluKernelNpu.cpp b/torch_npu/csrc/aten/ops/CeluKernelNpu.cpp new file mode 100644 index 00000000000..6b38b00e17e --- /dev/null +++ b/torch_npu/csrc/aten/ops/CeluKernelNpu.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor celu_out_npu_nocheck(at::Tensor& result, const at::Tensor& self, Scalar alpha) { + float alpha3 = 1.0; + + OpCommand cmd; + cmd.Name("Celu") + .Input(self) + .Output(result) + .Attr("alpha1", alpha) + .Attr("alpha2", alpha) + .Attr("alpha3", alpha3) + .Run(); + + return result; +} + +at::Tensor NPUNativeFunctions::celu(const at::Tensor& self, Scalar alpha) { + at::Tensor result = OpPreparation::ApplyTensor(self); + celu_out_npu_nocheck(result, self, alpha); + return result; +} + +at::Tensor& NPUNativeFunctions::celu_(at::Tensor& self, Scalar alpha) { + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor result = celu_out_npu_nocheck(contiguousSelf, contiguousSelf, alpha); + NpuUtils::format_fresh_view(self, result); + } else { + celu_out_npu_nocheck(self, self, alpha); + } + return self; +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/DefaultKernelNpu.cpp b/torch_npu/csrc/aten/ops/DefaultKernelNpu.cpp new file mode 100644 index 00000000000..17a4af15a03 --- /dev/null +++ b/torch_npu/csrc/aten/ops/DefaultKernelNpu.cpp @@ -0,0 +1,31 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor NPUNativeFunctions::isnan(const at::Tensor& self) { + return at::native::isnan(self); +} + +at::Tensor NPUNativeFunctions::unfold(const at::Tensor& self, int64_t dimension, int64_t size, int64_t step) { + return at::native::unfold(self, dimension, size, step); +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/EmbeddingBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/EmbeddingBackwardKernelNpu.cpp index 217d46771f4..0155cd7895e 100644 --- a/torch_npu/csrc/aten/ops/EmbeddingBackwardKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/EmbeddingBackwardKernelNpu.cpp @@ -18,7 +18,7 @@ namespace at_npu { namespace native { -at::Tensor embedding_backward_npu( +at::Tensor NPUNativeFunctions::embedding_backward( const at::Tensor& grad, const at::Tensor& indices, int64_t num_weights, diff --git a/torch_npu/csrc/aten/ops/GeluBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/GeluBackwardKernelNpu.cpp new file mode 100644 index 00000000000..0e6c809960e --- /dev/null +++ b/torch_npu/csrc/aten/ops/GeluBackwardKernelNpu.cpp @@ -0,0 +1,47 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& gelu_backward_out_npu_nocheck( + at::Tensor& grad_input, + const at::Tensor& grad, + const at::Tensor& self) { + at::Tensor unused = grad; + OpCommand cmd; + cmd.Name("GeluGrad") + .Input(grad) + .Input(self) + .Input(unused) + .Output(grad_input) + .Run(); + + return grad_input; +} + +at::Tensor NPUNativeFunctions::gelu_backward( + const at::Tensor& grad, + const at::Tensor& self) { + at::Tensor grad_input = OpPreparation::ApplyTensor(self); + gelu_backward_out_npu_nocheck(grad_input, grad, self); + return grad_input; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee From a8ecc2976037c5e165b62b156056bc8b982273ff Mon Sep 17 00:00:00 2001 From: wangxiao Date: Fri, 18 Feb 2022 17:41:12 +0800 Subject: [PATCH 5/6] fix bug of atan, ceil, celu, argsort, cdist, celu --- test/test_network_ops/test_argmin.py | 69 ----------- test/test_network_ops/test_atan.py | 8 +- test/test_network_ops/test_atan2.py | 90 -------------- test/test_network_ops/test_cdist_backward.py | 117 ------------------ test/test_network_ops/test_ceil.py | 8 +- test/test_network_ops/test_celu.py | 38 ++---- test/test_network_ops/test_gelu_backward.py | 77 ------------ torch_npu/csrc/aten/ops/ArgminKernelNpu.cpp | 53 -------- torch_npu/csrc/aten/ops/ArgsortKernelNpu.cpp | 4 +- torch_npu/csrc/aten/ops/Atan2KernelNpu.cpp | 77 ------------ .../csrc/aten/ops/CdistBackwardKernelNpu.cpp | 93 -------------- torch_npu/csrc/aten/ops/CdistKernelNpu.cpp | 5 +- torch_npu/csrc/aten/ops/CeluKernelNpu.cpp | 12 +- .../csrc/aten/ops/GeluBackwardKernelNpu.cpp | 47 ------- 14 files changed, 28 insertions(+), 670 deletions(-) delete mode 100644 test/test_network_ops/test_argmin.py delete mode 100644 test/test_network_ops/test_atan2.py delete mode 100644 test/test_network_ops/test_cdist_backward.py delete mode 100644 test/test_network_ops/test_gelu_backward.py delete mode 100644 torch_npu/csrc/aten/ops/ArgminKernelNpu.cpp delete mode 100644 torch_npu/csrc/aten/ops/Atan2KernelNpu.cpp delete mode 100644 torch_npu/csrc/aten/ops/CdistBackwardKernelNpu.cpp delete mode 100644 torch_npu/csrc/aten/ops/GeluBackwardKernelNpu.cpp diff --git a/test/test_network_ops/test_argmin.py b/test/test_network_ops/test_argmin.py deleted file mode 100644 index e64104fb71e..00000000000 --- a/test/test_network_ops/test_argmin.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) 2020 Huawei Technologies Co., Ltd -# All rights reserved. -# -# Licensed under the BSD 3-Clause License (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://opensource.org/licenses/BSD-3-Clause -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# coding: utf-8 -import torch -import torch_npu -import numpy as np - -from torch_npu.testing.common_utils import TestCase, run_tests -from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests -from torch_npu.testing.util_test import create_common_tensor - -class TestArgmin(TestCase): - @Dtypes(torch.float) - def test_argmin(self, device, dtype): - inputValues = [-1000, -1, 0, 0.5, 1, 2, 1000] - expectedOutput = [0.0000, 0.2689, 0.5, 0.6225, 0.7311, 0.8808, 1.000] - precision_4dps = 0.0002 - a = torch.tensor(inputValues, dtype=dtype, device=device) - - self.assertRtolEqual(torch.tensor(inputValues, dtype=dtype, device=device).sigmoid().cpu(), - torch.tensor(expectedOutput, dtype=dtype, device=device).cpu(), - precision_4dps) - - def cpu_op_exec(self, input1, dims, keepdim=False): - output = torch.argmin(input1, dim=dims, keepdim=keepdim) - if output.dtype != torch.int32: - output = output.to(torch.int32) - output = output.numpy() - return output - - def npu_op_exec(self, input1, dims, keepdim=False): - output = torch.argmin(input1, dim=dims, keepdim=keepdim) - output = output.to("cpu") - if output.dtype != torch.int32: - output = output.to(torch.int32) - output = output.numpy() - return output - - def test_argmin_shape_format(self, device): - shape_format = [ - [ [np.float32, 0, (6, 4)], 0, False], - [ [np.float32, 0, (6, 4)], 1, True ], - [ [np.float32, 0, (2, 4, 5)], 2, True ], - [ [np.float32, 0, (1, 2, 3, 3)], 2, False], - [ [np.float32, 0, (1, 2, 3, 3)], 2, False], - [ [np.float32, 29, (15, 15, 15, 16)], 1, False], - ] - - for item in shape_format: - cpu_input, npu_input = create_common_tensor(item[0], 1, 100) - cpu_output = self.cpu_op_exec(cpu_input, item[1], keepdim=item[2]) - npu_output = self.npu_op_exec(npu_input, item[1], keepdim=item[2]) - self.assertRtolEqual(cpu_output, npu_output) - -instantiate_device_type_tests(TestArgmin, globals(), except_for="cpu") -if __name__ == "__main__": - run_tests() diff --git a/test/test_network_ops/test_atan.py b/test/test_network_ops/test_atan.py index d51e441933c..9b7926e1e43 100644 --- a/test/test_network_ops/test_atan.py +++ b/test/test_network_ops/test_atan.py @@ -22,12 +22,12 @@ from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type from torch_npu.testing.util_test import create_common_tensor class TestAtan(TestCase): - def cpu_op_exec(self, input): - output = torch.atan(input) + def cpu_op_exec(self, input1): + output = torch.atan(input1) return output - def npu_op_exec(self, input): - output = torch.atan(input) + def npu_op_exec(self, input1): + output = torch.atan(input1) output = output.to("cpu") return output diff --git a/test/test_network_ops/test_atan2.py b/test/test_network_ops/test_atan2.py deleted file mode 100644 index 39c992dd685..00000000000 --- a/test/test_network_ops/test_atan2.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) 2020, Huawei Technologies.All rights reserved. -# -# Licensed under the BSD 3-Clause License (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://opensource.org/licenses/BSD-3-Clause -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch -import torch_npu -import numpy as np - -from torch_npu.testing.common_utils import TestCase, run_tests -from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests -from torch_npu.testing.util_test import create_common_tensor - -class TestAtan2(TestCase): - def cpu_op_exec(self,input1, input2): - output = torch.atan2(input1, input2) - output = output.numpy() - return output - - def npu_op_exec(self,input1, input2): - output = torch.atan2(input1, input2) - output = output.to("cpu") - output = output.numpy() - return output - - def npu_op_exec_out(self,input1, input2, out): - torch.atan2(input1, input2, out=out) - output = out.to("cpu") - output = output.numpy() - return output - - def test_atan2_common_shape_format(self, device): - shape_format = [ - [[np.float16, 0, [4, 12, 12, 128]], [np.float16, 0, [4]]], - [[np.float16, 0, [4, 128]], [np.float16, 0, [4, 256, 12]]], - [[np.float32, 0, [4, 12, 12, 128]], [np.float32, 0, [4]]], - [[np.float32, 0, [4, 128]], [np.float32, 0, [4, 256, 12]]], - ] - for item in shape_format: - cpu_input1, npu_input1 = create_common_tensor(item[0], -1, 1) - cpu_input2, npu_input2 = create_common_tensor(item[0], -1, 1) - if cpu_input1.dtype == torch.float16: - cpu_input1 = cpu_input1.to(torch.float32) - if cpu_input2.dtype == torch.float16: - cpu_input2 = cpu_input2.to(torch.float32) - cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) - npu_output = self.npu_op_exec(npu_input1, npu_input2) - cpu_output = cpu_output.astype(npu_output.dtype) - self.assertRtolEqual(cpu_output, npu_output) - - def test_atan2_out_common_shape_format(self, device): - shape_format = [ - [[np.float16, 0, [4, 12, 12, 128]], [np.float16, 0, [4]]], - [[np.float16, 0, [4, 128]], [np.float16, 0, [4, 256, 12]]], - [[np.float32, 0, [4, 12, 12, 128]], [np.float32, 0, [4]]], - [[np.float32, 0, [4, 128]], [np.float32, 0, [4, 256, 12]]], - ] - for item in shape_format: - cpu_input1, npu_input1 = create_common_tensor(item[0], -1, 1) - cpu_input2, npu_input2 = create_common_tensor(item[0], -1, 1) - cpu_out, npu_out = create_common_tensor(item[1], -1, 1) - if cpu_input1.dtype == torch.float16: - cpu_input1 = cpu_input1.to(torch.float32) - if cpu_input2.dtype == torch.float16: - cpu_input2 = cpu_input2.to(torch.float32) - cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) - npu_output = self.npu_op_exec(npu_input1, npu_input2) - npu_output_out = self.npu_op_exec_out(npu_input1, npu_input2, npu_out) - cpu_output = cpu_output.astype(npu_output.dtype) - self.assertRtolEqual(cpu_output, npu_output) - self.assertRtolEqual(cpu_output, npu_output_out) - - def test_atan2_mix_dtype(self, device): - npu_input1, npu_input2 = create_common_tensor([np.float32, 0, (2, 3)], 1, 100) - npu_input3, npu_input4 = create_common_tensor([np.float16, 0, (2, 3)], 1, 100) - cpu_output = self.cpu_op_exec(npu_input1, npu_input3) - npu_output = self.npu_op_exec(npu_input2, npu_input4) - self.assertRtolEqual(cpu_output, npu_output) - -instantiate_device_type_tests(TestAtan2, globals(), except_for='cpu') -if __name__ == "__main__": - run_tests() diff --git a/test/test_network_ops/test_cdist_backward.py b/test/test_network_ops/test_cdist_backward.py deleted file mode 100644 index 6d3b41aaa06..00000000000 --- a/test/test_network_ops/test_cdist_backward.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright (c) 2020 Huawei Technologies Co., Ltd -# Copyright (c) 2019, Facebook CORPORATION. -# All rights reserved. -# -# Licensed under the BSD 3-Clause License (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://opensource.org/licenses/BSD-3-Clause -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch -import torch_npu -import numpy as np - -from torch_npu.testing.common_utils import TestCase, run_tests -from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests -from torch_npu.testing.util_test import create_common_tensor - -def cdist_backward(x1, x2, p, grad, cdist): - x1 = torch.unsqueeze(x1, -2) - x2 = torch.unsqueeze(x2, -3) - grad = torch.unsqueeze(grad, -1) - cdist = torch.unsqueeze(cdist, -1) - diff = x1 - x2 - diff_abs = torch.abs(diff) - nz_cdist = torch.where(cdist == 0, torch.ones_like(cdist), cdist) - sign = torch.where(diff > 0, torch.ones_like(diff), torch.full_like(diff, -1)) - sign = torch.where(diff == 0, torch.zeros_like(diff), sign) - - if p == 0.0: - res = torch.zeros_like(diff) - elif p == 1.0: - res = grad * sign - elif p < 2.0: - res = sign * torch.pow(diff_abs, p - 1.0) * grad / torch.pow(nz_cdist, p - 1.0) - res = torch.where(cdist == 0, torch.zeros_like(res), res) - elif p == 2.0: - res = grad * diff / nz_cdist - res = torch.where(cdist == 0, torch.zeros_like(res), res) - elif p == float("inf"): - mask = torch.where(cdist - diff_abs > 0, torch.zeros_like(diff), torch.ones_like(diff)) - res = grad * sign * mask - else: - res = diff * torch.pow(diff_abs, p - 2) * grad / torch.pow(nz_cdist, p - 1.0) - res = torch.where(cdist == 0, torch.zeros_like(res), res) - res = torch.sum(res, -2) - return res - - -class Testcdist(TestCase): - def generate_data(self, min_n, max_n, shape_x, shape_y, src_type): - np.random.seed(10086) - x1 = np.random.uniform(min_n, max_n, shape_x).astype(src_type) - x2 = np.random.uniform(min_n, max_n, shape_y).astype(src_type) - return x1, x2 - - def op_exec(self, x1, x2, p, device='cpu'): - is_fp16 = x1.dtype == np.float16 - - if device == 'cpu' and is_fp16: - x1 = x1.astype(np.float32) - x2 = x2.astype(np.float32) - - x1 = torch.tensor(x1, device=device, requires_grad=True) - x2 = torch.tensor(x2, device=device, requires_grad=True) - - y = torch.cdist(x1, x2, p) - grad = torch.ones_like(y, requires_grad=True, device=device) - - if device == 'cpu' and is_fp16: - y = y.half() - y = y.float() - out = cdist_backward(x1, x2, p, grad, y) - return out.detach().numpy().astype('float16') - - y.backward(grad, retain_graph=True) - out = x1.grad.detach().cpu().numpy() - - return out - - def test_cdis_backward_common_shape(self, device): - shape_items = [ - [np.float16, (5, 10), (4, 10)], - [np.float16, (20, 5, 10), (20, 4, 10)], - [np.float32, (5, 10), (4, 10)], - [np.float32, (20, 5, 10), (20, 4, 10)], - ] - p_ranges = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5] - for item in shape_items: - for p in p_ranges: - input1, input2 = self.generate_data(-1, 1, - item[1], item[2], item[0]) - cpu_output = self.op_exec(input1, input2, p, 'cpu') - npu_output = self.op_exec(input1, input2, p, 'npu') - self.assertRtolEqual(cpu_output, npu_output) - - def test_cdis_backward_input_range(self, device): - item = [np.float32, (20, 5, 5), (20, 4, 5)] - p_ranges = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5] - input_ragnes = [(-0.1, 0.1), (10, 10), (-100, 100)] - for p in p_ranges: - for min_max in input_ragnes: - - input1, input2 = self.generate_data(min_max[0], min_max[1], - item[1], item[2], item[0]) - cpu_output = self.op_exec(input1, input2, p, 'cpu') - npu_output = self.op_exec(input1, input2, p, 'npu') - self.assertRtolEqual(cpu_output, npu_output) - -instantiate_device_type_tests(Testcdist, globals(), except_for="cpu") -if __name__ == "__main__": - run_tests() diff --git a/test/test_network_ops/test_ceil.py b/test/test_network_ops/test_ceil.py index 66138309ffc..bfa32911ea6 100644 --- a/test/test_network_ops/test_ceil.py +++ b/test/test_network_ops/test_ceil.py @@ -32,12 +32,12 @@ class TestCeil(TestCase): self.assertRtolEqual(cpu_output, npu_output) - def cpu_op_exec(self, input): - output = torch.ceil(input) + def cpu_op_exec(self, input1): + output = torch.ceil(input1) return output - def npu_op_exec(self, input): - output = torch.ceil(input) + def npu_op_exec(self, input1): + output = torch.ceil(input1) output = output.to("cpu") return output diff --git a/test/test_network_ops/test_celu.py b/test/test_network_ops/test_celu.py index 20f0714e8fb..43d1ac2cda3 100644 --- a/test/test_network_ops/test_celu.py +++ b/test/test_network_ops/test_celu.py @@ -69,6 +69,13 @@ class TestCelu(TestCase): output = output.to("cpu") output = output.numpy() return output + + def cpu_op_exec_fp16(self, input1, alpha): + input1 = input1.to(torch.float32) + output = torch.nn.functional.celu(input1, alpha=alpha) + output = output.numpy() + output = output.astype(np.float16) + return output def test_celu_3_3_float32_alpha1(self, device): input_x1 = self.generate_data(-1, 1, (3, 3), np.float32) @@ -89,13 +96,6 @@ class TestCelu(TestCase): self.assertRtolEqual(cpu_output1, npu_output1) def test_celu_float16_alpha1(self, device): - def cpu_op_exec_fp16(input1, alpha): - input1 = input1.to(torch.float32) - output = torch.nn.functional.celu(input1, alpha=alpha) - output = output.numpy() - output = output.astype(np.float16) - return output - shape_format = [ [[np.float16, 0, (65535, 1, 1, 1)]], [[np.float16, 0, (1, 1, 1, 65535)]], @@ -103,18 +103,11 @@ class TestCelu(TestCase): for item in shape_format: cpu_input1, npu_input1 = create_common_tensor(item[0], -2, 2) - cpu_output = cpu_op_exec_fp16(cpu_input1, 1.0) + cpu_output = self.cpu_op_exec_fp16(cpu_input1, 1.0) npu_output = self.npu_op_exec(npu_input1, 1.0) self.assertRtolEqual(cpu_output, npu_output) def test_celu_float16_alpha2_success(self, device): - def cpu_op_exec_fp16(input1, alpha): - input1 = input1.to(torch.float32) - output = torch.nn.functional.celu(input1, alpha=alpha) - output = output.numpy() - output = output.astype(np.float16) - return output - shape_format = [ [[np.float16, 0, (65535, 1, 1, 1)]], [[np.float16, 0, (1, 1, 1, 65535)]], @@ -122,18 +115,11 @@ class TestCelu(TestCase): for item in shape_format: cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 100) - cpu_output = cpu_op_exec_fp16(cpu_input1, 2.0) + cpu_output = self.cpu_op_exec_fp16(cpu_input1, 2.0) npu_output = self.npu_op_exec(npu_input1, 2.0) self.assertRtolEqual(cpu_output, npu_output) def test_celu_float16_alpha2_fail(self, device): - def cpu_op_exec_fp16(input1, alpha): - input1 = input1.to(torch.float32) - output = torch.nn.functional.celu(input1, alpha=alpha) - output = output.numpy() - output = output.astype(np.float16) - return output - shape_format = [ [[np.float16, 0, (65535, 1, 1, 1)]], [[np.float16, 0, (1, 1, 1, 65535)]], @@ -141,7 +127,7 @@ class TestCelu(TestCase): for item in shape_format: cpu_input1, npu_input1 = create_common_tensor(item[0], -2, 2) - cpu_output = cpu_op_exec_fp16(cpu_input1, 2.0) + cpu_output = self.cpu_op_exec_fp16(cpu_input1, 2.0) npu_output = self.npu_op_exec(npu_input1, 2.0) self.assertRtolEqual(cpu_output, npu_output) @@ -180,7 +166,7 @@ class TestCelu(TestCase): def test_celu_inplace_shape_format_alpha_range(self, device): shape_format_alpha_range = [ - # [[dtype, format, shape], alpha, min, max] + # 注:[[dtype, format, shape], alpha, min, max] [[np.float16, 2, (16, 5, 7, 11)], 5.6, -2, 2], [[np.float32, 2, (16, 5, 7, 11)], 0.5, -2, 2], [[np.float32, 2, (16, 5, 7, 11)], 0.7, -2, 2], @@ -206,7 +192,7 @@ class TestCelu(TestCase): def test_celu_inplace_shape_format_alpha_range(self, device): shape_format_alpha_range = [ - # [[dtype, format, shape], alpha, min, max] + # 注:[[dtype, format, shape], alpha, min, max] [[np.float32, 2, (16, 5, 7, 11)], 0.5, -2, 2], [[np.float32, 2, (16, 5, 7, 11)], 0.7, -2, 2], [[np.float32, 2, (16, 5, 7, 11)], 2.6, -2, 2], diff --git a/test/test_network_ops/test_gelu_backward.py b/test/test_network_ops/test_gelu_backward.py deleted file mode 100644 index 89d367df6ce..00000000000 --- a/test/test_network_ops/test_gelu_backward.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) 2020, Huawei Technologies.All rights reserved. -# -# Licensed under the BSD 3-Clause License (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://opensource.org/licenses/BSD-3-Clause -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import copy -import torch -import torch_npu -import numpy as np - -from torch_npu.testing.common_utils import TestCase, run_tests -from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests -from torch_npu.testing.util_test import create_common_tensor - -class TestGeluBackward(TestCase): - def generate_single_data(self, min_val, max_val, shape, dtype): - input1 = np.random.uniform(min_val, max_val, shape).astype(dtype) - npu_input1 = torch.from_numpy(input1) - return npu_input1 - - def cpu_op_exec(self, input1): - input1.requires_grad_(True) - output = torch.nn.functional.gelu(input1) - z = output.sum() - z.backward() - res = input1.grad - return res.detach().numpy() - - def npu_op_exec(self, input1): - input1 = input1.to("npu") - input1.requires_grad = True - output = torch.nn.functional.gelu(input1) - z = output.sum() - z.backward() - res = input1.grad.to("cpu") - return res.detach().numpy() - - def test_gelu_backward_float32_1(self, device): - input1= self.generate_single_data(0, 100, (4, 3, 1, 1), np.float32) - cpu_input1 = copy.deepcopy(input1) - cpu_output = self.cpu_op_exec(cpu_input1) - npu_output = self.npu_op_exec(input1) - self.assertRtolEqual(cpu_output, npu_output) - - def test_gelu_backward_float32_2(self, device): - input1= self.generate_single_data(0, 100, (15, 3, 1), np.float32) - cpu_input1 = copy.deepcopy(input1) - cpu_output = self.cpu_op_exec(cpu_input1) - npu_output = self.npu_op_exec(input1) - self.assertRtolEqual(cpu_output, npu_output) - - def test_gelu_backward_float32_3(self, device): - input1= self.generate_single_data(0, 100, (4, 4), np.float32) - cpu_input1 = copy.deepcopy(input1) - cpu_output = self.cpu_op_exec(cpu_input1) - npu_output = self.npu_op_exec(input1) - self.assertRtolEqual(cpu_output, npu_output) - - def test_gelu_backward_float16(self, device): - input1 = self.generate_single_data(0, 100, (5, 10, 100), np.float16) - cpu_input1 = input1.to(torch.float32) - cpu_output = self.cpu_op_exec(cpu_input1) - cpu_output = cpu_output.astype(np.float16) - npu_output = self.npu_op_exec(input1) - self.assertRtolEqual(cpu_output, npu_output) - -instantiate_device_type_tests(TestGeluBackward, globals(), except_for="cpu") -if __name__ == "__main__": - run_tests() diff --git a/torch_npu/csrc/aten/ops/ArgminKernelNpu.cpp b/torch_npu/csrc/aten/ops/ArgminKernelNpu.cpp deleted file mode 100644 index 8c046ad6eca..00000000000 --- a/torch_npu/csrc/aten/ops/ArgminKernelNpu.cpp +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) 2020, Huawei Technologies.All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "torch_npu/csrc/framework/utils/OpAdapter.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" - -namespace at_npu { -namespace native { - -at::Tensor NPUNativeFunctions::argmin(const at::Tensor& self, at::optional dim, bool keepdim) { - TORCH_CHECK( - self.numel() > 0, - "cannot perform reduction function argmin on a " - "tensor with no elements because the operation does not have an identity"); - - at::Tensor input = dim.has_value() ? self : self.reshape({-1}); - int64_t realDim = dim.has_value() ? dim.value() : 0; - bool realKeepDim = dim.has_value() ? keepdim : false; - - // calculate the output size - auto outputSize = reduce_ops_npu_output_size(input, realDim, realKeepDim); - - // construct the output tensor of the NPU - at::Tensor result = OpPreparation::ApplyTensorWithFormat( - outputSize, - self.options().dtype(at::kInt), - ACL_FORMAT_ND); - at::SmallVector DimVec = {realDim}; - // calculate the output result of the NPU - OpCommand cmd; - cmd.Name("ArgMin") - .Input(input) - .Input(DimVec, at::kInt) - .Output(result) - .Attr("keep_dims", realKeepDim) - .Run(); - - result = NPUNativeFunctions::npu_dtype_cast(result, at::ScalarType::Long); - return result; -} - -} // namespace native -} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/ArgsortKernelNpu.cpp b/torch_npu/csrc/aten/ops/ArgsortKernelNpu.cpp index 38b839a91b7..0a8fc964d26 100644 --- a/torch_npu/csrc/aten/ops/ArgsortKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/ArgsortKernelNpu.cpp @@ -65,7 +65,7 @@ at::Tensor& argsort_out_npu_nocheck( argsort_out_npu_no_transpose( transposeValues, transposeIndices, transposeSelf, lastDim, descending); - NPUNativeFunctions::npu_transpose_out(indices, transposeIndices, perm); + NPUNativeFunctions::npu_transpose_out(transposeIndices, perm, indices); // indices dtype transform to Int64 indices = NPUNativeFunctions::npu_dtype_cast(indices, at::kLong); @@ -85,7 +85,7 @@ at::Tensor NPUNativeFunctions::argsort(const at::Tensor& self, return indices; } -at::Tensor NPUNativeFunctions::argsort_dim(const at::Tensor& self, +at::Tensor NPUNativeFunctions::argsort(const at::Tensor& self, at::Dimname dim, bool descending) { return NPUNativeFunctions::argsort(self, dimname_to_position(self, dim), descending); diff --git a/torch_npu/csrc/aten/ops/Atan2KernelNpu.cpp b/torch_npu/csrc/aten/ops/Atan2KernelNpu.cpp deleted file mode 100644 index d2e4ce41a23..00000000000 --- a/torch_npu/csrc/aten/ops/Atan2KernelNpu.cpp +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) 2020, Huawei Technologies.All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "torch_npu/csrc/framework/utils/OpAdapter.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" - -namespace at_npu { -namespace native { - -at::Tensor& atan2_out_npu_nocheck( - at::Tensor& result, - const at::Tensor& self, - const at::Tensor& other) { - auto unified_result = OpPreparation::binary_op_check(result, self, other, true); - OpCommand cmd; - cmd.Name("Atan2") - .Expect(unified_result) - .Input(self) - .Input(other) - .Output(result) - .Run(); - - return result; -} - -at::Tensor& NPUNativeFunctions::atan2_out( - const at::Tensor& self, - const at::Tensor& other, - at::Tensor& result) { - auto outputSize = broadcast_ops_npu_output_size(self, other); - - OpPreparation::CheckOut( - {self}, - result, - self, - outputSize); - - atan2_out_npu_nocheck(result, self, other); - return result; -} - -at::Tensor NPUNativeFunctions::atan2(const at::Tensor& self, const at::Tensor& other) { - auto outputSize = broadcast_ops_npu_output_size(self, other); - at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); - atan2_out_npu_nocheck(result, self, other); - return result; -} - -at::Tensor& NPUNativeFunctions::atan2_(at::Tensor& self, const at::Tensor& other) { - OpPreparation::CheckMemory({self, other}, {self}); - - if (!NpuUtils::check_match(&self)) { - at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); - at::Tensor result = atan2_out_npu_nocheck(contiguousSelf, contiguousSelf, other); - NpuUtils::format_fresh_view(self, result); - } else { - atan2_out_npu_nocheck(self, self, other); - } - return self; -} -TORCH_LIBRARY_IMPL(aten, NPU, m) { - m.impl("atan2", TORCH_FN(atan2_npu)); - m.impl("atan2_", TORCH_FN(atan2_npu_)); - m.impl("atan2.out", TORCH_FN(atan2_out_npu)); -} -} // namespace native -} // namespace at diff --git a/torch_npu/csrc/aten/ops/CdistBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/CdistBackwardKernelNpu.cpp deleted file mode 100644 index ead328f8f73..00000000000 --- a/torch_npu/csrc/aten/ops/CdistBackwardKernelNpu.cpp +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright (c) 2020, Huawei Technologies.All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "torch_npu/csrc/framework/utils/OpAdapter.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" - -namespace at_npu { -namespace native { - -static void check_cdist_backward_input( - const at::Tensor& grad, - const at::Tensor& x1, - const at::Tensor& x2, - const double p, - const at::Tensor& cdist) { - TORCH_CHECK(x1.is_contiguous(), "_cdist_backward requires X1 to be contiguous"); - TORCH_CHECK(x2.is_contiguous(), "_cdist_backward requires X2 to be contiguous"); - TORCH_CHECK(cdist.is_contiguous(), "_cdist_backward requires dist to be contiguous"); - TORCH_CHECK(grad.is_contiguous(), "_cdist_backward requires grad to be contiguous"); - auto device1 = x1.device().type(); - TORCH_CHECK(device1 == at::kCPU || device1 == at::kCUDA || device1 == at::kNPU, "_cdist_backward only supports CPU, CUDA and NPU devices, X1 got: ", device1); - auto device2 = x2.device().type(); - TORCH_CHECK(device2 == at::kCPU || device2 == at::kCUDA || device2 == at::kNPU, "_cdist_backward only supports CPU, CUDA and NPU devices, X2 got: ", device2); - TORCH_CHECK(p <= std::numeric_limits::max(), "npu dose not support float64" ); -} - -at::Tensor NPUNativeFunctions::_cdist_backward( - const at::Tensor& grad, - const at::Tensor& x1, - const at::Tensor& x2, - const double p, - const at::Tensor& cdist) { - check_cdist_backward_input(grad, x1, x2, p, cdist); - - // Since double is not supported in NPU, the type of P needs to be converted from double to float. - float p_float; - if (std::isinf(p)) { - p_float = std::numeric_limits::infinity(); - } - else { - p_float = static_cast(p); - } - - // Broadcast - auto dim1 = x1.dim(); - auto dim2 = x2.dim(); - - at::SmallVector tensor1_expand_size = array_to_small_vector(x1.sizes()); - tensor1_expand_size.insert(tensor1_expand_size.begin() + (dim1 - 1), 1); - - at::SmallVector tensor2_expand_size = array_to_small_vector(x2.sizes()); - tensor2_expand_size.insert(tensor2_expand_size.begin() + (dim2 - 2), 1); - - at::SmallVector grad_expand_size = array_to_small_vector(grad.sizes()); - grad_expand_size.insert(grad_expand_size.end(), 1); - - at::SmallVector cdist_expand_size = array_to_small_vector(cdist.sizes()); - cdist_expand_size.insert(cdist_expand_size.end(), 1); - - std::vector tensor_broadcast_size = infer_size(tensor1_expand_size, tensor2_expand_size); - - at::Tensor tensor1_broadcast = x1.view(tensor1_expand_size).expand(tensor_broadcast_size).contiguous(); - at::Tensor tensor2_broadcast = x2.view(tensor2_expand_size).expand(tensor_broadcast_size).contiguous(); - at::Tensor grad_broadcast = grad.view(grad_expand_size).expand(tensor_broadcast_size).contiguous(); - at::Tensor cdist_broadcast = cdist.view(cdist_expand_size).expand(tensor_broadcast_size).contiguous(); - - auto outputSize = input_same_output_size(x1); - at::Tensor result = OpPreparation::ApplyTensor(tensor1_broadcast, outputSize); - OpCommand cmd; - cmd.Name("CdistGrad") - .Input(grad_broadcast) - .Input(tensor1_broadcast) - .Input(tensor2_broadcast) - .Input(cdist_broadcast) - .Attr("p", p_float) - .Output(result) - .Run(); - - return result; -} - -} // namespace native -} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/CdistKernelNpu.cpp b/torch_npu/csrc/aten/ops/CdistKernelNpu.cpp index b60ed5734ea..380c8d63f9c 100644 --- a/torch_npu/csrc/aten/ops/CdistKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/CdistKernelNpu.cpp @@ -17,7 +17,6 @@ namespace at_npu { namespace native { -using namespace at::native::npu; at::Tensor NPUNativeFunctions::_cdist_forward( const at::Tensor& x1, @@ -54,7 +53,7 @@ at::Tensor NPUNativeFunctions::_cdist_forward( at::IntArrayRef batch_tensor1(x1.sizes().data(), dim1 - 2); at::IntArrayRef batch_tensor2(x2.sizes().data(), dim2 - 2); - std::vector expand_batch_portion = infer_size(batch_tensor1, batch_tensor2); + std::vector expand_batch_portion = at::infer_size(batch_tensor1, batch_tensor2); std::vector tensor1_expand_size(expand_batch_portion); tensor1_expand_size.insert(tensor1_expand_size.end(), {r1, c1}); std::vector tensor2_expand_size(expand_batch_portion); @@ -64,7 +63,7 @@ at::Tensor NPUNativeFunctions::_cdist_forward( std::vector tensor1_view{expand_batch_product, r1, 1, c1}; std::vector tensor2_view{expand_batch_product, 1, r2, c2}; std::vector result_size{expand_batch_product, r1, r2}; - std::vector tensor_broadcast_size = infer_size(tensor1_view, tensor2_view); + std::vector tensor_broadcast_size = at::infer_size(tensor1_view, tensor2_view); // Broadcast batch dim. at::Tensor tensor1_expanded = x1.expand(tensor1_expand_size).contiguous().view(tensor1_view); diff --git a/torch_npu/csrc/aten/ops/CeluKernelNpu.cpp b/torch_npu/csrc/aten/ops/CeluKernelNpu.cpp index 6b38b00e17e..676a496a588 100644 --- a/torch_npu/csrc/aten/ops/CeluKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/CeluKernelNpu.cpp @@ -19,28 +19,24 @@ namespace at_npu { namespace native { -at::Tensor celu_out_npu_nocheck(at::Tensor& result, const at::Tensor& self, Scalar alpha) { - float alpha3 = 1.0; - +at::Tensor celu_out_npu_nocheck(at::Tensor& result, const at::Tensor& self, at::Scalar alpha) { OpCommand cmd; cmd.Name("Celu") .Input(self) .Output(result) - .Attr("alpha1", alpha) - .Attr("alpha2", alpha) - .Attr("alpha3", alpha3) + .Attr("alpha", alpha) .Run(); return result; } -at::Tensor NPUNativeFunctions::celu(const at::Tensor& self, Scalar alpha) { +at::Tensor NPUNativeFunctions::celu(const at::Tensor& self, at::Scalar alpha) { at::Tensor result = OpPreparation::ApplyTensor(self); celu_out_npu_nocheck(result, self, alpha); return result; } -at::Tensor& NPUNativeFunctions::celu_(at::Tensor& self, Scalar alpha) { +at::Tensor& NPUNativeFunctions::celu_(at::Tensor& self, at::Scalar alpha) { if (!NpuUtils::check_match(&self)) { at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); at::Tensor result = celu_out_npu_nocheck(contiguousSelf, contiguousSelf, alpha); diff --git a/torch_npu/csrc/aten/ops/GeluBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/GeluBackwardKernelNpu.cpp deleted file mode 100644 index 0e6c809960e..00000000000 --- a/torch_npu/csrc/aten/ops/GeluBackwardKernelNpu.cpp +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) 2020, Huawei Technologies.All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" -#include "torch_npu/csrc/framework/utils/OpAdapter.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" - -namespace at_npu { -namespace native { - -at::Tensor& gelu_backward_out_npu_nocheck( - at::Tensor& grad_input, - const at::Tensor& grad, - const at::Tensor& self) { - at::Tensor unused = grad; - OpCommand cmd; - cmd.Name("GeluGrad") - .Input(grad) - .Input(self) - .Input(unused) - .Output(grad_input) - .Run(); - - return grad_input; -} - -at::Tensor NPUNativeFunctions::gelu_backward( - const at::Tensor& grad, - const at::Tensor& self) { - at::Tensor grad_input = OpPreparation::ApplyTensor(self); - gelu_backward_out_npu_nocheck(grad_input, grad, self); - return grad_input; -} - -} // namespace native -} // namespace at_npu \ No newline at end of file -- Gitee From cf9d5c0d97968b57ee5a0b3e5e753ea9024cc20b Mon Sep 17 00:00:00 2001 From: junqiang521 Date: Mon, 21 Feb 2022 15:40:55 +0800 Subject: [PATCH 6/6] =?UTF-8?q?=E9=97=A8=E7=A6=81=E5=8A=A0=E5=85=A5resnet?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修改pytorch_resnet.py 运行路径 删除run_tests.py相关代码,这个文件不存在 --- ci/access_control_test.py | 6 +- ci/pytorch_resnet.py | 451 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 453 insertions(+), 4 deletions(-) create mode 100644 ci/pytorch_resnet.py diff --git a/ci/access_control_test.py b/ci/access_control_test.py index aa61c310b33..3ebebbc8848 100644 --- a/ci/access_control_test.py +++ b/ci/access_control_test.py @@ -22,7 +22,7 @@ from abc import ABCMeta, abstractmethod BASE_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -DEFAULT_UT_FILE = os.path.join(BASE_DIR, 'test/test_network_ops/test_add.py') +DEFAULT_UT_FILE = os.path.join(BASE_DIR, 'ci/pytorch_resnet.py') class AccurateTest(metaclass=ABCMeta): @@ -115,9 +115,7 @@ class TestMgr(): if os.path.exists(changed_file): exist_ut_file.append(changed_file) self.ut_files = exist_ut_file - - if len(self.ut_files) == 0: - self.ut_files.append(DEFAULT_UT_FILE) + self.ut_files.append(DEFAULT_UT_FILE) def get_ut_files(self): return self.ut_files diff --git a/ci/pytorch_resnet.py b/ci/pytorch_resnet.py new file mode 100644 index 00000000000..81c7ee8d7f2 --- /dev/null +++ b/ci/pytorch_resnet.py @@ -0,0 +1,451 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import random +import shutil +import time +import warnings + +import torch +import torch_npu +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +import torch.multiprocessing as mp +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import torchvision.models as models +import torch.npu + +SOURCE_DIR = os.environ.get('SOURCE_DIR') +BATCH_SIZE = 128 +EPOCHS_SIZE = 1 +TRAIN_STEP = 10 +LOG_STEP = 1 + +CALCULATE_DEVICE = "npu:0" +PRINT_DEVICE = "cpu" + + +model_names = sorted(name for name in models.__dict__ + if name.islower() and not name.startswith("__") + and callable(models.__dict__[name])) + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('--data', metavar='DIR', default=SOURCE_DIR, + help='path to dataset') +parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', + choices=model_names, + help='model architecture: ' + + ' | '.join(model_names) + + ' (default: resnet18)') +parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', + help='number of data loading workers (default: 8)') +parser.add_argument('--epochs', default=EPOCHS_SIZE, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=BATCH_SIZE, type=int, + metavar='N', + help='mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate', dest='lr') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') +parser.add_argument('-p', '--print-freq', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +parser.add_argument('--world-size', default=-1, type=int, + help='number of nodes for distributed training') +parser.add_argument('--rank', default=-1, type=int, + help='node rank for distributed training') +parser.add_argument('--dist-url', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='nccl', type=str, + help='distributed backend') +parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training. ') +parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use.') +parser.add_argument('--multiprocessing-distributed', action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') + +best_acc1 = 0 + + +def main(): + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + cudnn.deterministic = True + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + if args.gpu is not None: + warnings.warn('You have chosen a specific GPU. This will completely ' + 'disable data parallelism.') + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + + args.distributed = args.world_size > 1 or args.multiprocessing_distributed + + ngpus_per_node = torch.cuda.device_count() + if args.multiprocessing_distributed: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + args.world_size = ngpus_per_node * args.world_size + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) + else: + # Simply call main_worker function + main_worker(args.gpu, ngpus_per_node, args) + + +def main_worker(gpu, ngpus_per_node, args): + global best_acc1 + args.gpu = gpu + + if args.gpu is not None: + print("Use GPU: {} for training".format(args.gpu)) + + if args.distributed: + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if args.multiprocessing_distributed: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + args.rank = args.rank * ngpus_per_node + gpu + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + # create model + if args.pretrained: + print("=> using pre-trained model '{}'".format(args.arch)) + model = models.__dict__[args.arch](pretrained=True) + else: + print("=> creating model '{}'".format(args.arch)) + model = models.__dict__[args.arch]() + + if args.distributed: + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + args.batch_size = int(args.batch_size / ngpus_per_node) + args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + else: + model.cuda() + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + model = torch.nn.parallel.DistributedDataParallel(model) + elif args.gpu is not None: + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + else: + # DataParallel will divide and allocate batch_size to all available GPUs + if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): + model.features = torch.nn.DataParallel(model.features) + model.cuda() + else: + model = model.to(CALCULATE_DEVICE) + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().to(CALCULATE_DEVICE) + + optimizer = torch.optim.SGD(model.parameters(), args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + if args.gpu is None: + checkpoint = torch.load(args.resume) + else: + # Map model to be loaded to specified single gpu. + loc = 'cuda:{}'.format(args.gpu) + checkpoint = torch.load(args.resume, map_location=loc) + args.start_epoch = checkpoint['epoch'] + best_acc1 = checkpoint['best_acc1'] + if args.gpu is not None: + # best_acc1 may be from a checkpoint from a different GPU + best_acc1 = best_acc1.to(args.gpu) + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + + # Data loading code + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler) + + val_loader = torch.utils.data.DataLoader( + datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])), + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + if args.evaluate: + validate(val_loader, model, criterion, args) + return + + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, epoch, args) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, args) + + # evaluate on validation set + acc1 = validate(val_loader, model, criterion, args) + + # remember best acc@1 and save checkpoint + is_best = acc1 > best_acc1 + best_acc1 = max(acc1, best_acc1) + +def train(train_loader, model, criterion, optimizer, epoch, args): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses, top1, top5], + prefix="Epoch: [{}]".format(epoch)) + + # switch to train mode + model.train() + + end = time.time() + for i, (images, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + if 'npu' in CALCULATE_DEVICE: + target = target.to(torch.int32) + images, target = images.to(CALCULATE_DEVICE, non_blocking=True), target.to(CALCULATE_DEVICE, non_blocking=True) + + # compute output + output = model(images) + + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % LOG_STEP == 0: + progress.display(i) + + if i == TRAIN_STEP: + break + + +def validate(val_loader, model, criterion, args): + batch_time = AverageMeter('Time', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(val_loader), + [batch_time, losses, top1, top5], + prefix='Test: ') + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + end = time.time() + for i, (images, target) in enumerate(val_loader): + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + if 'npu' in CALCULATE_DEVICE: + target = target.to(torch.int32) + images, target = images.to(CALCULATE_DEVICE, non_blocking=True), target.to(CALCULATE_DEVICE, non_blocking=True) + # compute output + output = model(images) + + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % LOG_STEP == 0: + progress.display(i) + break + # TODO: this should also be done with the ProgressMeter + print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' + .format(top1=top1, top5=top5)) + + return top1.avg + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + + +def adjust_learning_rate(optimizer, epoch, args): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr = args.lr * (0.1 ** (epoch // 30)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == '__main__': + if 'npu' in CALCULATE_DEVICE: + torch.npu.set_device(CALCULATE_DEVICE) + main() -- Gitee