From 76803f522108711862859dd61db2bed5a6ce7abe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=B4=8B=E6=B4=8B?= <10527367+yang-yang-zhang123456@user.noreply.gitee.com> Date: Tue, 12 Apr 2022 10:56:45 +0000 Subject: [PATCH 1/2] update TensorFlow/contrib/cv/TNT_ID1233_for_TensorFlow/train_cnn_trajectory_2d.py. --- .../train_cnn_trajectory_2d.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/TensorFlow/contrib/cv/TNT_ID1233_for_TensorFlow/train_cnn_trajectory_2d.py b/TensorFlow/contrib/cv/TNT_ID1233_for_TensorFlow/train_cnn_trajectory_2d.py index 669b67ec8..efbeef22a 100644 --- a/TensorFlow/contrib/cv/TNT_ID1233_for_TensorFlow/train_cnn_trajectory_2d.py +++ b/TensorFlow/contrib/cv/TNT_ID1233_for_TensorFlow/train_cnn_trajectory_2d.py @@ -156,9 +156,10 @@ def main(args): for k in range(len(total_batch_x)): if np.sum(total_batch_x[k, 0, :, 1]) == 0: remove_idx.append(k) - total_batch_x = np.delete(total_batch_x, np.array(remove_idx), axis=0) - total_batch_y = np.delete(total_batch_y, np.array(remove_idx), axis=0) - print(len(total_batch_y)) + if len(remove_idx): + total_batch_x = np.delete(total_batch_x, np.array(remove_idx), axis=0) + total_batch_y = np.delete(total_batch_y, np.array(remove_idx), axis=0) + print(len(total_batch_y)) total_batch_x[:, 4:, :, 0] = 10 * total_batch_x[:, 4:, :, 0] temp_X = np.copy(total_batch_x) @@ -915,13 +916,14 @@ def parse_arguments(argv): default='/home/ma-user/modelarts/outputs/train_url_0/temp') parser.add_argument('--triplet_model', type=str, default='/home/ma-user/modelarts/outputs/train_url_0/model_data/20211209-124102/ ') - parser.add_argument('--max_step', type=int,default='20000000') + parser.add_argument('--max_step', type=int,default='2000000') parser.add_argument('--save_dir', type=str, default='/home/ma-user/modelarts/outputs/train_url_0/models/result/model.ckpt') parser.add_argument('--output_path', type=str, default='/home/ma-user/modelarts/outputs/train_url_0/logs/') return parser.parse_args(argv) + if __name__ == '__main__': - main(parse_arguments(sys.argv[1:])) \ No newline at end of file + main(parse_arguments(sys.argv[1:])) -- Gitee From 57b03a73d2be389c9dd7e97623bd0fa4ecc30eab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=B4=8B=E6=B4=8B?= <10527367+yang-yang-zhang123456@user.noreply.gitee.com> Date: Tue, 12 Apr 2022 10:57:40 +0000 Subject: [PATCH 2/2] update TensorFlow/contrib/cv/TNT_ID1233_for_TensorFlow/test/train_full_1p.sh. --- .../test/train_full_1p.sh | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/TensorFlow/contrib/cv/TNT_ID1233_for_TensorFlow/test/train_full_1p.sh b/TensorFlow/contrib/cv/TNT_ID1233_for_TensorFlow/test/train_full_1p.sh index b2c00b6ef..8887efad6 100644 --- a/TensorFlow/contrib/cv/TNT_ID1233_for_TensorFlow/test/train_full_1p.sh +++ b/TensorFlow/contrib/cv/TNT_ID1233_for_TensorFlow/test/train_full_1p.sh @@ -108,25 +108,29 @@ start_time=$(date +%s) # 您的训练数据集在${data_path}路径下,请直接使用这个变量获取 # 您的训练输出目录在${output_path}路径下,请直接使用这个变量获取 # 您的其他基础参数,可以自定义增加,但是batch_size请保留,并且设置正确的值 - +batch_size=32 if [ x"${modelarts_flag}" != x ]; then python3.7 ./train_cnn_trajectory_2d.py \ - --MAT_folder ${data_path}original_data/MOT17Det/mat/ \ + --MAT_folder ${data_path}/dataset/original_data/MOT17Det/mat \ + --img_folder ${data_path}/dataset/original_data/MOT17Det/train \ --temp_folder ${output_path}temp/ \ - --triplet_model ${data_path}model_data/20211209-124102/ \ - --save_dir ${output_path}model_data/traj/model.ckpt \ - --max_step 2001 \ - --output_path ${output_path} + --triplet_model ${data_path}/dataset/model_data/20211209-124102/ \ + --save_dir ${output_path}model_data/traj/ \ + --max_step 15 \ + --output_path ${output_path} >${print_log} 2>&1 + else python3.7 ./train_cnn_trajectory_2d.py \ - --MAT_folder ${data_path}original_data/MOT17Det/mat/ \ + --MAT_folder ${data_path}/dataset/original_data/MOT17Det/mat \ + --img_folder ${data_path}/dataset/original_data/MOT17Det/train \ --temp_folder ${output_path}temp/ \ - --triplet_model ${data_path}model_data/20211209-124102/ \ + --triplet_model ${data_path}/dataset/model_data/20211209-124102/ \ --save_dir ${output_path}model_data/traj/ \ - --max_step 2001 \ - --output_path ${output_path} + --max_step 15 \ + --output_path ${output_path} >${print_log} 2>&1 + fi # 性能相关数据计算 @@ -192,4 +196,5 @@ echo "CaseName = ${CaseName}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}. echo "ActualFPS = ${FPS}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log echo "TrainingTime = ${StepTime}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log echo "ActualLoss = ${ActualLoss}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log -echo "E2ETrainingTime = ${e2e_time}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log \ No newline at end of file +echo "TrainAccuracy = ${train_accuracy}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log -- Gitee