diff --git a/contrib/Faster_R-CNN/README.md b/contrib/Faster_R-CNN/README.md
index b455eadca13ec30166bea25f5a5aac1c6724862c..4f2d2ff75632ccafefbb9d7f7e01bc151ecbd2bc 100644
--- a/contrib/Faster_R-CNN/README.md
+++ b/contrib/Faster_R-CNN/README.md
@@ -1,6 +1,7 @@
# X射线图像焊缝缺陷检测
## 1. 介绍
+### 1.1 简介
在本系统中,目的是基于MindX SDK,在昇腾平台上,开发端到端X射线图像焊缝缺陷检测的参考设计,实现对图像中的焊缝缺陷进行缺陷类别识别的功能,并把可视化结果保存到本地,达到功能要求。
@@ -8,33 +9,7 @@
样例输出:框出并标有缺陷类型与置信度的jpg图片。
-### 1.1 数据集介绍
-
-GDXray是一个公开X射线数据集,其中包括一个关于X射线焊接图像(Welds)的数据,该数据由德国柏林的BAM联邦材料研究和测试研究所收集。
-
-Welds集中W0003 包括了68张焊接公司的X射线图像。本文基于W0003数据集并在焊接专家的帮助下将焊缝和其内部缺陷标注。
-
-数据集下载地址:https://domingomery.ing.puc.cl/material/gdxray/
-
-注:本系统训练使用的数据是由原png图片转为jpg图片,然后经过焊缝裁剪和滑窗裁剪后输入模型训练,推理时所用的图片是已经经过焊缝裁剪的图片。
-
-### 1.2 支持的产品
-
-本项目以昇腾Atlas310卡为主要的硬件平台。
-
-### 1.3 支持的版本
-
-支持的SDK版本为 5.0.0, CANN 版本为 7.0.0, MindSpore版本为1.8。
-
-HDK版本号查询方法,在Atlas产品环境下,运行命令:
-
-```shell
-npu-smi info
-```
-
-可以查询支持SDK的版本号。
-
-### 1.4 软件方案介绍
+软件方案介绍
本方案中,会先进行滑窗裁剪处理,然后将处理好的图片通过 appsrc 插件输入到业务流程中,最终根据Faster—RCNN模型识别得到缺陷类别和置信度生成框输出标有缺陷类别与置信度的jpg图片。
@@ -50,6 +25,37 @@ npu-smi info
| 6 | 结果输出 | 获取检测结果 |
| 7 | 结果可视化 | 将检测结果标注在输入图片上 |
+
+技术实现流程图
+
+
+
+
+
+
+
+### 1.2 支持的产品
+
+本项目以昇腾Atlas310卡为主要的硬件平台。
+
+### 1.3 支持的版本
+
+本样例配套的MxVision版本、CANN版本、Driver/Firmware版本如下所示:
+| MxVision版本 | CANN版本 | Driver/Firmware版本 |
+| --------- | ------------------ | -------------- |
+| 5.0.0 | 7.0.0 | 23.0.0 |
+
+### 1.4 三方依赖
+
+推荐系统为ubuntu 18.04,环境依赖软件和版本如下表:
+
+| 软件名称 | 版本 |
+| :-----------: | :------: |
+| numpy | 1.23.3 |
+| opencv-python | 4.6.0.66 |
+| pycocotools | 2.0.5 |
+| mmcv | 1.7.0 |
+
### 1.5 代码目录结构与说明
本工程名称为 Faster_R-CNN,工程目录如下所示:
@@ -78,7 +84,7 @@ npu-smi info
│ │ └── postprocess.py
│ ├── models
│ │ ├── aipp-configs
-│ │ │ ├── aipp.cfg # sdk做图像预处理aipp配置文件
+│ │ │ ├── aipp.cfg # sdk做图像预处理aipp配置文件
│ │ │ └── aipp_rgb.cfg # opencv做图像预处理aipp配置文件
│ │ ├── conversion-scripts #(需创建)转换前后模型所放的位置
│ │ ├── convert_om.sh # 模型转换相关环境变量配置可参考该文件
@@ -105,65 +111,34 @@ npu-smi info
注:验证时有COCO和VOC两种数据格式是因为原图片经过滑窗裁剪后的小图片是以coco的数据格式进行训练的,而本系统最终采用的验证方式是,将经过推理后得到的小图片的标注框信息还原到未经过滑窗裁剪的图片上,再进行VOC评估。
-### 1.6 技术实现流程图
-
-
-
-
-
-### 1.7 特性及适用场景
+### 1.6 相关约束
-经过测试,在现有数据集的基础上,该项目检测算法可以检测八种焊缝缺陷:气孔、裂纹、夹渣、未熔合、未焊透、咬边、内凹、成形不良,关于缺陷召回率和MAP分数在后续内容中将会提到。本项目属于工业缺陷中焊缝缺陷检测领域,主要针对DR成像设备(数字化X射线成像设备)拍摄金属焊接处成像形成的焊接X射线图像进行缺陷检测。
+经过测试,在现有数据集的基础上,该项目检测算法可以检测八种焊缝缺陷:气孔、裂纹、夹渣、未熔合、未焊透、咬边、内凹、成形不良。本项目属于工业缺陷中焊缝缺陷检测领域,主要针对DR成像设备(数字化X射线成像设备)拍摄金属焊接处成像形成的焊接X射线图像进行缺陷检测。
-## 2. 环境依赖
+## 2. 设置环境变量
-推荐系统为ubuntu 18.04,环境依赖软件和版本如下表:
-
-| 软件名称 | 版本 |
-| :-----------: | :------: |
-| ubantu | 18.04 |
-| MindX SDK | 5.0.0 |
-| Python | 3.9.2 |
-| CANN | 7.0.0 |
-| numpy | 1.23.3 |
-| opencv-python | 4.6.0.66 |
-| pycocotools | 2.0.5 |
-| mmcv | 1.7.0 |
-
-确保环境中正确安装mxVision SDK。
在编译运行项目前,需要设置环境变量:
-
-MindSDK 环境变量:
-
-```shell
-. ${SDK-path}/set_env.sh
```
-
-CANN 环境变量:
-
-```shell
+#设置CANN环境变量(请确认install_path路径是否正确)
. ${ascend-toolkit-path}/set_env.sh
-```
-
-- 环境变量介绍
-```
-SDK-path: mxVision SDK 安装路径。
-ascend-toolkit-path: CANN 安装路径。
-```
+#设置MindX SDK 环境变量,SDK-path为mxVision SDK 安装路径
+. ${SDK-path}/set_env.sh
-## 3. 模型转换
+#查看环境变量
+env
-本项目中采用的模型是 Faster—RCNN模型,参考实现代码:https://www.hiascend.com/zh/software/modelzoo/models/detail/C/8d8b656fe2404616a1f0f491410a224c/1
+```
+## 3. 准备模型
-1. 将训练好的模型 [fasterrcnn_mindspore.air](https://mindx.sdk.obs.cn-north-4.myhuaweicloud.com/mindxsdk-referenceapps%20/contrib/Faster-RCNN/fasterrcnn_mindspore.air) 下载至 ``python/models/conversion-scripts``(文件夹需创建)文件夹下。
+**步骤1** 将训练好的Faster—RCNN模型 [fasterrcnn_mindspore.air](https://mindx.sdk.obs.cn-north-4.myhuaweicloud.com/mindxsdk-referenceapps%20/contrib/Faster-RCNN/fasterrcnn_mindspore.air) 下载至 ``python/models/conversion-scripts``(文件夹需创建)文件夹下。
-2. 将该模型转换为om模型,具体操作为: ``python/models`` 文件夹下,执行指令进行模型转换:
+**步骤2** 将该模型转换为om模型,具体操作为: ``python/models`` 文件夹下,执行指令进行模型转换:
### DVPP模型转换
@@ -179,6 +154,7 @@ bash convert_om.sh conversion-scripts/fasterrcnn_mindspore.air aipp-configs/aipp
**注**:转换后的OPENCV模型会用OpenCV对图片做预处理,然后进行推理,用户可自行进行选择。
+
## 4. 编译与运行
**步骤1** 编译后处理插件
@@ -209,15 +185,15 @@ python3 main.py
1. 准备精度测试所需图片,将[验证集](https://mindx.sdk.obs.cn-north-4.myhuaweicloud.com/mindxsdk-referenceapps%20/contrib/Faster-RCNN/eval.zip)下载到`python/data/eval/`目录下并解压。
-1. 打开`python/pipeline/fasterrcnn_ms_dvpp.pipeline`文件,将第45行(postProcessConfigPath)配置参数改为`../models/fasterrcnn_coco2017_acc_test.cfg`。
+2. 打开`python/pipeline/fasterrcnn_ms_dvpp.pipeline`文件,将第45行(postProcessConfigPath)配置参数改为`../models/fasterrcnn_coco2017_acc_test.cfg`。
-1. 使用dvpp模式对图片进行推理,切换到``python/Main``目录下,执行命令:
+3. 使用dvpp模式对图片进行推理,切换到``python/Main``目录下,执行命令:
```python
python3 main.py --img_path ../data/eval/cocodataset/val2017/ --pipeline_path ../pipeline/fasterrcnn_ms_dvpp.pipeline --model_type dvpp --infer_mode eval --ann_file ../data/eval/cocodataset/annotations/instances_val2017.json
```
-2. 因为涉及到去重处理,每种缺陷需要分开评估精度,切换到``python/Main``目录下,执行命令:
+4. 因为涉及到去重处理,每种缺陷需要分开评估精度,切换到``python/Main``目录下,执行命令:
```python
# 验证气孔精度
@@ -228,11 +204,6 @@ python3 main.py
```
**注**:cat_id为缺陷标签,object_name为对应缺陷名称,在 ``python/models/coco2017.names``可查看缺陷类别。
-
- | 缺陷种类 | AP |
- | :------: | :----: |
- | 气孔 | 0.7251 |
- | 裂纹 | 0.7597 |
## 5. 常见问题
diff --git a/contrib/Faster_R-CNN/build.sh b/contrib/Faster_R-CNN/build.sh
deleted file mode 100644
index 73428ac0284bc8c2e354f59c7d19822712966cd6..0000000000000000000000000000000000000000
--- a/contrib/Faster_R-CNN/build.sh
+++ /dev/null
@@ -1,38 +0,0 @@
-#!/bin/bash
-# Copyright 2022 Huawei Technologies Co., Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-set -e
-CURRENT_PATH="$( cd "$(dirname "$0")" ;pwd -P )"
-
-
-POSTPROCESS_FOLDER=(
- /postprocess/
-)
-
-
-FLAG=0
-for path in ${POSTPROCESS_FOLDER[@]};do
- cd ${CURRENT_PATH}/${path}
- bash build.sh || {
- echo -e "Failed to build postprocess plugin ${path}"
- FLAG=1
- }
-done
-
-
-if [ ${FLAG} -eq 1 ]; then
- exit 1
-fi
-exit 0
diff --git a/contrib/Faster_R-CNN/python/Main/main.py b/contrib/Faster_R-CNN/python/Main/main.py
index e83c7f43467094dde0d335aa65a5201e6519d960..a7283bd8c3dbe4755b109f00ff73a45233a7e8b6 100644
--- a/contrib/Faster_R-CNN/python/Main/main.py
+++ b/contrib/Faster_R-CNN/python/Main/main.py
@@ -48,7 +48,7 @@ def parser_args():
type=str,
required=False,
default="../pipeline/fasterrcnn_ms_dvpp.pipeline",
- help="image file path. The default is 'config/maskrcnn_ms.pipeline'. ")
+ help="image file path. The default is '../pipeline/fasterrcnn_ms_dvpp.pipeline'. ")
parser.add_argument(
"--model_type",
type=str,
diff --git a/contrib/STGCN/README.md b/contrib/STGCN/README.md
index f723a090ca86a176a4695ad8f32b9df84ad913f1..242278e7e700ac3a83fce5028b0ef240dfbce92e 100644
--- a/contrib/STGCN/README.md
+++ b/contrib/STGCN/README.md
@@ -1,21 +1,18 @@
# 城市道路交通预测
## 1 介绍
+### 1.1 简介
-STGCN主要用于交通预测领域,是一种时空卷积网络,解决在交通领域的时间序列预测问题。在定义图上的问题,并用纯卷积结构建立模型,这使得使用更少的参数能带来更快的训练速度。本样例基于MindxSDK开发,是在STGCN模型的基础上对SZ-Taxi数据集进行训练转化,可以对未来一定时段内的交通速度进行预测。通过在SZ-Taxi的测试集上进行推理测试,精度可以达到 MAE 2.81 | RMSE 4.29。该模型在SZ-Taxi数据集上无具体目标精度可以参考,此精度值为官方认可值。
-
+STGCN主要用于交通预测领域,是一种时空卷积网络,解决在交通领域的时间序列预测问题。在定义图上的问题,并用纯卷积结构建立模型,这使得使用更少的参数能带来更快的训练速度。本样例基于MindxSDK开发,是在STGCN模型的基础上对SZ-Taxi数据集进行训练转化,可以对未来一定时段内的交通速度进行预测。
论文原文:https://arxiv.org/abs/1709.04875
STGCN模型GitHub仓库:https://github.com/hazdzz/STGCN
-SZ-Taxi数据集:https://github.com/lehaifeng/T-GCN/tree/master/data
+SZ-Taxi数据集:https://github.com/lehaifeng/T-GCN/tree/master/data
SZ-Taxi数据集包含深圳市的出租车动向,包括道路邻接矩阵和道路交通速度信息。
-### 1.1 支持的产品
-本项目以昇腾Atlas310卡为主要的硬件平台。
-
-### 1.2 软件方案介绍
+软件方案介绍
基于MindX SDK的城市道路交通预测模型的推理流程为:
@@ -29,10 +26,32 @@ SZ-Taxi数据集包含深圳市的出租车动向,包括道路邻接矩阵和
| 2 | 模型推理 | 调用MindX SDK的mxpi_tensorinfer对输入张量进行推理 |
| 3 | 结果输出 | 调用MindX SDK的mxpi_dataserialize和appsink以及pythonAPI的GetProtobuf()函数输出结果 |
-### 1.3 特性及适用场景
-模型的原始训练是基于SZ-Taxi数据集训练的,读取的图为深圳罗湖区156条主要道路的交通连接情况。因此对于针对罗湖区的自定义交通速度数据(大小为N×156,N>12),都能给出具有参考价值的未来一定时段的交通速度,从而有助于判断未来一段时间内道路的拥堵情况等。
+主程序流程
+
+1、初始化流管理。
+2、读取数据集。
+3、向流发送数据,进行推理。
+4、获取pipeline各插件输出结果。
+5、销毁流。
+
+### 1.2 支持的产品
+
+本项目以昇腾Atlas300V pro、 Atlas300I pro为主要的硬件平台
-### 1.4 代码目录结构与说明
+### 1.3 支持的版本
+
+本样例配套的MxVision版本、CANN版本、Driver/Firmware版本如下所示:
+| MxVision版本 | CANN版本 | Driver/Firmware版本 |
+| --------- | ------------------ | -------------- |
+| 6.0.RC3 | 8.0.RC3 | 24.1.RC3 |
+
+### 1.4 三方依赖
+| 依赖软件 | 版本 |
+| -------- | --------- |
+| scipy | 1.13.1 |
+| numpy | 1.24.0 |
+
+### 1.5 代码目录结构与说明
eg:本sample工程名称为STGCN,工程目录如下图所示:
```
@@ -44,98 +63,38 @@ eg:本sample工程名称为STGCN,工程目录如下图所示:
├── predict.py # 根据输入的数据集输出未来一定时段的交通速度
├── README.md
├── convert_om.sh # onnx文件转化为om文件
-├── results # 预测结果存放
-└── train_need
- └── export_onnx.py # 将pth文件转化成onnx文件,添加进训练项目
+└── results # 预测结果存放
```
-## 2 环境依赖
+### 1.6 相关约束
-eg:推荐系统为ubuntu 18.04,环境依赖软件和版本如下表:
+模型的原始训练是基于SZ-Taxi数据集训练的,读取的图为深圳罗湖区156条主要道路的交通连接情况。因此对于针对罗湖区的自定义交通速度数据(大小为N×156,N>12),都能给出具有参考价值的未来一定时段的交通速度,从而有助于判断未来一段时间内道路的拥堵情况等。
+
+## 2 设置环境变量
-| 软件名称 | 版本 |
-| -------- | ------ |
-| mxVision | 2.0.4 |
-| Python | 3.9 |
-| CANN | 5.1.RC1 |
-- 环境变量介绍
在编译运行项目前,需要设置环境变量:
```
-. ${SDK安装路径}/set_env.sh
-. ${CANN安装路径}/set_env.sh
-```
+#设置CANN环境变量(请确认install_path路径是否正确)
+. ${ascend-toolkit-path}/set_env.sh
-## 依赖安装
-```
-CANN软件包获取地址:https://www.hiascend.com/software/cann/commercial
-SDK官方下载地址:https://www.hiascend.com/zh/software/mindx-sdk
-我的安装步骤是本地下载好对应版本的安装包然后上传到服务器,然后再完成以下两个步骤
-1、给.run安装包设置可执行权限
-2、执行安装指令
-./ *.run --install
-```
+#设置MindX SDK 环境变量,SDK-path为mxVision SDK 安装路径
+. ${SDK-path}/set_env.sh
-## 3 城市道路交通预测开发实现
-总体流程如下:
-```
-模型训练->模型转化->模型推理
-```
-### 3.1 模型训练
-首先需要使用STGCN对SZ-Taxi数据集进行训练,使用的模型代码和数据集获取方式如下。
-```
-STGCN模型GitHub仓库:https://github.com/hazdzz/STGCN
+#查看环境变量
+env
-SZ-Taxi数据集:https://github.com/lehaifeng/T-GCN/tree/master/data
-```
-自行参照GitHub项目中的README.md和requiremments.txt文件配置训练所需环境。
-为训练SZ-Taxi数据集,主要需要修改两个部分:
-```
-1、stgcn.py部分
-(1)训练参数如下:
-'learning_rate': 0.001,
-'epochs': 1000,
-'batch_size': 8,
-'gamma': 0.95,
-'drop_rate': 0.5,
-'weight_decay_rate': 0.0005
-(2)将数据集放到指定文件夹后增加
-args.dataset == 'sz-taxis'
-
-2、dataloader.py部分
-(1)load_adj()
-读取邻接矩阵部分改为
-my_data = np.genfromtxt('data/sz-taxis/sz_adj.csv', delimiter=',') # 邻接矩阵路径
-smy_data = sp.csr_matrix(my_data)
-adj = smy_data.tocsc()
-并且增加
-elif dataset_name == 'sz-taxis':
- n_vertex = 156
-(2)load_data()
-train的划分改为:
-train = vel[: len_train + len_val]
```
-修改完毕后将SZ-Taxi数据集中的sz_adj.csv和sz_speed.csv文件放在'data/sz-taxis/'目录下(如果没有自行创建),运行STGCN模型GitHub仓库中的main.py文件即可开始训练。训练获得pth文件可通过export_onnx.py转换成onnx文件。
-训练好的pth文件连接如下:
-```
-https://mindx.sdk.obs.cn-north-4.myhuaweicloud.com/mindxsdk-referenceapps%20/contrib/STGCN/stgcn_sym_norm_lap_45_mins.pth
-```
+## 3 准备模型
-### 3.2 模型转化
-本项目推理模型权重采用Github仓库中Pytorch框架的STGCN模型训练SZ-Taxi数据集得到的权重转化得到。经过以下两步模型转化:
-1、pth转化为onnx
-可以根据实际的路径和输入大小修改export_onnx.py(该文件需要依赖于项目结构目录,请放到训练代码项目主目录下再运行)
-运行指令如下:
-```
-python export_onnx.py
-```
-转换好的onnx文件连接如下:
+**步骤1** 模型下载
+本项目提供训练好的onnx模型,下载链接如下:
```
https://mindx.sdk.obs.cn-north-4.myhuaweicloud.com/mindxsdk-referenceapps%20/contrib/STGCN/stgcn10.onnx
```
-2、onnx转化为om
+**步骤2** onnx转化为om
根据实际路径修改convert_om.sh
```
bash convert_om.sh [model_path] stgcn10
@@ -144,24 +103,12 @@ model_path:onnx文件路径须自行输入。
stgcn10:生成的om模型文件名,转换脚本会在此基础上添加.om后缀。
```
-## 4 模型推理
-### 4.1 pipeline编排
-```
- appsrc # 输入
- mxpi_tensorinfer # 模型推理
- mxpi_dataserialize
- appsink # 输出
-```
-### 4.2 主程序流程
+## 4 运行
+### 4.1 数据集准备
+SZ-Taxi数据集下载链接:https://github.com/lehaifeng/T-GCN/tree/master/data
+将sz_speed.csv放置在工程目录/data下
-1、初始化流管理。
-2、读取数据集。
-3、向流发送数据,进行推理。
-4、获取pipeline各插件输出结果。
-5、销毁流。
-
-## 5 运行
-### 5.1 运行main.py
+### 4.2 运行main.py
运行main.py可以在sz_speed.csv的测试集上获得推理精度,指令如下:
```
python main.py [image_path] [result_dir] [n_pred]
@@ -176,10 +123,8 @@ n_pred:预测时段,如9
```
最后sz_speed.csv测试集的推理预测的结果会保存在results/predictions.txt文件中,实际数据会保存在results/labels.txt文件中。
推理精度会直接显示在界面上。
-```
-MAE 2.81 | RMSE 4.29
-```
-### 5.2 运行predict.py
+
+### 4.3 运行predict.py
如果需要推理自定义的数据集(行数大于12行,列数为156列的csv文件),运行predict.py,指令如下:
```
python predict.py [image_path] [result_dir]
@@ -193,10 +138,3 @@ result_dir:推理结果保存路径,如“results/”
则会在results文件夹下生成代表预测的交通速度数据prediction.txt文件
这是通过已知数据集里过去时段的交通速度数据预测未来一定时间内的交通速度,无标准参考,所以只会输出代表预测的交通速度数据的prediction.txt文件,而没有MAE和RMSE等精度。
另外和main.py的运行指令相比少一个n_pred参数,因为已在代码中定义了确定数值,无需额外输入。
-
-
-## 6 常见问题
-1、服务器上进行推理的时候出现coredump报错
-```
-原因:因为服务器上安装了好几个版本的mxVision,使用RC2版本的时候出现了这个问题,2.0.4版本的时候就可以了,是版本不匹配导致的。运行前可以先运行一下对应版本的set_env.sh
-```
\ No newline at end of file
diff --git a/contrib/STGCN/convert_om.sh b/contrib/STGCN/convert_om.sh
index a5ee3cf6dc69780c3fb89666732f91a948b32bfc..c3cd4e5bccd8a69751b9e7794be7f9205560d924 100644
--- a/contrib/STGCN/convert_om.sh
+++ b/contrib/STGCN/convert_om.sh
@@ -23,7 +23,7 @@ atc \
--enable_small_channel=1 \
--log=error \
--input_format=NCHW \
- --soc_version=Ascend310 \
+ --soc_version=Ascend310P3 \
--op_select_implmode=high_precision \
--output_type=FP32
diff --git a/contrib/STGCN/data/.keep b/contrib/STGCN/data/.keep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/contrib/STGCN/results/.keep b/contrib/STGCN/results/.keep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/contrib/STGCN/train_need/export_onnx.py b/contrib/STGCN/train_need/export_onnx.py
deleted file mode 100644
index 02ee618f5b710118309cd4c98fb6aaba0357766e..0000000000000000000000000000000000000000
--- a/contrib/STGCN/train_need/export_onnx.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# Copyright(C) 2022. Huawei Technologies Co.,Ltd. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-import torch.nn
-import onnx
-
-model = torch.load('model/save/sz-taxis/stgcn_sym_norm_lap_45_mins.pth')
-input_names = ['input']
-output_names = ['output']
-
-x = torch.randn(64, 1, 12, 156, device='cpu')
-
-torch.onnx.export(model, x, 'stgcn10.onnx',\
- opset_version = 12, input_names=input_names, \
- output_names=output_names, verbose='True')