diff --git a/PyTorch/contrib/cv/video/SiamRPN/README.md b/PyTorch/contrib/cv/video/SiamRPN/README.md index 34d85a548b12652c49977af5de09a74144c9c6fc..e1586346ab084f37d454ef56f6ae592017d566e3 100644 --- a/PyTorch/contrib/cv/video/SiamRPN/README.md +++ b/PyTorch/contrib/cv/video/SiamRPN/README.md @@ -1,17 +1,19 @@ -# SiamRPN++ for PyTorch - -- [概述](https://gitee.com/ascend/docs-openmind/blob/master/guide/modelzoo/pytorch_model/tutorials/概述.md) -- [准备训练环境](https://gitee.com/ascend/docs-openmind/blob/master/guide/modelzoo/pytorch_model/tutorials/准备训练环境.md) -- [开始训练](https://gitee.com/ascend/docs-openmind/blob/master/guide/modelzoo/pytorch_model/tutorials/开始训练.md) -- [训练结果展示](https://gitee.com/ascend/docs-openmind/blob/master/guide/modelzoo/pytorch_model/tutorials/训练结果展示.md) -- [版本说明](https://gitee.com/ascend/docs-openmind/blob/master/guide/modelzoo/pytorch_model/tutorials/版本说明.md) +# SiamRPN/SiamRPN++ for PyTorch # 概述 ## 简述 +### SiamRPN + +孪生候选区域生成网络(Siamese region proposal network),简称Siamese-RPN,它能够利用大尺度的图像对离线端到端训练。结构包含用于特征提取的孪生子网络(Siamese subnetwork)和候选区域生成网络(region proposal subnetwork),其中候选区域生成网络包含分类和回归两条支路。在跟踪阶段,模型被构造成为单样本检测任务(one-shot detection task)。模型预先计算孪生子网络中的模板支路,也就是第一帧,并且将它构造成一个检测支路中区域提取网络里面的一个卷积层,用于在线跟踪。得益于这些改良,传统的多尺度测试和在线微调可以被舍弃,这样做也大大提高了速度。Siamese-RPN跑出了160FPS的速度,并且在VOT2015,VOT2016和VOT2017上取得了领先的成绩。 + +### SiamRPN++ + SiamRPN++是一个由ResNet架构驱动的Siam跟踪器,其内部采用一种简单而有效的采样策略来打破空间不变性的限制。网络中使用一种分层的特征聚合结构用于互相关操作,这有助于跟踪器根据在多个层次上学习到的特征预测相似度图。作为一种高效的视觉跟踪模型,该模型在跟踪精度方面达到了新的水平,同时以35帧/秒的速度高效运行。 +### 模型实现 + - 参考实现: ``` @@ -52,6 +54,11 @@ SiamRPN++是一个由ResNet架构驱动的Siam跟踪器,其内部采用一种 pip install -r requirements.txt ``` +- 构建插件 + + ``` + python3 setup.py build_ext --inplace + ``` ## 准备数据集 @@ -179,12 +186,15 @@ SiamRPN++是一个由ResNet架构驱动的Siam跟踪器,其内部采用一种 > **说明:** >该数据集的训练过程脚本只作为一种参考示例。 +## 获取预训练模型(可选) + +请参考原始仓(源码实现)readme进行预训练模型获取,下载使用配置对应模型的预训练权重,存放在源码包根目录下的`./pretrained_models`下。 -## 获取预训练模型 +## 配置文件获取 -**1.配置文件获取**:将原始仓(源码实现)中`experiments/`目录下的文件下载到本工程中的`./pysot-master/experiments`/路径下,本工程使用的配置文件为原始仓中的`./pysot/experiments/siamrpn_r50_l234_dwxcorr_8gpu/config.yaml`文件。 +**1.SiamRPN: **该模型单卡训练配置路径为`./pysot-master/experiments/siamrpn_alexnet/config.yaml`,多卡训练配置路径为`./pysot-master/experiments/siamrpn_alexnet_8p/config.yaml`。 -**2.预训练权重获取**:请参考原始仓(源码实现)readme进行预训练模型获取,下载siamrpn_r50_l234_dwxcorr对应模型的预训练权重model.pth,存放在源码包根目录下的`./pretrained_models`下,并重命名为resnet50.model。 +**2. SiamRPN++:**将源仓中`experiments/`目录下的文件下载到本工程中的`./pysot-master/experiments`/路径下,本工程使用的配置文件为原始仓中的`./pysot/experiments/siamrpn_r50_l234_dwxcorr_8gpu/config.yaml`文件。 # 开始训练 @@ -198,52 +208,61 @@ SiamRPN++是一个由ResNet架构驱动的Siam跟踪器,其内部采用一种 2. 运行训练脚本。 + **注意:训练前请将训练脚本内的模型配置路径修改为所需的模型。** + 该模型支持单机单卡训练和单机8卡训练。 - - 单机单卡训练 +- 单机单卡训练 - 启动单卡训练。 + 启动单卡训练。 - ``` - bash ./test/train_full_1p.sh # 单卡精度 - - bash ./test/train_performance_1p.sh # 单卡性能 - ``` + - 测试SiamRPN模型训练性能时,由于前10个epoch是训练主干网络,不具有参考性。可以resume from 8P训练的第10个epoch的权重文件(`config`文件中已经配置),以加快测试速度。 - - 单机8卡训练 + ``` + bash ./test/train_full_1p.sh # 单卡精度 + bash ./test/train_performance_1p.sh # 单卡性能 + ``` - 启动8卡训练。 +- 单机8卡训练 - ``` - bash ./test/train_full_8p.sh # 8卡精度 - - bash ./test/train_performance_8p.sh # 8卡性能 - ``` + 启动8卡训练。 - - 单机8卡评测 + ``` + bash ./test/train_full_8p.sh # 8卡精度 + bash ./test/train_performance_8p.sh # 8卡性能 + ``` - 启动8卡评测。 + - 单机8卡评测 - ``` - bash ./test/train_eval_8p.sh - ``` + 启动8卡评测。 - 模型训练脚本参数说明如下。 - ``` + bash ./test/train_eval_8p.sh # 8卡评测 + ``` + + 模型训练脚本参数说明如下。 + + ``` 公共参数: --seed //随机数种子设置 --cfg //参数配置 --is_performance //设置是否进行性能测试 --max_step //设置最大的迭代数 - ``` + ``` - 单卡训练完成后,权重文件保存在SiamRPN/pysot-master/snapshot_1p,8P的权重文件保存在SiamRPN/pysot-master/snapshot_8p下,并输出模型训练精度和性能信息。 +单卡训练完成后,权重文件保存在SiamRPN/pysot-master/snapshot_1p,8P的权重文件保存在SiamRPN/pysot-master/snapshot_8p下,并输出模型训练精度和性能信息。 # 训练结果展示 -**表 2** 训练结果展示表 +### 1. SiamRPN + +| NAME | EAO | Accuracy | Robustness | FPS | Epochs | BatchSize | AMP_Type | +| :--------: | :---: | :------: | :--------: | :--: | :----: | :-------: | :------: | +| 1p-NPU-e11 | - | - | - | 524 | 1 | 550 | O1 | +| 8p-NPU-e50 | 0.325 | 0.579 | 0.322 | 3638 | 50 | 550 | O1 | + +### 2. SiamRPN++ | NAME | FPS↑ | Accuracy↑ | Robustness**↓** | EAO↑ | AMP_Type | Torch_Version | | :--------: | :---: | :----: | :------: | :---: | :-----: | :-----: | @@ -255,9 +274,11 @@ SiamRPN++是一个由ResNet架构驱动的Siam跟踪器,其内部采用一种 ## 变更 - 2023.03.13:更新readme,重新发布。 +2023.04.17:更新readme,添加SiamRPN模型训练相关内容 + +2023.03.13:更新readme,重新发布。 - 2020.11.10:首次发布。 +2020.11.10:首次发布。 ## FAQ diff --git a/PyTorch/contrib/cv/video/SiamRPN/pysot-master/experiments/siamrpn_alexnet/config.yaml b/PyTorch/contrib/cv/video/SiamRPN/pysot-master/experiments/siamrpn_alexnet/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8fe745286378415c0acc6e494ace4c38bcd9e74b --- /dev/null +++ b/PyTorch/contrib/cv/video/SiamRPN/pysot-master/experiments/siamrpn_alexnet/config.yaml @@ -0,0 +1,88 @@ +META_ARC: "SiamRPN-8P" + +BACKBONE: + TYPE: "alexnet" + KWARGS: + width_mult: 1.0 + PRETRAINED: "./pretrained_models/alexnet-bn.pth" + TRAIN_LAYERS: ["layer4", "layer5"] + TRAIN_EPOCH: 10 + LAYERS_LR: 1.0 + +ADJUST: + ADJUST: False + +RPN: + TYPE: "UPChannelRPN" + KWARGS: + anchor_num: 5 + feature_in: 256 + +MASK: + MASK: False + +ANCHOR: + STRIDE: 8 + RATIOS: [0.33, 0.5, 1, 2, 3] + SCALES: [8] + ANCHOR_NUM: 5 + +TRACK: + TYPE: "SiamRPNTracker" + PENALTY_K: 0.12 #0.16 #0.12 + WINDOW_INFLUENCE: 0.38 #0.40 #0.38 + LR: 0.32 # 0.30 # 0.32 + EXEMPLAR_SIZE: 127 + INSTANCE_SIZE: 287 # 271 or 255 + BASE_SIZE: 0 + CONTEXT_AMOUNT: 0.5 + +TRAIN: + EPOCH: 11 + NUM_WORKERS: 8 + PRINT_FREQ: 100 + START_EPOCH: 0 + BATCH_SIZE: 550 + BASE_SIZE: 0 + OUTPUT_SIZE: 17 + BASE_LR: 0.005 + CLS_WEIGHT: 1.0 + LOC_WEIGHT: 1.2 + RESUME: "./snapshot_8p/checkpoint_e10.pth" + + LR: + TYPE: 'log' + KWARGS: + start_lr: 0.01 + end_lr: 0.001 + LR_WARMUP: + TYPE: 'step' + EPOCH: 5 + KWARGS: + start_lr: 0.001 + end_lr: 0.005 + step: 1 + +DATASET: + NAMES: + - 'VID' + - 'YOUTUBEBB' + - 'COCO' + - 'DET' + + TEMPLATE: + SHIFT: 4 + SCALE: 0.05 + BLUR: 0.0 + FLIP: 0.0 + COLOR: 1.0 + + SEARCH: + SHIFT: 64 + SCALE: 0.18 + BLUR: 0.2 + FLIP: 0.0 + COLOR: 1.0 + + NEG: 0.2 + GRAY: 0.0 diff --git a/PyTorch/contrib/cv/video/SiamRPN/pysot-master/experiments/siamrpn_alexnet_8npu/config.yaml b/PyTorch/contrib/cv/video/SiamRPN/pysot-master/experiments/siamrpn_alexnet_8npu/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b78ac9c757a3caecd91feb45e9cba7be522c44df --- /dev/null +++ b/PyTorch/contrib/cv/video/SiamRPN/pysot-master/experiments/siamrpn_alexnet_8npu/config.yaml @@ -0,0 +1,87 @@ +META_ARC: "SiamRPN-8P" + +BACKBONE: + TYPE: "alexnet" + KWARGS: + width_mult: 1.0 + PRETRAINED: "./pretrained_models/alexnet-bn.pth" + TRAIN_LAYERS: ["layer4", "layer5"] + TRAIN_EPOCH: 10 + LAYERS_LR: 1.0 + +ADJUST: + ADJUST: False + +RPN: + TYPE: "UPChannelRPN" + KWARGS: + anchor_num: 5 + feature_in: 256 + +MASK: + MASK: False + +ANCHOR: + STRIDE: 8 + RATIOS: [0.33, 0.5, 1, 2, 3] + SCALES: [8] + ANCHOR_NUM: 5 + +TRACK: + TYPE: "SiamRPNTracker" + PENALTY_K: 0.12 #0.16 #0.12 + WINDOW_INFLUENCE: 0.38 #0.40 #0.38 + LR: 0.32 # 0.30 # 0.32 + EXEMPLAR_SIZE: 127 + INSTANCE_SIZE: 287 # 271 or 255 + BASE_SIZE: 0 + CONTEXT_AMOUNT: 0.5 + +TRAIN: + EPOCH: 50 + NUM_WORKERS: 8 + PRINT_FREQ: 100 + START_EPOCH: 0 + BATCH_SIZE: 550 + BASE_SIZE: 0 + OUTPUT_SIZE: 17 + BASE_LR: 0.005 + CLS_WEIGHT: 1.0 + LOC_WEIGHT: 1.2 + + LR: + TYPE: 'log' + KWARGS: + start_lr: 0.01 + end_lr: 0.001 + LR_WARMUP: + TYPE: 'step' + EPOCH: 5 + KWARGS: + start_lr: 0.001 + end_lr: 0.005 + step: 1 + +DATASET: + NAMES: + - 'VID' + - 'YOUTUBEBB' + - 'COCO' + - 'DET' + + TEMPLATE: + SHIFT: 4 + SCALE: 0.05 + BLUR: 0.0 + FLIP: 0.0 + COLOR: 1.0 + + SEARCH: + SHIFT: 64 + SCALE: 0.18 + BLUR: 0.2 + FLIP: 0.0 + COLOR: 1.0 + + NEG: 0.2 + GRAY: 0.0 diff --git a/PyTorch/contrib/cv/video/SiamRPN/pysot-master/tools_1p/train.py b/PyTorch/contrib/cv/video/SiamRPN/pysot-master/tools_1p/train.py index 3a5b143e875d1e23212534f663cb9c84337981ea..8f9e3433088eb279e37eeec4591d551a3a90aa28 100644 --- a/PyTorch/contrib/cv/video/SiamRPN/pysot-master/tools_1p/train.py +++ b/PyTorch/contrib/cv/video/SiamRPN/pysot-master/tools_1p/train.py @@ -27,13 +27,14 @@ import json import random import numpy as np import apex + try: from apex import amp except: print('no apex') import torch -if torch.__version__>= '1.8': +if torch.__version__ >= '1.8': import torch_npu import torch.nn as nn @@ -84,13 +85,18 @@ def build_data_loader(): logger.info("build dataset done") train_sampler = None + if get_world_size() > 1: - train_sampler = DistributedSampler(train_dataset) + train_sampler = DistributedSampler(train_dataset, + num_replicas=get_world_size(), + rank=get_rank()) + train_loader = DataLoader(train_dataset, batch_size=cfg.TRAIN.BATCH_SIZE, num_workers=cfg.TRAIN.NUM_WORKERS, pin_memory=False, sampler=train_sampler) + return train_loader @@ -102,8 +108,8 @@ def build_opt_lr(model): m.eval() trainable_params = [] - trainable_params += [{'params': filter(lambda x: x.requires_grad, - model.backbone.parameters()), + + trainable_params += [{'params': model.backbone.parameters(), 'lr': cfg.BACKBONE.LAYERS_LR * cfg.TRAIN.BASE_LR}] if cfg.ADJUST.ADJUST: @@ -120,9 +126,11 @@ def build_opt_lr(model): if cfg.REFINE.REFINE: trainable_params += [{'params': model.refine_head.parameters(), 'lr': cfg.TRAIN.LR.BASE_LR}] + optimizer = torch.optim.SGD(trainable_params, momentum=cfg.TRAIN.MOMENTUM, weight_decay=cfg.TRAIN.WEIGHT_DECAY) + lr_scheduler = build_lr_scheduler(optimizer, epochs=cfg.TRAIN.EPOCH) lr_scheduler.step(cfg.TRAIN.START_EPOCH) return optimizer, lr_scheduler @@ -139,7 +147,7 @@ def train(train_loader, model, optimizer, lr_scheduler): world_size = get_world_size() num_per_epoch = len(train_loader.dataset) // \ - cfg.TRAIN.EPOCH // (cfg.TRAIN.BATCH_SIZE * world_size) + cfg.TRAIN.EPOCH // (cfg.TRAIN.BATCH_SIZE * world_size) start_epoch = cfg.TRAIN.START_EPOCH epoch = start_epoch @@ -168,12 +176,20 @@ def train(train_loader, model, optimizer, lr_scheduler): if cfg.BACKBONE.TRAIN_EPOCH == epoch: logger.info('start training backbone.') - for layer in cfg.BACKBONE.TRAIN_LAYERS: - for param in getattr(model.module.backbone, layer).parameters(): - param.requires_grad = True - for m in getattr(model.module.backbone, layer).modules(): - if isinstance(m, nn.BatchNorm2d): - m.train() + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + for layer in cfg.BACKBONE.TRAIN_LAYERS: + for param in getattr(model.module.backbone, layer).parameters(): + param.requires_grad = True + for m in getattr(model.module.backbone, layer).modules(): + if isinstance(m, nn.BatchNorm2d): + m.train() + else: + for layer in cfg.BACKBONE.TRAIN_LAYERS: + for param in getattr(model.backbone, layer).parameters(): + param.requires_grad = True + for m in getattr(model.backbone, layer).modules(): + if isinstance(m, nn.BatchNorm2d): + m.train() logger.info("model\n{}".format(describe(model.module))) @@ -181,7 +197,7 @@ def train(train_loader, model, optimizer, lr_scheduler): cur_lr = lr_scheduler.get_cur_lr() logger.info('epoch: {}'.format(epoch + 1)) - tb_idx = idx + # tb_idx = idx if idx % num_per_epoch == 0 and idx != 0: for idx, pg in enumerate(optimizer.param_groups): @@ -201,7 +217,6 @@ def train(train_loader, model, optimizer, lr_scheduler): reduce_gradients(model) - # clip gradient clip_grad_norm_(model.parameters(), cfg.TRAIN.GRAD_CLIP) optimizer.step() @@ -222,7 +237,6 @@ def train(train_loader, model, optimizer, lr_scheduler): info = "Epoch: [{}][{}/{}] lr: {:.6f}\n".format( epoch + 1, (idx + 1) % num_per_epoch, num_per_epoch, cur_lr) - avgtime = batch_info['batch_time'] + batch_info['data_time'] for cc, (k, v) in enumerate(batch_info.items()): if cc % 2 == 0: @@ -235,8 +249,8 @@ def train(train_loader, model, optimizer, lr_scheduler): print_speed(idx + 1 + start_epoch * num_per_epoch, average_meter.batch_time.avg, cfg.TRAIN.EPOCH * num_per_epoch) - - print('FPS', (28 * 1 / avgtime)) + avgtime = average_meter.batch_time.avg + print('FPS', (cfg.TRAIN.BATCH_SIZE * 1 / avgtime)) end = time.time() if args.is_performance: @@ -245,12 +259,16 @@ def train(train_loader, model, optimizer, lr_scheduler): def main(): + os.environ['RANK'] = str(args.local_rank) + rank, world_size = dist_init() + logger.info("init done") - + # load cfg cfg.merge_from_file(args.cfg) if rank == 0: + if not os.path.exists(cfg.TRAIN.LOG_DIR): os.makedirs(cfg.TRAIN.LOG_DIR) init_log('global', logging.INFO) @@ -268,7 +286,8 @@ def main(): # load pretrained backbone weights if cfg.BACKBONE.PRETRAINED: cur_path = os.path.dirname(os.path.realpath(__file__)) - backbone_path = os.path.join(cur_path, '../../', cfg.BACKBONE.PRETRAINED) + backbone_path = os.path.join( + cur_path, '../../', cfg.BACKBONE.PRETRAINED) load_pretrain(model.backbone, backbone_path) # build dataset loader @@ -278,6 +297,7 @@ def main(): optimizer, lr_scheduler = build_opt_lr(model) # resume training + if cfg.TRAIN.RESUME: logger.info("resume from {}".format(cfg.TRAIN.RESUME)) assert os.path.isfile(cfg.TRAIN.RESUME), \ @@ -287,15 +307,15 @@ def main(): # load pretrain elif cfg.TRAIN.PRETRAINED: load_pretrain(model, cfg.TRAIN.PRETRAINED) - - dist_model = DistModule(model) logger.info(lr_scheduler) logger.info("model prepare done") - dist_model, optimizer = amp.initialize(dist_model, optimizer, opt_level="O1", - loss_scale=32, combine_grad=True) + dist_model, optimizer = amp.initialize(model, optimizer, opt_level="O1", + loss_scale=32) + dist_model = torch.nn.parallel.DistributedDataParallel(dist_model, + device_ids=[rank]) # start training train(train_loader, dist_model, optimizer, lr_scheduler) diff --git a/PyTorch/contrib/cv/video/SiamRPN/pysot-master/tools_8p/train.py b/PyTorch/contrib/cv/video/SiamRPN/pysot-master/tools_8p/train.py index 9ad96c94f5eafff3f76c3b85531eeb5d7db23565..4e456005d35abe9c560e282985b28484866c71c5 100644 --- a/PyTorch/contrib/cv/video/SiamRPN/pysot-master/tools_8p/train.py +++ b/PyTorch/contrib/cv/video/SiamRPN/pysot-master/tools_8p/train.py @@ -34,7 +34,7 @@ except: print('no apex') import torch -if torch.__version__>= '1.8': +if torch.__version__ >= '1.8': import torch_npu import torch.nn as nn @@ -147,7 +147,7 @@ def train(train_loader, model, optimizer, lr_scheduler): world_size = get_world_size() num_per_epoch = len(train_loader.dataset) // \ - cfg.TRAIN.EPOCH // (cfg.TRAIN.BATCH_SIZE * world_size) + cfg.TRAIN.EPOCH // (cfg.TRAIN.BATCH_SIZE * world_size) start_epoch = cfg.TRAIN.START_EPOCH epoch = start_epoch @@ -237,7 +237,7 @@ def train(train_loader, model, optimizer, lr_scheduler): info = "Epoch: [{}][{}/{}] lr: {:.6f}\n".format( epoch + 1, (idx + 1) % num_per_epoch, num_per_epoch, cur_lr) - avgtime = batch_info['batch_time'] + batch_info['data_time'] + for cc, (k, v) in enumerate(batch_info.items()): if cc % 2 == 0: info += ("\t{:s}\t").format( @@ -249,9 +249,10 @@ def train(train_loader, model, optimizer, lr_scheduler): print_speed(idx + 1 + start_epoch * num_per_epoch, average_meter.batch_time.avg, cfg.TRAIN.EPOCH * num_per_epoch) - print('FPS', (28 * 8 / avgtime.item())) + avgtime = average_meter.batch_time.avg + print('FPS', (cfg.TRAIN.BATCH_SIZE * 8 / avgtime)) end = time.time() - + if args.is_performance: if idx == args.max_step: exit() @@ -285,7 +286,8 @@ def main(): # load pretrained backbone weights if cfg.BACKBONE.PRETRAINED: cur_path = os.path.dirname(os.path.realpath(__file__)) - backbone_path = os.path.join(cur_path, '../../', cfg.BACKBONE.PRETRAINED) + backbone_path = os.path.join( + cur_path, '../../', cfg.BACKBONE.PRETRAINED) load_pretrain(model.backbone, backbone_path) # build dataset loader