diff --git a/PyTorch/contrib/cv/others/VDSR/LICENSE b/PyTorch/contrib/cv/others/VDSR/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7ba87e07665bc056ebf56f921dbc2ea89bc51d21 --- /dev/null +++ b/PyTorch/contrib/cv/others/VDSR/LICENSE @@ -0,0 +1,24 @@ +The MIT License (MIT) + +Copyright (c) 2017- Jiu XU +Copyright (c) 2017- Rakuten, Inc +Copyright (c) 2017- Rakuten Institute of Technology +Copyright 2021 Huawei Technologies Co., Ltd + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/PyTorch/contrib/cv/others/VDSR/README.md b/PyTorch/contrib/cv/others/VDSR/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4f2cbed4d85cced24adc988677a2ee3149e90c89 --- /dev/null +++ b/PyTorch/contrib/cv/others/VDSR/README.md @@ -0,0 +1,181 @@ +# VDSR for PyTorch + +- [概述](概述.md) +- [准备训练环境](准备训练环境.md) +- [开始训练](开始训练.md) +- [训练结果展示](训练结果展示.md) +- [版本说明](版本说明.md) + + + +# 概述 + +## 简述 + +VDSR是一个经典的超分模型,使用了一种非常深的深度学习模型来进行超分,结合残差学习和很高的学习率来进行模型训练加速,并且使用自适应梯度裁剪来保证训练的稳定性,与SRCNN一样,都是先将低分辨率输入双三次插值到高分辨率,再来进行模型的预测。这里包括两个部分,VGG-like的深层网络模型,每一层卷积中均使用带padding的3x3卷积层,并且随后都会添加一个ReLU来增强模型的非线性,最后使用残差学习来将模型预测到的结果element-wise的形式相加,来得到最终的结果。 + + +- 参考实现: + + ``` + url=https://github.com/twtygqyy/pytorch-vdsr.git + ``` + +- 适配昇腾 AI 处理器的实现: + + ``` + url=https://gitee.com/ascend/ModelZoo-PyTorch.git + code_path=PyTorch/contrib/cv/classification + ``` + +- 通过Git获取代码方法如下: + + ``` + git clone https://gitee.com/yxl0321/ModelZoo-PyTorch.git # 克隆仓库的代码 + cd ./ModelZoo-PyTorch/PyTorch/contrib/cv/others/VDSR # 切换到模型代码所在路径,若仓库下只有该模型,则无需切换 + ``` + +- 通过单击“立即下载”,下载源码包。 + +# 准备训练环境 + +## 准备环境 + +- 当前模型支持的固件与驱动、 CANN 以及 PyTorch 如下表所示。 + + **表 1** 版本配套表 + + | 配套 | 版本 | + | ---------- | ------------------------------------------------------------ | + | 固件与驱动 | [1.0.15](https://www.hiascend.com/hardware/firmware-drivers?tag=commercial) | + | CANN | [5.1.RC1](https://www.hiascend.com/software/cann/commercial?version=5.1.RC1) | + | PyTorch | [1.8.1](https://gitee.com/ascend/pytorch/tree/master/) + +- 环境准备指导。 + + 请参考《[Pytorch框架训练环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/ptes)》。 + +- 安装依赖(根据模型需求,按需添加所需依赖)。 + + ``` + pip install -r requirements.txt + ``` + + +## 准备数据集 + +1. 训练集。 + + 本模型采用的训练集是T91数据集通过用Matlab双三次插值生成的,创建训练文件得到的,运行./data/generate_train.m 可生成自己的训练集,需要修改"folder=real_image_path", 也可使用已经生成过好的数据集文件,文件路径192.168.99.101物理机下'./forDocker/home/yxl/VDSR/data/train.h5' +2. 测试集。 + + 本模型的测试集是Set5数据集,是由5张高分辨图片进行双三次插值生成,测试集应放在./VDSR文件夹下,测试集路径192.168.99.101物理机下'./forDocker/home/yxl/VDSR/Set5_mat' + + + +## 获取预训练模型 + +预训练模型存放在'./checkpoint'文件夹下,训练过程中每5个epoch保存一次模型。 + +# 开始训练 + +## 训练模型 + +1. 进入解压后的源码包根目录。 + + ``` + cd ./VDSR + ``` + +2. 运行训练脚本。 + + 该模型支持单机单卡训练和单机8卡训练。 + + - 单机单卡训练 + + 启动单卡训练。 + + ``` + # training 1p accuracy + bash ./test/train_full_1p.sh --data_path=real_data_path --valdata=valdata_path + + # training 1p performance + bash ./test/train_performance_1p.sh --data_path=real_data_path --valdata=valdata_path + + # finetuning 1p + bash test/train_finetune_1p.sh --data_path=real_data_path --valdata=valdata_path --pth_path=real_pre_train_model_path + + ``` + + - 单机8卡训练 + + 启动8卡训练。 + + ``` + # training 8p accuracy + bash ./test/train_full_8p.sh --data_path=real_data_path --valdata=valdata_path + + # training 8p performance + bash ./test/train_performance_8p.sh --data_path=real_data_path --valdata=valdata_path + + #test 8p accuracy + bash test/train_eval_8p.sh --data_path=real_data_path --valdata=valdata_path --pth_path=real_pre_train_model_path + ``` + + --data\_path参数填写数据集路径。默认为./data/train.h5 + --valdata参数填写测试集路径。默认为./Set5_mat + --pth\_path参数填写预训练模型路径。默认为./checkpoint/model_epoch_50.pth + + 模型训练脚本参数说明如下。 + + ``` + 公共参数: + --data_path //数据集路径 + --valdata //测试集路径 + --workers //读取数据集线程数 + --addr //主机地址 + --Epoch //重复训练次数 + --batchSize //训练批次大小 + --lr //初始学习率,默认:0.01 + --momentum //动量,默认:0.9 + --weight_decay //权重衰减,默认:0.0001 + --resume //中断重新开始模型参数路径 + --start-epoch //开始训练epoch 默认1 + --clip //梯度裁剪参数 默认0.4 + --pretrained //预训练模型路径 + --valdata //测试集路径 + --amp //是否使用混合精度 + --loss-scale //混合精度lossscale大小 + --opt-level //混合精度类型 + 多卡训练参数: + --multiprocessing-distributed //是否使用多卡训练 + --device-list '0,1,2,3,4,5,6,7' //多卡训练指定训练用卡 + ``` + + 训练完成后,权重文件保存在'./checkpoint'路径下,并输出模型训练精度和性能信息。 + +# 训练结果展示 + +**表 2** 训练结果展示表 + +我们分别输出了在不同尺度下,模型的PSNR(PSNR_predicted)与使用双三次插值法的PSNR(PSNR_bicubic)以做对比,结果表明,该模型简单有效。 + +| NAME | PSNR_x2 | PSBR_x3 | PSNR_x4 | FPS | Epochs | AMP_Type | +| ------- | ----- | ------ | ------ | ----: | ------ | -------: | +| 1p-竞品 | 37.07 | 33.35 | 31.10 |2434 | 50 | - | +| 1p-NPU | 37.30 | 33.53 | 31.25 |4112 | 50 | O1 | +| 8p-竞品 | 37.03 | 33.34 | 31.04 |10934 | 50 | - | +| 8p-NPU | 37.17 | 33.42 | 31.10 |6334 | 50 | O1 | + + + + + + + + + + + + + diff --git a/PyTorch/contrib/cv/others/VDSR/data/generate_test_mat.m b/PyTorch/contrib/cv/others/VDSR/data/generate_test_mat.m new file mode 100644 index 0000000000000000000000000000000000000000..d8ab163a8673b3089c22a10ff672afe3c74ae5a3 --- /dev/null +++ b/PyTorch/contrib/cv/others/VDSR/data/generate_test_mat.m @@ -0,0 +1,28 @@ +clear;close all; +%% settings +folder = 'Set5'; + +%% generate data +filepaths = []; +filepaths = [filepaths; dir(fullfile(folder, '*.bmp'))]; + +scale = [2, 3, 4]; + +for i = 1 : length(filepaths) + im_gt = imread(fullfile(folder,filepaths(i).name)); + for s = 1 : length(scale) + im_gt = modcrop(im_gt, scale(s)); + im_gt = double(im_gt); + im_gt_ycbcr = rgb2ycbcr(im_gt / 255.0); + im_gt_y = im_gt_ycbcr(:,:,1) * 255.0; + im_l_ycbcr = imresize(im_gt_ycbcr,1/scale(s),'bicubic'); + im_b_ycbcr = imresize(im_l_ycbcr,scale(s),'bicubic'); + im_l_y = im_l_ycbcr(:,:,1) * 255.0; + im_l = ycbcr2rgb(im_l_ycbcr) * 255.0; + im_b_y = im_b_ycbcr(:,:,1) * 255.0; + im_b = ycbcr2rgb(im_b_ycbcr) * 255.0; + last = length(filepaths(i).name)-4; + filename = sprintf('Set5_mat/%s_x%s.mat',filepaths(i).name(1 : last),num2str(scale(s))); + save(filename, 'im_gt_y', 'im_b_y', 'im_gt', 'im_b', 'im_l_ycbcr', 'im_l_y', 'im_l'); + end +end diff --git a/PyTorch/contrib/cv/others/VDSR/data/generate_train.m b/PyTorch/contrib/cv/others/VDSR/data/generate_train.m new file mode 100644 index 0000000000000000000000000000000000000000..8caba08852ee70781d33061ac36c16ce5fdac868 --- /dev/null +++ b/PyTorch/contrib/cv/others/VDSR/data/generate_train.m @@ -0,0 +1,93 @@ +clear;close all; + +folder = 'path/to/train/folder'; + +savepath = 'train.h5'; +size_input = 41; +size_label = 41; +stride = 41; + +%% scale factors +scale = [2,3,4]; +%% downsizing +downsizes = [1,0.7,0.5]; + +%% initialization +data = zeros(size_input, size_input, 1, 1); +label = zeros(size_label, size_label, 1, 1); + +count = 0; +margain = 0; + +%% generate data +filepaths = []; +filepaths = [filepaths; dir(fullfile(folder, '*.jpg'))]; +filepaths = [filepaths; dir(fullfile(folder, '*.bmp'))]; + +for i = 1 : length(filepaths) + for flip = 1: 3 + for degree = 1 : 4 + for s = 1 : length(scale) + for downsize = 1 : length(downsizes) + image = imread(fullfile(folder,filepaths(i).name)); + + if flip == 1 + image = flipdim(image ,1); + end + if flip == 2 + image = flipdim(image ,2); + end + + image = imrotate(image, 90 * (degree - 1)); + + image = imresize(image,downsizes(downsize),'bicubic'); + + if size(image,3)==3 + image = rgb2ycbcr(image); + image = im2double(image(:, :, 1)); + + im_label = modcrop(image, scale(s)); + [hei,wid] = size(im_label); + im_input = imresize(imresize(im_label,1/scale(s),'bicubic'),[hei,wid],'bicubic'); + filepaths(i).name + for x = 1 : stride : hei-size_input+1 + for y = 1 :stride : wid-size_input+1 + + subim_input = im_input(x : x+size_input-1, y : y+size_input-1); + subim_label = im_label(x : x+size_label-1, y : y+size_label-1); + + count=count+1; + + data(:, :, 1, count) = subim_input; + label(:, :, 1, count) = subim_label; + end + end + end + end + end + end + end +end + +order = randperm(count); +data = data(:, :, 1, order); +label = label(:, :, 1, order); + +%% writing to HDF5 +chunksz = 64; +created_flag = false; +totalct = 0; + +for batchno = 1:floor(count/chunksz) + batchno + last_read=(batchno-1)*chunksz; + batchdata = data(:,:,1,last_read+1:last_read+chunksz); + batchlabs = label(:,:,1,last_read+1:last_read+chunksz); + + startloc = struct('dat',[1,1,1,totalct+1], 'lab', [1,1,1,totalct+1]); + curr_dat_sz = store2hdf5(savepath, batchdata, batchlabs, ~created_flag, startloc, chunksz); + created_flag = true; + totalct = curr_dat_sz(end); +end + +h5disp(savepath); diff --git a/PyTorch/contrib/cv/others/VDSR/data/modcrop.m b/PyTorch/contrib/cv/others/VDSR/data/modcrop.m new file mode 100644 index 0000000000000000000000000000000000000000..728c68810609913d8ae8475a0d7305a92a1f1fae --- /dev/null +++ b/PyTorch/contrib/cv/others/VDSR/data/modcrop.m @@ -0,0 +1,12 @@ +function imgs = modcrop(imgs, modulo) +if size(imgs,3)==1 + sz = size(imgs); + sz = sz - mod(sz, modulo); + imgs = imgs(1:sz(1), 1:sz(2)); +else + tmpsz = size(imgs); + sz = tmpsz(1:2); + sz = sz - mod(sz, modulo); + imgs = imgs(1:sz(1), 1:sz(2),:); +end + diff --git a/PyTorch/contrib/cv/others/VDSR/data/store2hdf5.m b/PyTorch/contrib/cv/others/VDSR/data/store2hdf5.m new file mode 100644 index 0000000000000000000000000000000000000000..0a0016dca40925652f679900be9627cee75e9c22 --- /dev/null +++ b/PyTorch/contrib/cv/others/VDSR/data/store2hdf5.m @@ -0,0 +1,59 @@ +function [curr_dat_sz, curr_lab_sz] = store2hdf5(filename, data, labels, create, startloc, chunksz) + % *data* is W*H*C*N matrix of images should be normalized (e.g. to lie between 0 and 1) beforehand + % *label* is D*N matrix of labels (D labels per sample) + % *create* [0/1] specifies whether to create file newly or to append to previously created file, useful to store information in batches when a dataset is too big to be held in memory (default: 1) + % *startloc* (point at which to start writing data). By default, + % if create=1 (create mode), startloc.data=[1 1 1 1], and startloc.lab=[1 1]; + % if create=0 (append mode), startloc.data=[1 1 1 K+1], and startloc.lab = [1 K+1]; where K is the current number of samples stored in the HDF + % chunksz (used only in create mode), specifies number of samples to be stored per chunk (see HDF5 documentation on chunking) for creating HDF5 files with unbounded maximum size - TLDR; higher chunk sizes allow faster read-write operations + + % verify that format is right + dat_dims=size(data); + lab_dims=size(labels); + num_samples=dat_dims(end); + + assert(lab_dims(end)==num_samples, 'Number of samples should be matched between data and labels'); + + if ~exist('create','var') + create=true; + end + + + if create + %fprintf('Creating dataset with %d samples\n', num_samples); + if ~exist('chunksz', 'var') + chunksz=1000; + end + if exist(filename, 'file') + fprintf('Warning: replacing existing file %s \n', filename); + delete(filename); + end + h5create(filename, '/data', [dat_dims(1:end-1) Inf], 'Datatype', 'single', 'ChunkSize', [dat_dims(1:end-1) chunksz]); % width, height, channels, number + h5create(filename, '/label', [lab_dims(1:end-1) Inf], 'Datatype', 'single', 'ChunkSize', [lab_dims(1:end-1) chunksz]); % width, height, channels, number + if ~exist('startloc','var') + startloc.dat=[ones(1,length(dat_dims)-1), 1]; + startloc.lab=[ones(1,length(lab_dims)-1), 1]; + end + else % append mode + if ~exist('startloc','var') + info=h5info(filename); + prev_dat_sz=info.Datasets(1).Dataspace.Size; + prev_lab_sz=info.Datasets(2).Dataspace.Size; + assert(prev_dat_sz(1:end-1)==dat_dims(1:end-1), 'Data dimensions must match existing dimensions in dataset'); + assert(prev_lab_sz(1:end-1)==lab_dims(1:end-1), 'Label dimensions must match existing dimensions in dataset'); + startloc.dat=[ones(1,length(dat_dims)-1), prev_dat_sz(end)+1]; + startloc.lab=[ones(1,length(lab_dims)-1), prev_lab_sz(end)+1]; + end + end + + if ~isempty(data) + h5write(filename, '/data', single(data), startloc.dat, size(data)); + h5write(filename, '/label', single(labels), startloc.lab, size(labels)); + end + + if nargout + info=h5info(filename); + curr_dat_sz=info.Datasets(1).Dataspace.Size; + curr_lab_sz=info.Datasets(2).Dataspace.Size; + end +end diff --git a/PyTorch/contrib/cv/others/VDSR/dataset.py b/PyTorch/contrib/cv/others/VDSR/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a523e19179d37f77066c300b5d0f523247a24225 --- /dev/null +++ b/PyTorch/contrib/cv/others/VDSR/dataset.py @@ -0,0 +1,31 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import torch.utils.data as data +import torch +import h5py + +class DatasetFromHdf5(data.Dataset): + def __init__(self, file_path): + super(DatasetFromHdf5, self).__init__() + hf = h5py.File(file_path) + self.data = hf.get('data') + self.target = hf.get('label') + + def __getitem__(self, index): + return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.target[index,:,:,:]).float() + + def __len__(self): + return self.data.shape[0] \ No newline at end of file diff --git a/PyTorch/contrib/cv/others/VDSR/main.py b/PyTorch/contrib/cv/others/VDSR/main.py new file mode 100644 index 0000000000000000000000000000000000000000..45b0054fdb9d1da11d3c4673171cdd1799189c32 --- /dev/null +++ b/PyTorch/contrib/cv/others/VDSR/main.py @@ -0,0 +1,556 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import os +import warnings +import torch +if torch.__version__>= '1.8': + import torch_npu +import sys +import time +import random +import math +import glob +import scipy.io as sio +import numpy as np +import apex +from apex import amp +import torch.nn as nn +import torch.optim as optim +import torch.backends.cudnn as cudnn +import torch.multiprocessing as mp +import torch.utils.data.distributed +from torch.autograd import Variable +from torch.utils.data import DataLoader +import torch.distributed as dist +from models.vdsr import Net +from dataset import DatasetFromHdf5 +from apex.parallel import DistributedDataParallel + + +# Training settings +parser = argparse.ArgumentParser(description="PyTorch VDSR") +parser.add_argument('--data_path', default='./data/train.h5', type=str,help='path to dataset') +parser.add_argument("--batchSize", type=int, default=128, help="Training batch size") +parser.add_argument("--nEpochs", type=int, default=50, help="Number of epochs to train for") +parser.add_argument("--lr", type=float, default=0.1, help="Learning Rate. Default=0.1") +parser.add_argument("--step", type=int, default=10, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=10") +parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)") +parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number (useful on restarts)") +parser.add_argument("--clip", type=float, default=0.4, help="Clipping Gradients. Default=0.4") +parser.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9") +parser.add_argument("--weight-decay", "--wd", default=1e-4, type=float, help="Weight decay, Default: 1e-4") +parser.add_argument('--seed', default=49, type=int,help='seed for initializing training. ') +parser.add_argument('--pretrained', default='', type=str, help='path to pretrained model (default: none)') +parser.add_argument("--gpu", default= None, type=int, help="gpu ids (default: 0)") +parser.add_argument('--num_classes', default=10, type=int,help='The number of classes.') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',help='number of data loading workers (default: 4)') +parser.add_argument('--valdata', default='./Set5_mat', type=str,help='number of data loading workers (default: 4)') +#GPU +parser.add_argument('--device', default='npu', type=str, help='npu or gpu') +parser.add_argument("--world_size", default=1, type=int, help="Number of threads for data loader to use, Default: 1") +parser.add_argument('--local_rank',default=-1,type=int,help='node rank for distributed training') +parser.add_argument('--loss-scale', default=None, type=float,help='loss scale using in amp, default -1 means dynamic') +parser.add_argument('--opt-level', default='O1', type=str,help='loss scale using in amp, default -1 means dynamic') +parser.add_argument('--prof', default=False, action='store_true',help='use profiling to evaluate the performance of model') +parser.add_argument('--amp', default=False, action='store_true',help='use amp to train the model') +parser.add_argument('--distributed',default =False,help='') +parser.add_argument('--addr', default='127.0.0.1',type=str, help='master addr') +parser.add_argument('--device_list', default='0,1,2,3,4,5,6,7',type=str, help='device id list') +parser.add_argument('--dist-url', default='tcp://127.0.0.1:50001', type=str,help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='hccl', type=str,help='distributed backend') +parser.add_argument('--multiprocessing-distributed', default=False, action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') +parser.add_argument('--rank', default=0, type=int,help='node rank for distributed training') + +def device_id_to_process_device_map(device_list): + devices = device_list.split(",") + devices = [int(x) for x in devices] + devices.sort() + + process_device_map = dict() + for process_id, device_id in enumerate(devices): + process_device_map[process_id] = device_id + + return process_device_map + +# for servers to immediately record the logs +def flush_print(func): + def new_print(*args, **kwargs): + func(*args, **kwargs) + sys.stdout.flush() + return new_print +print = flush_print(print) + +def main(): + opt = parser.parse_args() + + os.environ['MASTER_ADDR'] = opt.addr + os.environ['MASTER_PORT'] = '29777' + + if opt.seed is not None: + random.seed(opt.seed) + torch.manual_seed(opt.seed) + cudnn.deterministic = True + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + if opt.gpu is not None: + warnings.warn('You have chosen a specific GPU. This will completely ' + 'disable data parallelism.') + + if opt.dist_url == "env://" and opt.world_size == -1: + opt.world_size = int(os.environ["WORLD_SIZE"]) + + opt.distributed = opt.world_size > 1 or opt.multiprocessing_distributed + + opt.process_device_map = device_id_to_process_device_map(opt.device_list) + + if opt.device == 'npu': + ngpus_per_node = int(os.environ["RANK_SIZE"]) + else: + if opt.distributed: + ngpus_per_node = torch.cuda.device_count() + else: + ngpus_per_node = 1 + if opt.multiprocessing_distributed: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + opt.world_size = ngpus_per_node * opt.world_size + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + if opt.device=='npu': + main_worker(opt.gpu, ngpus_per_node,opt) + else: + mp.spawn(main_worker, nprocs=ngpus_per_node, + args=(ngpus_per_node, opt)) + else: + # Simply call main_worker function + main_worker(opt.gpu, ngpus_per_node,opt) + +def main_worker(gpu, ngpus_per_node,opt): + + opt.gpu = opt.process_device_map[gpu] + + + if opt.distributed: + if opt.dist_url == "env://" and opt.rank == -1: + opt.rank = int(os.environ["RANK"]) + if opt.multiprocessing_distributed: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + opt.rank = opt.rank * ngpus_per_node + gpu + + if opt.device == 'npu': + dist.init_process_group(backend=opt.dist_backend, # init_method=args.dist_url, + world_size=opt.world_size, rank=opt.rank) + else: + dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url, + world_size=opt.world_size, rank=opt.rank) + + + if opt.pretrained: + model = Net() + if os.path.isfile(opt.pretrained): + pretrained_dict = torch.load(opt.pretrained, map_location="cpu")["state_dict"] + else: + pretrained_dict = torch.load("./checkpoint/model_epoch_50.pth", map_location="cpu")["state_dict"] + if "fc.weight" in pretrained_dict: + pretrained_dict.pop('fc.weight') + pretrained_dict.pop('fc.bias') + model.load_state_dict(pretrained_dict, strict=False) + else: + model = Net() + + + if opt.distributed: + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + if opt.gpu is not None: + if opt.device == 'npu': + loc = 'npu:{}'.format(opt.gpu) + torch.npu.set_device(loc) + model = model.to(loc) + else: + torch.cuda.set_device(opt.gpu) + model.cuda(opt.gpu) + + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + opt.batchSize = int(opt.batchSize / opt.world_size) + # opt.workers = int((opt.workers + ngpus_per_node - 1) / ngpus_per_node) + else: + if opt.device == 'npu': + loc = 'npu:{}'.format(opt.gpu) + model = model.to(loc) + else: + model.cuda() + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + elif opt.gpu is not None: + if opt.device == 'npu': + loc = 'npu:{}'.format(opt.gpu) + torch.npu.set_device(opt.gpu) + model = model.to(loc) + else: + torch.cuda.set_device(opt.gpu) + model = model.cuda(opt.gpu) + else: + # DataParallel will divide and allocate batch_size to all available GPUs + if opt.device == 'npu': + loc = 'npu:{}'.format(opt.gpu) + + optimizer = apex.optimizers.NpuFusedSGD(model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay) + if opt.amp: + model, optimizer = amp.initialize(model, optimizer, opt_level=opt.opt_level, loss_scale=opt.loss_scale, combine_grad=True) + + if opt.distributed: + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + if opt.gpu is not None: + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + if opt.pretrained: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[opt.gpu], broadcast_buffers=False, + find_unused_parameters=True) + else: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[opt.gpu], broadcast_buffers=False) + else: + model = torch.nn.parallel.DistributedDataParallel(model) + else: + # DataParallel will divide and allocate batch_size to all available GPUs + if opt.device == 'npu': + loc = 'npu:{}'.format(opt.gpu) + model = torch.nn.DataParallel(model).to(loc) + else: + model = torch.nn.DataParallel(model).cuda() + + + if opt.device == 'npu': + loc = 'npu:{}'.format(opt.gpu) + criterion = nn.MSELoss(reduction='sum').to(loc) + else: + criterion = nn.MSELoss(reduction='sum').cuda(opt.gpu) + + # optionally resume from a checkpoint + if opt.resume: + if os.path.isfile(opt.resume): + if opt.gpu is None: + checkpoint = torch.load(opt.resume) + else: + # Map model to be loaded to specified single gpu. + if opt.device == 'npu': + loc = 'npu:{}'.format(opt.gpu) + else: + loc = 'cuda:{}'.format(opt.gpu) + checkpoint = torch.load(opt.resume, map_location=loc) + opt.start_epoch = checkpoint['epoch'] + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + if opt.amp: + amp.load_state_dict(checkpoint['amp']) + + + cudnn.benchmark = True + + train_set = DatasetFromHdf5(opt.data_path) + if opt.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_set) + else: + train_sampler = None + if opt.device == 'npu': + training_data_loader = DataLoader(train_set,batch_size=opt.batchSize, + num_workers=opt.workers, + shuffle=(train_sampler is None), + pin_memory=False, + sampler=train_sampler, + drop_last=True) + else: + training_data_loader = DataLoader(train_set,batch_size=opt.batchSize, + shuffle=(train_sampler is None), + pin_memory=False, + sampler=train_sampler, + drop_last=True) + + if opt.prof: + profiling(training_data_loader, model, criterion, optimizer, opt) + return + + start_time = time.time() + for epoch in range(opt.start_epoch, opt.nEpochs+1): + if opt.distributed: + train_sampler.set_epoch(epoch) + train(training_data_loader, optimizer, model, criterion, epoch, opt, ngpus_per_node) + if not opt.multiprocessing_distributed or (opt.multiprocessing_distributed + and opt.rank % ngpus_per_node == 0): + if (epoch) % 5 != 0: # 5个epoch保存一下,一共保存20个epoch + continue + + ############## npu modify begin ############# + if opt.amp: + save_checkpoint({ + 'epoch': epoch, + 'arch': 'VDSR', + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'amp': amp.state_dict(), + }, epoch) + else: + save_checkpoint({ + 'epoch': epoch, + 'arch': 'VDSR', + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + }, epoch) + ############## npu modify end ############# + + if not opt.multiprocessing_distributed or (opt.multiprocessing_distributed + and opt.rank % ngpus_per_node == 0): + test(model,opt) + +def adjust_learning_rate(optimizer, epoch, opt): + """Sets the learning rate to the initial LR decayed by 10 every 10 epochs""" + lr = opt.lr * (0.1 ** (epoch // opt.step)) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + +def profiling(data_loader, model, criterion, optimizer, opt,): + # switch to train mode + model.train() + + def update(model, images, target, optimizer): + output = model(images) + loss = criterion(output, target) + if opt.amp: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + optimizer.zero_grad() + optimizer.step() + + for step, (images, target) in enumerate(data_loader): + if opt.device == 'npu': + loc = 'npu:{}'.format(opt.gpu) + images = images.to(loc, non_blocking=True).to(torch.float) + target = target.to(torch.float).to(loc, non_blocking=True) + else: + images = images.cuda(opt.gpu, non_blocking=True) + target = target.cuda(opt.gpu, non_blocking=True) + + if step < 5: + update(model, images, target, optimizer) + else: + if opt.device == 'npu': + with torch.autograd.profiler.profile(use_npu=True) as prof: + update(model, images, target, optimizer) + else: + with torch.autograd.profiler.profile(use_cuda=True) as prof: + update(model, images, target, optimizer) + break + + prof.export_chrome_trace("output.prof") + +def train(training_data_loader, optimizer, model, criterion, epoch, opt, ngpus_per_node): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(training_data_loader), + [batch_time, data_time, losses, top1, top5], + prefix="Epoch: [{}]".format(epoch)) + + adjust_learning_rate(optimizer, epoch, opt) + model.train() + end = time.time() + for iteration, batch in enumerate(training_data_loader, 1): + input, target = Variable(batch[0]), Variable(batch[1], requires_grad=False) + data_time.update(time.time() - end) + + if opt.device == 'npu': + loc = 'npu:{}'.format(opt.gpu) + input = input.to(loc, non_blocking=True) + target = target.to(loc, non_blocking=True) + else: + input = input.cuda(opt.gpu, non_blocking=True) + target = target.cuda(opt.gpu, non_blocking=True) + + output = model(input) + loss = criterion(output, target) + + optimizer.zero_grad() + if opt.amp: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + # 梯度裁剪 + nn.utils.clip_grad_norm_(model.parameters(), opt.clip) + optimizer.step() + if opt.device == 'npu': + torch.npu.synchronize() + # measure elapsed time + cost_time = time.time() - end + batch_time.update(cost_time) + end = time.time() + + if not opt.multiprocessing_distributed or (opt.multiprocessing_distributed + and opt.rank % ngpus_per_node == 0): + if iteration % 100 == 0: + print("===> Epoch[{}]({}/{}): Loss: {:.10f}".format(epoch, iteration, len(training_data_loader), loss.item())) + if batch_time.avg: + print("[npu id:", opt.gpu, "]", "batch_size:", opt.world_size * opt.batchSize, + 'Time: {:.3f}'.format(batch_time.avg), '* FPS@all {:.3f}'.format( + opt.batchSize * opt.world_size / batch_time.avg)) + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=':f', start_count_index=10): + self.name = name + self.fmt = fmt + self.reset() + self.start_count_index = start_count_index + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + if self.count == 0: + self.N = n + + self.val = val + self.count += n + if self.count > (self.start_count_index * self.N): + self.sum += val * n + self.avg = self.sum / (self.count - self.start_count_index * self.N) + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print("[npu id:", os.environ['LOCAL_DEVICE_ID'], "]", '\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + +def save_checkpoint(state, epoch): + model_out_path = "checkpoint/" + "model_epoch_{}.pth".format(epoch) + if not os.path.exists("checkpoint/"): + os.makedirs("checkpoint/") + torch.save(state, model_out_path) + +def PSNR(pred, gt, shave_border=0): + height, width = pred.shape[:2] + pred = pred[shave_border:height - shave_border, shave_border:width - shave_border] + gt = gt[shave_border:height - shave_border, shave_border:width - shave_border] + imdff = pred - gt + rmse = math.sqrt(np.mean(imdff ** 2)) + if rmse == 0: + return 100 + return 20 * math.log10(255.0 / rmse) + +def test(model, opt): + + model.eval() + + scales = [2, 3, 4] + + image_list = glob.glob(opt.valdata+"/*.*") + + for scale in scales: + avg_psnr_predicted = 0.0 + avg_psnr_bicubic = 0.0 + avg_elapsed_time = 0.0 + count = 0.0 + for image_name in image_list: + if str(scale) in image_name: + count += 1 + # print("Processing ", image_name) + im_gt_y = sio.loadmat(image_name)['im_gt_y'] + im_b_y = sio.loadmat(image_name)['im_b_y'] + + im_gt_y = im_gt_y.astype(float) + im_b_y = im_b_y.astype(float) + + psnr_bicubic = PSNR(im_gt_y, im_b_y,shave_border=scale) + avg_psnr_bicubic += psnr_bicubic + + im_input = im_b_y/255. + + im_input = Variable(torch.from_numpy(im_input).float()).view(1, -1, im_input.shape[0], im_input.shape[1]) + + if opt.device == 'npu': + loc = 'npu:{}'.format(opt.gpu) + im_input = im_input.to(loc, non_blocking=True).to(torch.float) + else: + im_input = im_input.cuda(opt.gpu, non_blocking=True) + + start_time = time.time() + HR = model(im_input) + elapsed_time = time.time() - start_time + avg_elapsed_time += elapsed_time + + HR = HR.cpu() + + im_h_y = HR.data[0].numpy().astype(np.float32) + + im_h_y = im_h_y * 255. + im_h_y[im_h_y < 0] = 0 + im_h_y[im_h_y > 255.] = 255. + im_h_y = im_h_y[0,:,:] + + psnr_predicted = PSNR(im_gt_y, im_h_y,shave_border=scale) + avg_psnr_predicted += psnr_predicted + + print("Scale=", scale) + print("Dataset=", opt.valdata) + print("PSNR_predicted=", avg_psnr_predicted/count) + print("PSNR_bicubic=", avg_psnr_bicubic/count) + print("It takes average {}s for processing".format(avg_elapsed_time/count)) + +if __name__ == "__main__": + main() + + + diff --git a/PyTorch/contrib/cv/others/VDSR/models/vdsr.py b/PyTorch/contrib/cv/others/VDSR/models/vdsr.py new file mode 100644 index 0000000000000000000000000000000000000000..60d806cfed29b5f2e0463757c76a45142637c996 --- /dev/null +++ b/PyTorch/contrib/cv/others/VDSR/models/vdsr.py @@ -0,0 +1,55 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import torch +import torch.nn as nn +from math import sqrt + +class ConvReLUBlock(nn.Module): + def __init__(self): + super(ConvReLUBlock, self).__init__() + self.conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.relu(self.conv(x)) + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.residual_layer = self.make_layer(ConvReLUBlock, 18) + self.input = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) + self.output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False) + self.relu = nn.ReLU(inplace=True) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, sqrt(2. / n)) + + def make_layer(self, block, num_of_layer): + layers = [] + for _ in range(num_of_layer): + layers.append(block()) + return nn.Sequential(*layers) + + def forward(self, x): + residual = x + out = self.relu(self.input(x)) + out = self.residual_layer(out) + out = self.output(out) + out = torch.add(out, residual) + return out + \ No newline at end of file diff --git a/PyTorch/contrib/cv/others/VDSR/modelzoo_level.txt b/PyTorch/contrib/cv/others/VDSR/modelzoo_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..55a9add9fa74832ca908108d73946cd76281a9cd --- /dev/null +++ b/PyTorch/contrib/cv/others/VDSR/modelzoo_level.txt @@ -0,0 +1,3 @@ +FuncStatus:OK +PerfStatus:OK +PrecisionStatus:POK \ No newline at end of file diff --git a/PyTorch/contrib/cv/others/VDSR/requirements.txt b/PyTorch/contrib/cv/others/VDSR/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..cf3531aa4776cc8ebf671732a56f491d92019041 --- /dev/null +++ b/PyTorch/contrib/cv/others/VDSR/requirements.txt @@ -0,0 +1,5 @@ +scipy +h5py +numpy +apex +glob \ No newline at end of file diff --git a/PyTorch/contrib/cv/others/VDSR/test/env_npu.sh b/PyTorch/contrib/cv/others/VDSR/test/env_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..099ec5cc8be67ce5b0a899708139c8e83f397f10 --- /dev/null +++ b/PyTorch/contrib/cv/others/VDSR/test/env_npu.sh @@ -0,0 +1,76 @@ +#!/bin/bash +export install_path=/usr/local/Ascend + +if [ -d ${install_path}/toolkit ]; then + export LD_LIBRARY_PATH=${install_path}/fwkacllib/lib64/:/usr/include/hdf5/lib/:/usr/local/:/usr/local/lib/:/usr/lib/:${install_path}/driver/lib64/common/:${install_path}/driver/lib64/driver/:${install_path}/add-ons:${path_lib}:${LD_LIBRARY_PATH} + export PATH=${install_path}/fwkacllib/ccec_compiler/bin:${install_path}/fwkacllib/bin:$PATH + export PYTHONPATH=${install_path}/fwkacllib/python/site-packages:${install_path}/tfplugin/python/site-packages:${install_path}/toolkit/python/site-packages:$PYTHONPATH + export PYTHONPATH=/usr/local/python3.7.5/lib/python3.7/site-packages:$PYTHONPATH + export ASCEND_OPP_PATH=${install_path}/opp +else + if [ -d ${install_path}/nnae/latest ];then + export LD_LIBRARY_PATH=${install_path}/nnae/latest/fwkacllib/lib64/:/usr/local/:/usr/local/python3.7.5/lib/:/usr/local/openblas/lib:/usr/local/lib/:/usr/lib64/:/usr/lib/:${install_path}/driver/lib64/common/:${install_path}/driver/lib64/driver/:${install_path}/add-ons/:/usr/lib/aarch64_64-linux-gnu:$LD_LIBRARY_PATH + export PATH=$PATH:${install_path}/nnae/latest/fwkacllib/ccec_compiler/bin/:${install_path}/nnae/latest/toolkit/tools/ide_daemon/bin/ + export ASCEND_OPP_PATH=${install_path}/nnae/latest/opp/ + export OPTION_EXEC_EXTERN_PLUGIN_PATH=${install_path}/nnae/latest/fwkacllib/lib64/plugin/opskernel/libfe.so:${install_path}/nnae/latest/fwkacllib/lib64/plugin/opskernel/libaicpu_engine.so:${install_path}/nnae/latest/fwkacllib/lib64/plugin/opskernel/libge_local_engine.so + export PYTHONPATH=${install_path}/nnae/latest/fwkacllib/python/site-packages/:${install_path}/nnae/latest/fwkacllib/python/site-packages/auto_tune.egg/auto_tune:${install_path}/nnae/latest/fwkacllib/python/site-packages/schedule_search.egg:$PYTHONPATH + export ASCEND_AICPU_PATH=${install_path}/nnae/latest + else + export LD_LIBRARY_PATH=${install_path}/ascend-toolkit/latest/fwkacllib/lib64/:/usr/local/:/usr/local/lib/:/usr/lib64/:/usr/lib/:/usr/local/python3.7.5/lib/:/usr/local/openblas/lib:${install_path}/driver/lib64/common/:${install_path}/driver/lib64/driver/:${install_path}/add-ons/:/usr/lib/aarch64-linux-gnu:$LD_LIBRARY_PATH + export PATH=$PATH:${install_path}/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin/:${install_path}/ascend-toolkit/latest/toolkit/tools/ide_daemon/bin/ + export ASCEND_OPP_PATH=${install_path}/ascend-toolkit/latest/opp/ + export OPTION_EXEC_EXTERN_PLUGIN_PATH=${install_path}/ascend-toolkit/latest/fwkacllib/lib64/plugin/opskernel/libfe.so:${install_path}/ascend-toolkit/latest/fwkacllib/lib64/plugin/opskernel/libaicpu_engine.so:${install_path}/ascend-toolkit/latest/fwkacllib/lib64/plugin/opskernel/libge_local_engine.so + export PYTHONPATH=${install_path}/ascend-toolkit/latest/fwkacllib/python/site-packages/:${install_path}/ascend-toolkit/latest/fwkacllib/python/site-packages/auto_tune.egg/auto_tune:${install_path}/ascend-toolkit/latest/fwkacllib/python/site-packages/schedule_search.egg:$PYTHONPATH + export ASCEND_AICPU_PATH=${install_path}/ascend-toolkit/latest + fi +fi + +${install_path}/driver/tools/msnpureport -g error -d 0 +${install_path}/driver/tools/msnpureport -g error -d 1 +${install_path}/driver/tools/msnpureport -g error -d 2 +${install_path}/driver/tools/msnpureport -g error -d 3 +${install_path}/driver/tools/msnpureport -g error -d 4 +${install_path}/driver/tools/msnpureport -g error -d 5 +${install_path}/driver/tools/msnpureport -g error -d 6 +${install_path}/driver/tools/msnpureport -g error -d 7 + +#将Host日志输出到串口,0-关闭/1-开启 +export ASCEND_SLOG_PRINT_TO_STDOUT=0 +#设置默认日志级别,0-debug/1-info/2-warning/3-error +export ASCEND_GLOBAL_LOG_LEVEL=3 +#设置Event日志开启标志,0-关闭/1-开启 +export ASCEND_GLOBAL_EVENT_ENABLE=0 +#设置是否开启taskque,0-关闭/1-开启 +export TASK_QUEUE_ENABLE=0 +#设置是否开启PTCopy,0-关闭/1-开启 +export PTCOPY_ENABLE=1 +#设置是否开启combined标志,0-关闭/1-开启 +export COMBINED_ENABLE=0 +#设置特殊场景是否需要重新编译,不需要修改 +export DYNAMIC_OP="ADD#MUL" +#HCCL白名单开关,1-关闭/0-开启 +export HCCL_WHITELIST_DISABLE=1 +export HCCL_IF_IP=$(hostname -I |awk '{print $1}') + +ulimit -SHn 512000 + +path_lib=$(python3.7 -c """ +import sys +import re +result='' +for index in range(len(sys.path)): + match_sit = re.search('-packages', sys.path[index]) + if match_sit is not None: + match_lib = re.search('lib', sys.path[index]) + + if match_lib is not None: + end=match_lib.span()[1] + result += sys.path[index][0:end] + ':' + + result+=sys.path[index] + '/torch/lib:' +print(result)""" +) + +echo ${path_lib} + +export LD_LIBRARY_PATH=/usr/local/python3.7.5/lib/:${path_lib}:$LD_LIBRARY_PATH diff --git a/PyTorch/contrib/cv/others/VDSR/test/train_eval_8p.sh b/PyTorch/contrib/cv/others/VDSR/test/train_eval_8p.sh new file mode 100644 index 0000000000000000000000000000000000000000..b0eea76f850496a56570cdbc4efbc6c88a9f1fd7 --- /dev/null +++ b/PyTorch/contrib/cv/others/VDSR/test/train_eval_8p.sh @@ -0,0 +1,142 @@ +#!/bin/bash + +################基础配置参数,需要模型审视修改################## +# 必选字段(必须在此处定义的参数): Network batch_size resume RANK_SIZE +# 网络名称,同目录名称 +Network="VDSR" +# 训练batch_size +batch_size=384 +# 训练使用的npu卡数 +export RANK_SIZE=8 +# 数据集路径,保持为空,不需要修改 +data_path="" +# 测试集路径,保持为空,不需要修改 +valdata="" +# checkpoint文件路径,以实际路径为准 +pth_path="" +# 训练epoch +train_epochs=50 +# 学习率 +learning_rate=0.3 + + + +# 参数校验,data_path为必传参数,其他参数的增删由模型自身决定;此处新增参数需在上面有定义并赋值 +for para in $* +do + if [[ $para == --workers* ]];then + workers=`echo ${para#*=}` + elif [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + elif [[ $para == --pth_path* ]];then + pth_path=`echo ${para#*=}` + elif [[ $para == --valdata* ]];then + valdata=`echo ${para#*=}` + fi +done + +# 校验是否传入data_path,不需要修改 +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path\" must be confing" + exit 1 +fi + +# 校验是否传入 pth_path , 验证脚本需要传入此参数 +if [[ $pth_path == "" ]];then + echo "[Error] para \"pth_path\" must be confing" + exit 1 +fi + +###############指定训练脚本执行路径############### +# cd到与test文件夹同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径 +cur_path=`pwd` +cur_path_last_diename=${cur_path##*/} +if [ x"${cur_path_last_diename}" == x"test" ];then + test_path_dir=${cur_path} + cd .. + cur_path=`pwd` +else + test_path_dir=${cur_path}/test +fi + + +#################创建日志输出目录,不需要修改################# +ASCEND_DEVICE_ID=0 +if [ -d ${test_path_dir}/output/${ASCEND_DEVICE_ID} ];then + rm -rf ${test_path_dir}/output/${ASCEND_DEVICE_ID} + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +else + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +fi + + +#################启动训练脚本################# +#训练开始时间,不需要修改 +start_time=$(date +%s) +# 非平台场景时source 环境变量 +check_etp_flag=`env | grep etp_running_flag` +etp_flag=`echo ${check_etp_flag#*=}` +if [ x"${etp_flag}" != x"true" ];then + source ${test_path_dir}/env_npu.sh +fi +RANK_ID_START=0 +KERNEL_NUM=$(($(nproc)/8)) +for((RANK_ID=$RANK_ID_START;RANK_ID<$((RANK_SIZE+RANK_ID_START));RANK_ID++)) +do + PID_START=$((KERNEL_NUM * RANK_ID)) + PID_END=$((PID_START + KERNEL_NUM - 1)) + nohup taskset -c $PID_START-$PID_END python3.7 ./main.py -j ${KERNEL_NUM}\ + --data_path=${data_path} \ + --valdata=${valdata} \ + --addr=$(hostname -I |awk '{print $1}') \ + --seed=49 \ + --lr=${learning_rate} \ + --momentum=0.9 \ + --weight-decay=1e-4 \ + --workers=4 \ + --gpu=$RANK_ID \ + --dist-url='tcp://127.0.0.1:50011' \ + --dist-backend 'hccl' \ + --multiprocessing-distributed \ + --world_size=1 \ + --device='npu' \ + --nEpochs=${train_epochs} \ + --resume ${pth_path} \ + --amp \ + --batchSize=${batch_size} > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & +done +wait + + +##################获取训练数据################ +# 训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +#结果打印,不需要修改 +echo "------------------ Final result ------------------" + +#输出训练精度,需要模型审视修改 +train_accuracy=`grep -a 'PSNR_predicted' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk -F "PSNR_predicted=" 'NR==1{print $NF}'|awk -F " " '{print $1}'` +#打印,不需要修改 +echo "Final 2xPSNR_predicted : ${train_accuracy}" +echo "E2E Training Duration sec : $e2e_time" + +# 训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'acc' + + +# 最后一个迭代loss值,不需要修改 +ActualLoss=`grep Test ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${ASCEND_DEVICE_ID}.log | awk '{print $8}' | awk 'END {print}'` + +#关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainAccuracy = ${train_accuracy}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualLoss = ${ActualLoss}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log \ No newline at end of file diff --git a/PyTorch/contrib/cv/others/VDSR/test/train_finetune_1p.sh b/PyTorch/contrib/cv/others/VDSR/test/train_finetune_1p.sh new file mode 100644 index 0000000000000000000000000000000000000000..e668449b8b48d636058fc5004293faa5410e2241 --- /dev/null +++ b/PyTorch/contrib/cv/others/VDSR/test/train_finetune_1p.sh @@ -0,0 +1,162 @@ +#!/bin/bash + +################基础配置参数,需要模型审视修改################## +# 必选字段(必须在此处定义的参数): Network batch_size RANK_SIZE +# 网络名称,同目录名称 +Network="VDSR" +# 训练batch_size +batch_size=128 +# 训练使用的npu卡数 +export RANK_SIZE=1 +# 数据集路径,保持为空,不需要修改 +data_path="" +# 测试集路径,保持为空,不需要修改 +valdata="" +# checkpoint文件路径,以实际路径为准 +pth_path="" +# 训练epoch +train_epochs=50 +# 学习率 +learning_rate=0.1 +# 指定训练所使用的npu device卡id +device_id=0 + +# 参数校验,data_path为必传参数,其他参数的增删由模型自身决定;此处新增参数需在上面有定义并赋值 +for para in $* +do + if [[ $para == --device_id* ]];then + device_id=`echo ${para#*=}` + elif [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + elif [[ $para == --valdata* ]];then + valdata=`echo ${para#*=}` + elif [[ $para == --pth_path* ]];then + pth_path=`echo ${para#*=}` + fi +done + +# 校验是否传入data_path,不需要修改 +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path\" must be confing" + exit 1 +fi + +# 校验是否传入 pth_path , 验证脚本需要传入此参数 +if [[ $pth_path == "" ]];then + echo "[Error] para \"pth_path\" must be confing" + exit 1 +fi + +# 校验是否指定了device_id,分动态分配device_id与手动指定device_id,此处不需要修改 +if [ $ASCEND_DEVICE_ID ];then + echo "device id is ${ASCEND_DEVICE_ID}" +elif [ ${device_id} ];then + export ASCEND_DEVICE_ID=${device_id} + echo "device id is ${ASCEND_DEVICE_ID}" +else + "[Error] device id must be config" + exit 1 +fi + + + +###############指定训练脚本执行路径############### +# cd到与test文件夹同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径 +cur_path=`pwd` +cur_path_last_dirname=${cur_path##*/} +if [ x"${cur_path_last_dirname}" == x"test" ];then + test_path_dir=${cur_path} + cd .. + cur_path=`pwd` +else + test_path_dir=${cur_path}/test +fi + + +#################创建日志输出目录,不需要修改################# +if [ -d ${test_path_dir}/output/${ASCEND_DEVICE_ID} ];then + rm -rf ${test_path_dir}/output/${ASCEND_DEVICE_ID} + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +else + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +fi + + +#################启动训练脚本################# +#训练开始时间,不需要修改 +start_time=$(date +%s) +# 非平台场景时source 环境变量 +check_etp_flag=`env | grep etp_running_flag` +etp_flag=`echo ${check_etp_flag#*=}` +if [ x"${etp_flag}" != x"true" ];then + source ${test_path_dir}/env_npu.sh +fi +nohup python3.7 ./main.py \ + --data_path=${data_path} \ + --valdata=${valdata} \ + --addr=$(hostname -I |awk '{print $1}') \ + --seed=49 \ + --lr=${learning_rate} \ + --momentum=0.9 \ + --weight-decay=1e-4 \ + --workers=4 \ + --world_size=1 \ + --device='npu' \ + --gpu=${ASCEND_DEVICE_ID} \ + --dist-url='tcp://127.0.0.1:50021' \ + --dist-backend 'hccl' \ + --nEpochs=${train_epochs} \ + --amp \ + --pretrained=${pth_path} \ + --batchSize=${batch_size} > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & + +wait + + +##################获取训练数据################ +#训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +#结果打印,不需要修改 +echo "------------------ Final result ------------------" +#输出性能FPS,需要模型审视修改 +FPS=`grep -a 'FPS' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk -F " " '{print $11}'|awk 'END {print}'` +#打印,不需要修改 +echo "Final Performance images/sec : $FPS" + +#输出训练精度,需要模型审视修改 +train_accuracy=`grep -a 'PSNR_predicted' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk -F "PSNR_predicted=" 'NR==1{print $NF}'|awk -F " " '{print $1}'` +#打印,不需要修改 +echo "Final 2xPSNR_predicted : ${train_accuracy}" +echo "E2E Training Duration sec : $e2e_time" + +#性能看护结果汇总 +#训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'acc' + +##获取性能数据,不需要修改 +#吞吐量 +ActualFPS=${FPS} +#单迭代训练时长 +TrainingTime=`awk 'BEGIN{printf "%.2f\n", '${batch_size}'*1000/'${FPS}'}'` + +#从train_$ASCEND_DEVICE_ID.log提取Loss到train_${CaseName}_loss.txt中,需要根据模型审视 +grep Epoch ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log|grep -v Test|awk -F "Loss" '{print $NF}' | awk -F " " '{print $2}' >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt + +#最后一个迭代loss值,不需要修改 +ActualLoss=`awk 'END {print}' ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt` + +#关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainingTime = ${TrainingTime}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainAccuracy = ${train_accuracy}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualLoss = ${ActualLoss}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log \ No newline at end of file diff --git a/PyTorch/contrib/cv/others/VDSR/test/train_full_1p.sh b/PyTorch/contrib/cv/others/VDSR/test/train_full_1p.sh new file mode 100644 index 0000000000000000000000000000000000000000..882669a8796b57249c6e500ff386c1d3179b9f69 --- /dev/null +++ b/PyTorch/contrib/cv/others/VDSR/test/train_full_1p.sh @@ -0,0 +1,150 @@ +#!/bin/bash + +################基础配置参数,需要模型审视修改################## +# 必选字段(必须在此处定义的参数): Network batch_size RANK_SIZE +# 网络名称,同目录名称 +Network="VDSR" +# 训练batch_size +batch_size=128 +# 训练使用的npu卡数 +export RANK_SIZE=1 +# 数据集路径,保持为空,不需要修改 +data_path="" +# 测试集路径,保持为空,不需要修改 +valdata="" +# 训练epoch +train_epochs=50 +# 学习率 +learning_rate=0.1 +# 指定训练所使用的npu device卡id +device_id=4 + + +# 参数校验,data_path为必传参数,其他参数的增删由模型自身决定;此处新增参数需在上面有定义并赋值 +for para in $* +do + if [[ $para == --device_id* ]];then + device_id=`echo ${para#*=}` + elif [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + elif [[ $para == --valdata* ]];then + valdata=`echo ${para#*=}` + fi +done + +# 校验是否传入data_path,不需要修改 +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path\" must be confing" + exit 1 +fi +# 校验是否指定了device_id,分动态分配device_id与手动指定device_id,此处不需要修改 +if [ $ASCEND_DEVICE_ID ];then + echo "device id is ${ASCEND_DEVICE_ID}" +elif [ ${device_id} ];then + export ASCEND_DEVICE_ID=${device_id} + echo "device id is ${ASCEND_DEVICE_ID}" +else + "[Error] device id must be config" + exit 1 +fi + + + +###############指定训练脚本执行路径############### +# cd到与test文件夹同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径 +cur_path=`pwd` +cur_path_last_dirname=${cur_path##*/} +if [ x"${cur_path_last_dirname}" == x"test" ];then + test_path_dir=${cur_path} + cd .. + cur_path=`pwd` +else + test_path_dir=${cur_path}/test +fi + + +#################创建日志输出目录,不需要修改################# +if [ -d ${test_path_dir}/output/${ASCEND_DEVICE_ID} ];then + rm -rf ${test_path_dir}/output/${ASCEND_DEVICE_ID} + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +else + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +fi + + +#################启动训练脚本################# +#训练开始时间,不需要修改 +start_time=$(date +%s) +# 非平台场景时source 环境变量 +check_etp_flag=`env | grep etp_running_flag` +etp_flag=`echo ${check_etp_flag#*=}` +if [ x"${etp_flag}" != x"true" ];then + source ${test_path_dir}/env_npu.sh +fi + nohup python3.7 ./main.py \ + --data_path=${data_path} \ + --valdata=${valdata} \ + --addr=$(hostname -I |awk '{print $1}') \ + --seed=49 \ + --lr=${learning_rate} \ + --momentum=0.9 \ + --weight-decay=1e-4 \ + --workers=4 \ + --world_size=1 \ + --device='npu' \ + --gpu=${ASCEND_DEVICE_ID} \ + --dist-url='tcp://127.0.0.1:50021' \ + --dist-backend 'hccl' \ + --nEpochs=${train_epochs} \ + --amp \ + --batchSize=${batch_size} > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & +wait + + +##################获取训练数据################ +#训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +#结果打印,不需要修改 +echo "------------------ Final result ------------------" +#输出性能FPS,需要模型审视修改 +FPS=`grep -a 'FPS' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk -F " " '{print $11}'|awk 'END {print}'` +#打印,不需要修改 +echo "Final Performance images/sec : $FPS" + +#输出训练精度,需要模型审视修改 +train_accuracy=`grep -a 'PSNR_predicted' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk -F "PSNR_predicted=" 'NR==1{print $NF}'|awk -F " " '{print $1}'` +#打印,不需要修改 +echo "Final 2xPSNR_predicted : ${train_accuracy}" +echo "E2E Training Duration sec : $e2e_time" + +#性能看护结果汇总 +#训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'acc' + +##获取性能数据,不需要修改 +#吞吐量 +ActualFPS=${FPS} +#单迭代训练时长 +TrainingTime=`awk 'BEGIN{printf "%.2f\n", '${batch_size}'*1000/'${FPS}'}'` + +#从train_$ASCEND_DEVICE_ID.log提取Loss到train_${CaseName}_loss.txt中,需要根据模型审视 +grep Epoch ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log|grep -v Test|awk -F "Loss" '{print $NF}' | awk -F " " '{print $2}' >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt + +#最后一个迭代loss值,不需要修改 +ActualLoss=`awk 'END {print}' ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt` + +#关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainingTime = ${TrainingTime}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainAccuracy = ${train_accuracy}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualLoss = ${ActualLoss}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log \ No newline at end of file diff --git a/PyTorch/contrib/cv/others/VDSR/test/train_full_8p.sh b/PyTorch/contrib/cv/others/VDSR/test/train_full_8p.sh new file mode 100644 index 0000000000000000000000000000000000000000..e8189c2aa3e4e69a26340733929307246d330d21 --- /dev/null +++ b/PyTorch/contrib/cv/others/VDSR/test/train_full_8p.sh @@ -0,0 +1,148 @@ +#!/bin/bash + +################基础配置参数,需要模型审视修改################## +# 必选字段(必须在此处定义的参数): Network batch_size RANK_SIZE +# 网络名称,同目录名称 +Network="VDSR" +# 训练batch_size +batch_size=384 +# 训练使用的npu卡数 +export RANK_SIZE=8 +# 数据集路径,保持为空,不需要修改 +data_path="" +# 测试集路径,保持为空,不需要修改 +valdata="" +# 训练epoch +train_epochs=50 +# 学习率 +learning_rate=0.3 + + + +# 参数校验,data_path为必传参数,其他参数的增删由模型自身决定;此处新增参数需在上面有定义并赋值 +for para in $* +do + if [[ $para == --workers* ]];then + workers=`echo ${para#*=}` + elif [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + elif [[ $para == --valdata* ]];then + valdata=`echo ${para#*=}` + fi +done + +# 校验是否传入data_path,不需要修改 +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path\" must be confing" + exit 1 +fi + + +###############指定训练脚本执行路径############### +# cd到与test文件夹同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径 +cur_path=`pwd` +cur_path_last_diename=${cur_path##*/} +if [ x"${cur_path_last_diename}" == x"test" ];then + test_path_dir=${cur_path} + cd .. + cur_path=`pwd` +else + test_path_dir=${cur_path}/test +fi + + +#################创建日志输出目录,不需要修改################# +ASCEND_DEVICE_ID=0 +if [ -d ${test_path_dir}/output/${ASCEND_DEVICE_ID} ];then + rm -rf ${test_path_dir}/output/${ASCEND_DEVICE_ID} + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +else + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +fi + + +#################启动训练脚本################# +#训练开始时间,不需要修改 +start_time=$(date +%s) +# 非平台场景时source 环境变量 +check_etp_flag=`env | grep etp_running_flag` +etp_flag=`echo ${check_etp_flag#*=}` +if [ x"${etp_flag}" != x"true" ];then + source ${test_path_dir}/env_npu.sh +fi + +RANK_ID_START=0 +KERNEL_NUM=$(($(nproc)/8)) +for((RANK_ID=$RANK_ID_START;RANK_ID<$((RANK_SIZE+RANK_ID_START));RANK_ID++)) +do + PID_START=$((KERNEL_NUM * RANK_ID)) + PID_END=$((PID_START + KERNEL_NUM - 1)) + nohup taskset -c $PID_START-$PID_END python3.7 ./main.py -j ${KERNEL_NUM}\ + --data_path=${data_path} \ + --valdata=${valdata} \ + --addr=$(hostname -I |awk '{print $1}') \ + --seed=49 \ + --lr=${learning_rate} \ + --momentum=0.9 \ + --weight-decay=1e-4 \ + --workers=16 \ + --gpu=$RANK_ID \ + --dist-url='tcp://127.0.0.1:50011' \ + --dist-backend 'hccl' \ + --multiprocessing-distributed \ + --world_size=1 \ + --device='npu' \ + --nEpochs=${train_epochs} \ + --amp \ + --batchSize=${batch_size} > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & +done +wait + + +##################获取训练数据################ +#训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +#结果打印,不需要修改 +echo "------------------ Final result ------------------" +#输出性能FPS,需要模型审视修改 +FPS=`grep -a 'FPS' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk -F " " '{print $11}'|awk 'END {print}'` +#打印,不需要修改 +echo "Final Performance images/sec : $FPS" + +#输出训练精度,需要模型审视修改 +train_accuracy=`grep -a 'PSNR_predicted' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk -F "PSNR_predicted=" 'NR==1{print $NF}'|awk -F " " '{print $1}'` +#打印,不需要修改 +echo "Final 2xPSNR_predicted : ${train_accuracy}" +echo "E2E Training Duration sec : $e2e_time" + +#性能看护结果汇总 +#训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'acc' + +##获取性能数据,不需要修改 +#吞吐量 +ActualFPS=${FPS} +#单迭代训练时长 +TrainingTime=`awk 'BEGIN{printf "%.2f\n", '${batch_size}'*1000/'${FPS}'}'` + +#从train_$ASCEND_DEVICE_ID.log提取Loss到train_${CaseName}_loss.txt中,需要根据模型审视 +grep Epoch ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log|grep -v Test|awk -F "Loss" '{print $NF}' | awk -F " " '{print $2}' >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt + +#最后一个迭代loss值,不需要修改 +ActualLoss=`awk 'END {print}' ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt` + +#关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainingTime = ${TrainingTime}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainAccuracy = ${train_accuracy}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualLoss = ${ActualLoss}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log \ No newline at end of file diff --git a/PyTorch/contrib/cv/others/VDSR/test/train_performance_1p.sh b/PyTorch/contrib/cv/others/VDSR/test/train_performance_1p.sh new file mode 100644 index 0000000000000000000000000000000000000000..14566a12e5123c9229b366a07d74cc2a7b4c4e45 --- /dev/null +++ b/PyTorch/contrib/cv/others/VDSR/test/train_performance_1p.sh @@ -0,0 +1,150 @@ +#!/bin/bash + +################基础配置参数,需要模型审视修改################## +# 必选字段(必须在此处定义的参数): Network batch_size RANK_SIZE +# 网络名称,同目录名称 +Network="VDSR" +# 训练batch_size +batch_size=128 +# 训练使用的npu卡数 +export RANK_SIZE=1 +# 数据集路径,保持为空,不需要修改 +data_path="" +# 测试集路径,保持为空,不需要修改 +valdata="" +# 训练epoch +train_epochs=4 +# 学习率 +learning_rate=0.1 +# 指定训练所使用的npu device卡id +device_id=0 + + +# 参数校验,data_path为必传参数,其他参数的增删由模型自身决定;此处新增参数需在上面有定义并赋值 +for para in $* +do + if [[ $para == --device_id* ]];then + device_id=`echo ${para#*=}` + elif [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + elif [[ $para == --valdata* ]];then + valdata=`echo ${para#*=}` + fi +done + +# 校验是否传入data_path,不需要修改 +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path\" must be confing" + exit 1 +fi +# 校验是否指定了device_id,分动态分配device_id与手动指定device_id,此处不需要修改 +if [ $ASCEND_DEVICE_ID ];then + echo "device id is ${ASCEND_DEVICE_ID}" +elif [ ${device_id} ];then + export ASCEND_DEVICE_ID=${device_id} + echo "device id is ${ASCEND_DEVICE_ID}" +else + "[Error] device id must be config" + exit 1 +fi + + + +###############指定训练脚本执行路径############### +# cd到与test文件夹同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径 +cur_path=`pwd` +cur_path_last_dirname=${cur_path##*/} +if [ x"${cur_path_last_dirname}" == x"test" ];then + test_path_dir=${cur_path} + cd .. + cur_path=`pwd` +else + test_path_dir=${cur_path}/test +fi + + +#################创建日志输出目录,不需要修改################# +if [ -d ${test_path_dir}/output/${ASCEND_DEVICE_ID} ];then + rm -rf ${test_path_dir}/output/${ASCEND_DEVICE_ID} + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +else + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +fi + + +#################启动训练脚本################# +#训练开始时间,不需要修改 +start_time=$(date +%s) +# 非平台场景时source 环境变量 +check_etp_flag=`env | grep etp_running_flag` +etp_flag=`echo ${check_etp_flag#*=}` +if [ x"${etp_flag}" != x"true" ];then + source ${test_path_dir}/env_npu.sh +fi + nohup python3.7 ./main.py \ + --data_path=${data_path} \ + --valdata=${valdata} \ + --addr=$(hostname -I |awk '{print $1}') \ + --seed=49 \ + --lr=${learning_rate} \ + --momentum=0.9 \ + --weight-decay=1e-4 \ + --workers=4 \ + --world_size=1 \ + --device='npu' \ + --gpu=${ASCEND_DEVICE_ID} \ + --dist-url='tcp://127.0.0.1:50021' \ + --dist-backend 'hccl' \ + --nEpochs=${train_epochs} \ + --amp \ + --batchSize=${batch_size} > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & +wait + + +##################获取训练数据################ +#训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +#结果打印,不需要修改 +echo "------------------ Final result ------------------" +#输出性能FPS,需要模型审视修改 +FPS=`grep -a 'FPS' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk -F " " '{print $11}'|awk 'END {print}'` +#打印,不需要修改 +echo "Final Performance images/sec : $FPS" + +#输出训练精度,需要模型审视修改 +train_accuracy=`grep -a 'PSNR_predicted' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk -F "PSNR_predicted=" 'NR==1{print $NF}'|awk -F " " '{print $1}'` +#打印,不需要修改 +echo "Final 2xPSNR_predicted : ${train_accuracy}" +echo "E2E Training Duration sec : $e2e_time" + +#性能看护结果汇总 +#训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'acc' + +##获取性能数据,不需要修改 +#吞吐量 +ActualFPS=${FPS} +#单迭代训练时长 +TrainingTime=`awk 'BEGIN{printf "%.2f\n", '${batch_size}'*1000/'${FPS}'}'` + +#从train_$ASCEND_DEVICE_ID.log提取Loss到train_${CaseName}_loss.txt中,需要根据模型审视 +grep Epoch ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log|grep -v Test|awk -F "Loss" '{print $NF}' | awk -F " " '{print $2}' >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt + +#最后一个迭代loss值,不需要修改 +ActualLoss=`awk 'END {print}' ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt` + +#关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainingTime = ${TrainingTime}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainAccuracy = ${train_accuracy}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualLoss = ${ActualLoss}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log \ No newline at end of file diff --git a/PyTorch/contrib/cv/others/VDSR/test/train_performance_8p.sh b/PyTorch/contrib/cv/others/VDSR/test/train_performance_8p.sh new file mode 100644 index 0000000000000000000000000000000000000000..bee61e34a319c751203b034337a649e91619f45c --- /dev/null +++ b/PyTorch/contrib/cv/others/VDSR/test/train_performance_8p.sh @@ -0,0 +1,148 @@ +#!/bin/bash + +################基础配置参数,需要模型审视修改################## +# 必选字段(必须在此处定义的参数): Network batch_size RANK_SIZE +# 网络名称,同目录名称 +Network="VDSR" +# 训练batch_size +batch_size=384 +# 训练使用的npu卡数 +export RANK_SIZE=8 +# 数据集路径,保持为空,不需要修改 +data_path="" +# 测试集路径,保持为空,不需要修改 +valdata="" +# 训练epoch +train_epochs=4 +# 学习率 +learning_rate=0.3 + + + +# 参数校验,data_path为必传参数,其他参数的增删由模型自身决定;此处新增参数需在上面有定义并赋值 +for para in $* +do + if [[ $para == --workers* ]];then + workers=`echo ${para#*=}` + elif [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + elif [[ $para == --valdata* ]];then + valdata=`echo ${para#*=}` + fi +done + +# 校验是否传入data_path,不需要修改 +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path\" must be confing" + exit 1 +fi + + +###############指定训练脚本执行路径############### +# cd到与test文件夹同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径 +cur_path=`pwd` +cur_path_last_diename=${cur_path##*/} +if [ x"${cur_path_last_diename}" == x"test" ];then + test_path_dir=${cur_path} + cd .. + cur_path=`pwd` +else + test_path_dir=${cur_path}/test +fi + + +#################创建日志输出目录,不需要修改################# +ASCEND_DEVICE_ID=0 +if [ -d ${test_path_dir}/output/${ASCEND_DEVICE_ID} ];then + rm -rf ${test_path_dir}/output/${ASCEND_DEVICE_ID} + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +else + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +fi + + +#################启动训练脚本################# +#训练开始时间,不需要修改 +start_time=$(date +%s) +# 非平台场景时source 环境变量 +check_etp_flag=`env | grep etp_running_flag` +etp_flag=`echo ${check_etp_flag#*=}` +if [ x"${etp_flag}" != x"true" ];then + source ${test_path_dir}/env_npu.sh +fi + +RANK_ID_START=0 +KERNEL_NUM=$(($(nproc)/8)) +for((RANK_ID=$RANK_ID_START;RANK_ID<$((RANK_SIZE+RANK_ID_START));RANK_ID++)) +do + PID_START=$((KERNEL_NUM * RANK_ID)) + PID_END=$((PID_START + KERNEL_NUM - 1)) + nohup taskset -c $PID_START-$PID_END python3.7 ./main.py -j ${KERNEL_NUM}\ + --data_path=${data_path} \ + --valdata=${valdata} \ + --addr=$(hostname -I |awk '{print $1}') \ + --seed=49 \ + --lr=${learning_rate} \ + --momentum=0.9 \ + --weight-decay=1e-4 \ + --workers=16 \ + --gpu=$RANK_ID \ + --dist-url='tcp://127.0.0.1:50011' \ + --dist-backend 'hccl' \ + --multiprocessing-distributed \ + --world_size=1 \ + --device='npu' \ + --nEpochs=${train_epochs} \ + --amp \ + --batchSize=${batch_size} > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & +done +wait + + +##################获取训练数据################ +#训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +#结果打印,不需要修改 +echo "------------------ Final result ------------------" +#输出性能FPS,需要模型审视修改 +FPS=`grep -a 'FPS' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk -F " " '{print $11}'|awk 'END {print}'` +#打印,不需要修改 +echo "Final Performance images/sec : $FPS" + +#输出训练精度,需要模型审视修改 +train_accuracy=`grep -a 'PSNR_predicted' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk -F "PSNR_predicted=" 'NR==1{print $NF}'|awk -F " " '{print $1}'` +#打印,不需要修改 +echo "Final 2xPSNR_predicted : ${train_accuracy}" +echo "E2E Training Duration sec : $e2e_time" + +#性能看护结果汇总 +#训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'acc' + +##获取性能数据,不需要修改 +#吞吐量 +ActualFPS=${FPS} +#单迭代训练时长 +TrainingTime=`awk 'BEGIN{printf "%.2f\n", '${batch_size}'*1000/'${FPS}'}'` + +#从train_$ASCEND_DEVICE_ID.log提取Loss到train_${CaseName}_loss.txt中,需要根据模型审视 +grep Epoch ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log|grep -v Test|awk -F "Loss" '{print $NF}' | awk -F " " '{print $2}' >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt + +#最后一个迭代loss值,不需要修改 +ActualLoss=`awk 'END {print}' ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt` + +#关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainingTime = ${TrainingTime}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainAccuracy = ${train_accuracy}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualLoss = ${ActualLoss}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log \ No newline at end of file