diff --git a/official/gnn/gcn/README.md b/official/gnn/gcn/README.md index 77571bd31d639541ef54bfad6b822e711d2d2a13..9fa28fd06c6912766c5e60f3e4b1b11b8f8f179d 100644 --- a/official/gnn/gcn/README.md +++ b/official/gnn/gcn/README.md @@ -147,7 +147,7 @@ Epoch: 0198 train_loss= 0.60680 train_acc= 0.95000 val_loss= 1.04847 val_acc= 0. Epoch: 0199 train_loss= 0.61920 train_acc= 0.96429 val_loss= 1.04797 val_acc= 0.78400 time= 0.00413 Epoch: 0200 train_loss= 0.57948 train_acc= 0.96429 val_loss= 1.04753 val_acc= 0.78600 time= 0.00415 Optimization Finished! -Test set results: cost= 1.00983 accuracy= 0.81300 time= 0.39083 +Test set results: cost= 1.00983 accuracy= 0.81800 time= 0.39083 ... ``` @@ -158,13 +158,13 @@ Test set results: cost= 1.00983 accuracy= 0.81300 time= 0.39083 | Parameters | GCN | | -------------------------- | -------------------------------------------------------------- | | Resource | Ascend 910 | -| uploaded Date | 06/09/2020 (month/day/year) | +| uploaded Date | 09/17/2021 (month/day/year) | | MindSpore Version | 1.0.0 | | Dataset | Cora/Citeseer | | Training Parameters | epoch=200 | | Optimizer | Adam | | Loss Function | Softmax Cross Entropy | -| Accuracy | 81.5/70.3 | +| Accuracy | 81.8/71.3 | | Parameters (B) | 92160/59344 | | Scripts | | diff --git a/official/gnn/gcn/README_CN.md b/official/gnn/gcn/README_CN.md index f89bffdafedc9b1ea993436754e08b6add09fd52..164612b5b491f14922ce70a15b2097fabd9a39bc 100644 --- a/official/gnn/gcn/README_CN.md +++ b/official/gnn/gcn/README_CN.md @@ -1,192 +1,412 @@ -# 目录 - - - -- [目录](#目录) -- [图卷积网络描述](#图卷积网络描述) -- [模型架构](#模型架构) -- [数据集](#数据集) -- [环境要求](#环境要求) -- [快速入门](#快速入门) - - [用法](#用法) - - [启动](#启动) -- [脚本说明](#脚本说明) - - [脚本及样例代码](#脚本及样例代码) - - [脚本参数](#脚本参数) - - [培训、评估、测试过程](#培训评估测试过程) - - [用法](#用法-1) - - [启动](#启动-1) - - [结果](#结果) -- [模型描述](#模型描述) - - [性能](#性能) -- [随机情况说明](#随机情况说明) -- [ModelZoo主页](#modelzoo主页) - - - -# 图卷积网络描述 +# 交付件基本信息 -图卷积网络(GCN)于2016年提出,旨在对图结构数据进行半监督学习。它提出了一种基于卷积神经网络有效变体的可扩展方法,可直接在图上操作。该模型在图边缘的数量上线性缩放,并学习隐藏层表示,这些表示编码了局部图结构和节点特征。 - -[论文](https://arxiv.org/abs/1609.02907): Thomas N. Kipf, Max Welling.2016.Semi-Supervised Classification with Graph Convolutional Networks.In ICLR 2016. - -# 模型架构 - -GCN包含两个图卷积层。每一层以节点特征和邻接矩阵为输入,通过聚合相邻特征来更新节点特征。 - -# 数据集 - -| 数据集 | 类型 | 节点 | 边 | 类 | 特征 | 标签率 | -| ------- | ---------------:|-----:| ----:| ------:|--------:| ---------:| -| Cora | Citation network | 2708 | 5429 | 7 | 1433 | 0.052 | -| Citeseer| Citation network | 3327 | 4732 | 6 | 3703 | 0.036 | +**发布者(Publisher)**:Huawei -# 环境要求 +**应用领域(Application Domain)**:GNN -- 硬件(Ascend处理器) - - 准备Ascend或GPU处理器搭建硬件环境。 -- 框架 - - [MindSpore](https://gitee.com/mindspore/mindspore) -- 如需查看详情,请参见如下资源: - - [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html) - - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html) +**版本(Version)**:1.2 -# 快速入门 +**修改时间(Modified)**:2021.09.17 -- 安装[MindSpore](https://www.mindspore.cn/install) +**大小(Size)(cora)**:91 KB (ckpt)/ 91 KB (onnx)/ 99 KB (air\)/ 129 KB \(om) -- 从github下载/kimiyoung/planetoid提供的数据集Cora或Citeseer +**大小(Size)(citeseer)**:232 KB (ckpt)/ 233 KB (onnx)/ 241 KB (air\)/ 194 KB \(om) -- 将数据集放到任意路径,文件夹应该包含如下文件(以Cora数据集为例): +**框架(Framework)**:MindSpore\_1.2.0 -```text -. -└─data - ├─ind.cora.allx - ├─ind.cora.ally - ├─ind.cora.graph - ├─ind.cora.test.index - ├─ind.cora.tx - ├─ind.cora.ty - ├─ind.cora.x - └─ind.cora.y -``` - -- 为Cora或Citeseer生成MindRecord格式的数据集 - -## 用法 +**模型格式(Model Format)**:ckpt/onnx/air/om -```buildoutcfg -cd ./scripts -# SRC_PATH为下载的数据集文件路径,DATASET_NAME为Cora或Citeseer -sh run_process_data.sh [SRC_PATH] [DATASET_NAME] -``` +**精度(Precision)**:Mixed/FP16 -## 启动 +**处理器(Processor)**:昇腾910/昇腾310 -```text -# 为Cora生成MindRecord格式的数据集 -sh run_process_data.sh ./data cora -# 为Citeseer生成MindRecord格式的数据集 -sh run_process_data.sh ./data citeseer -``` +**应用级别(Categories)**:Released -# 脚本说明 - -## 脚本及样例代码 - -```shell -. -└─gcn - ├─README.md - ├─scripts - | ├─run_process_data.sh # 生成MindRecord格式的数据集 - | └─run_train.sh # 启动训练,目前只支持Ascend后端 - | - ├─src - | ├─config.py # 参数配置 - | ├─dataset.py # 数据预处理 - | ├─gcn.py # GCN骨干 - | └─metrics.py # 损失和准确率 - | - └─train.py # 训练网络,每个训练轮次后评估验证结果收敛后,训练停止,然后进行测试。 -``` +**描述(Description)**:基于MindSpore框架的GCN隐性反馈协同过滤网络模型训练并保存模型,通过ATC工具转换,可在昇腾AI设备上运行,支持使用MindX SDK及MxBase进行推理 -## 脚本参数 +# 概述 -训练参数可以在config.py中配置。 +## 简述 -```text -"learning_rate": 0.01, # 学习率 -"epochs": 200, # 训练轮次 -"hidden1": 16, # 第一图卷积层隐藏大小 -"dropout": 0.5, # 第一图卷积层dropout率 -"weight_decay": 5e-4, # 第一图卷积层参数的权重衰减 -"early_stopping": 10, # 早停容限 -``` - -## 培训、评估、测试过程 - -### 用法 +图卷积网络(GCN)于2016年提出,旨在对图结构数据进行半监督学习。它提出了一种基于卷积神经网络有效变体的可扩展方法,可直接在图上操作。该模型在图边缘的数量上线性缩放,并学习隐藏层表示,这些表示编码了局部图结构和节点特征。 -```text -# 使用Cora或Citeseer数据集进行训练,DATASET_NAME为Cora或Citeseer -sh run_train.sh [DATASET_NAME] -``` +[论文](https://arxiv.org/abs/1609.02907): Thomas N. Kipf, Max Welling.2016.Semi-Supervised Classification with Graph Convolutional Networks.In ICLR 2016. -### 启动 +通过Git获取对应commit_id的代码方法如下: -```bash -sh run_train.sh cora ``` - -### 结果 - -训练结果将保存在脚本路径下,文件夹名称以“train”开头。您可在日志中找到如下结果: - -```text -Epoch:0001 train_loss= 1.95373 train_acc= 0.09286 val_loss= 1.95075 val_acc= 0.20200 time= 7.25737 -Epoch:0002 train_loss= 1.94812 train_acc= 0.32857 val_loss= 1.94717 val_acc= 0.34000 time= 0.00438 -Epoch:0003 train_loss= 1.94249 train_acc= 0.47857 val_loss= 1.94337 val_acc= 0.43000 time= 0.00428 -Epoch:0004 train_loss= 1.93550 train_acc= 0.55000 val_loss= 1.93957 val_acc= 0.46400 time= 0.00421 -Epoch:0005 train_loss= 1.92617 train_acc= 0.67143 val_loss= 1.93558 val_acc= 0.45400 time= 0.00430 -... -Epoch:0196 train_loss= 0.60326 train_acc= 0.97857 val_loss= 1.05155 val_acc= 0.78200 time= 0.00418 -Epoch:0197 train_loss= 0.60377 train_acc= 0.97143 val_loss= 1.04940 val_acc= 0.78000 time= 0.00418 -Epoch:0198 train_loss= 0.60680 train_acc= 0.95000 val_loss= 1.04847 val_acc= 0.78000 time= 0.00414 -Epoch:0199 train_loss= 0.61920 train_acc= 0.96429 val_loss= 1.04797 val_acc= 0.78400 time= 0.00413 -Epoch:0200 train_loss= 0.57948 train_acc= 0.96429 val_loss= 1.04753 val_acc= 0.78600 time= 0.00415 -Optimization Finished! -Test set results: cost= 1.00983 accuracy= 0.81300 time= 0.39083 -... +git clone {repository_url} # 克隆仓库的代码 +cd {repository_name} # 切换到模型的代码仓目录 +git checkout {branch} # 切换到对应分支 +git reset --hard {commit_id} # 代码设置到对应的commit_id +cd {code_path} # 切换到模型代码所在路径,若仓库下只有该模型,则无需切换 ``` -# 模型描述 - -## 性能 - -| 参数 | GCN | -| -------------------------- | -------------------------------------------------------------- | -| 资源 | Ascend 910 | -| 上传日期 | 2020-06-09 | -| MindSpore版本 | 0.5.0-beta | -| 数据集 | Cora/Citeseer | -| 训练参数 | epoch=200 | -| 优化器 | Adam | -| 损失函数 | Softmax交叉熵 | -| 准确率 | 81.5/70.3 | -| 参数(B) | 92160/59344 | -| 脚本 | | - -# 随机情况说明 - -以下两种随机情况: - -- 根据入参--seed在train.py中设置种子。 -- 随机失活操作。 +# 推理 -train.py已经设置了一些种子,避免权重初始化的随机性。若需关闭随机失活,将src/config.py中相应的dropout_prob参数设置为0。 +## 准备容器环境 + +1. 下载源码包。 + + 单击“下载模型脚本”和“下载模型”,下载所需软件包。 + +2. 将源码上传至推理服务器任意目录并解压(如:“/home/data/cz“)。 + +3. 下载所需的软件包。 + + 下载MindX SDK开发套件(mxManufacture)。 -# ModelZoo主页 +4. 编译镜像。 + + **docker build -t** *infer_image* **--build-arg FROM_IMAGE_NAME=** *base_image:tag* **--build-arg SDK_PKG=** *sdk_pkg* **.** + + **表 1** 参数说明 + + + + + + + + + + + + + + + + + + +

参数

+

说明

+

infer_image

+

推理镜像名称,根据实际写入。

+

base_image

+

基础镜像,可从Ascend Hub上下载。

+

tag

+

镜像tag,请根据实际配置,如:21.0.1。

+

sdk_pkg

+

下载的mxManufacture包名称,如Ascend-mindxsdk-mxmanufacture_{version}_linux-{arch}.run。

+
+ + + >![输入图片说明](https://images.gitee.com/uploads/images/2021/0719/172222_3c2963f4_923381.gif "icon-note.gif") **说明:** + >不要遗漏命令结尾的“.“。 + +5. 准备数据。 + + * 将 ModelArts 数据预处理之后 results 文件夹里生成的推理数据拷贝到 infer/data/input 下 + +6. 启动容器。 + + **bash docker_start_infer.sh** *infer_image* *data_path* + + **表 2** 参数说明 + + + + + + + + + + + + +

参数

+

说明

+

infer_image

+

推理镜像名称,根据实际写入。

+

data_path

+

数据路径。如:“/home/HwHiAiUser”

+
+ +7. 进入容器。 + + ``` + # 进入容器 + docker exec -it -u root npu0 bash + # 切换工作目录 + cd /home/data/cz/gcn + ``` + +## 模型转换 + + 1. 准备模型文件。 + + * 将ModelArts训练之后导出的 results/model/**.air 模型文件放入 infer/data/model 目录下 + + 2. 模型转换。 + + * 执行infer/convert/run.sh + ``` + cd ./infer/convert + # 对Cora模型进行转换 + sh run.sh cora + # 对Citeseer模型进行转换 + sh run.sh citeseer + ``` + 执行完成后会在infer/data/model目录下生成**.om模型文件,注意此处om文件名需与pipeline中的保持一致 + +## mxBase推理 + + 1. 编译工程。 + + ``` + cd ./infer/mxbase + # 通过 vim run.sh 更改 MX_SDK_HOME 路径, MX_SDK_HO 设置为 MX_SDK 安装路径 + # 对 Cora 模型进行编译 + sh run.sh cora + # 对Citeseer数据集进行编译 + sh run.sh citeseer + ``` + 在当前目录生成可执行文件main + + 2. 运行推理服务。 + + ``` + # 设置环境变量 + export MX_SDK_HOME="/home/data/cz/app/mxManufacture" # MX_SDK_HOME设置为 MX_SDK 安装路径 + export LD_LIBRARY_PATH="${MX_SDK_HOME}/lib":"${MX_SDK_HOME}/opensource/lib":"${MX_SDK_HOME}/opensource/lib64":"/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64":${LD_LIBRARY_PATH} + export GST_PLUGIN_SCANNER="${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner" + export GST_PLUGIN_PATH="${MX_SDK_HOME}/opensource/lib/gstreamer-1.0":"${MX_SDK_HOME}/lib/plugins" + + # ./infer/mxbase 下运行推理服务 + ./main + ``` + + 3. 查看推理结果。 + + 推理脚本会在命令行显示如下结果: + ``` + # cora 数据集 + node_nums=2708 + class_nums=7 + ============================ Infer Result ============================ + Infer acc:0.818 + ======================================================================= + I0914 04:41:45.136170 44822 MxsmStream.cpp:688] Begin to destroy stream(gcn). + I0914 04:41:45.136253 44822 MxsmStream.cpp:743] Send custom eos to the Stream successfully. + I0914 04:41:45.136407 44822 MxsmStream.cpp:749] Send eos to the Stream successfully. + I0914 04:41:45.136466 44822 MxsmStream.cpp:755] Flushes the Stream data successfully. + + # citeseer 数据集 + node_nums=3312 + class_nums=6 + ============================ Infer Result ============================ + Infer acc:0.713 + ======================================================================= + I0914 04:42:46.791231 45207 MxsmStream.cpp:688] Begin to destroy stream(gcn). + I0914 04:42:46.791312 45207 MxsmStream.cpp:743] Send custom eos to the Stream successfully. + I0914 04:42:46.791458 45207 MxsmStream.cpp:749] Send eos to the Stream successfully. + I0914 04:42:46.791514 45207 MxsmStream.cpp:755] Flushes the Stream data successfully. + ``` + +## MindX SDK推理 + + 1. 修改配置文件。 + + 1. 修改pipeline文件。 + + ``` + vim infer/data/config/gcn_cora.pipeline + vim infer/data/config/gcn_citeseer.pipeline + ``` + 如需替换模型,修改”modelPath”字段对应的模型路径 + + 2. 运行推理服务。 + + 1. 执行推理。 + + ``` + cd infer/sdk + bash run.sh [Dataset] # Dataset 在["cora", "citeseer"]中选择 + ``` + + 2. 查看推理结果。 + + 推理脚本会在命令行终端显示如下结果: + ``` + # cora 数据集 + adj.txt shape : [1, 7333264] + Send successfully! + feature.txt shape : [1, 3880564] + Send successfully! + ============================ Infer Result ============================ + Pred_label label:[3 4 4 ... 1 3 3] + Infer acc:0.818000 + ======================================================================= + I0913 12:15:27.282696 230520 MxsmStream.cpp:720] Begin to destroy stream(gcn). + I0913 12:15:27.282829 230520 MxsmStream.cpp:776] Send custom eos to the Stream successfully. + I0913 12:15:27.282997 230520 MxsmStream.cpp:782] Send eos to the Stream successfully. + I0913 12:15:27.283041 230520 MxsmStream.cpp:788] Flushes the Stream data successfully. + + # citeseer 数据集 + adj.txt shape : [1, 10969344] + Send successfully! + feature.txt shape : [1, 12264336] + Send successfully! + ============================ Infer Result ============================ + Pred_label label:[3 1 5 ... 3 1 5] + Infer acc:0.713000 + ======================================================================= + I0913 12:17:04.869657 230570 MxsmStream.cpp:720] Begin to destroy stream(gcn). + I0913 12:17:04.869745 230570 MxsmStream.cpp:776] Send custom eos to the Stream successfully. + I0913 12:17:04.869858 230570 MxsmStream.cpp:782] Send eos to the Stream successfully. + I0913 12:17:04.869911 230570 MxsmStream.cpp:788] Flushes the Stream data successfully. + ``` + +# 在ModelArts上应用 + +## 创建OBS桶 + +1. 创建桶。 + + * 点击”创建桶“ + * ”区域“选择”华北-北京四“ + * ”存储类别“选取”标准存储“ + * ”桶ACL“选取”私有“ + * 关闭”多AZ“ + * 输入全局唯一桶名称, 例如 “S3" + * 点击”确定“ + +2. 创建文件夹存放数据。 + + 在创建的桶中创建以下文件夹: + + * gcn:存放训练脚本、数据集、训练生成ckpt模型 + * logs:存放训练日志目录 + +3. 上传代码 + + * 进入 gcn 代码文件根目录 + * 将 gcn 目录下的文件全部上传至 obs://S3/gcn 文件夹下 + + +## 创建训练作业 + +1. 登录ModelArts。 + +2. 创建训练作业。 + +### 模型训练、评估、测试、冻结 + + 在 ModelArts 上使用单卡训练。 + + ``` + # ==================================创建算法========================================== + # (1) 上传你的代码和数据集到 S3 桶上 + # (2) 创建方式: 自定义脚本 + AI引擎:Ascend-Powered-Engine mindspore_1.3.0-cann_5.0.2-py_3.7-euler_2.8.3-aarch64 + 代码目录: /S3/gcn/ + 启动文件: /S3/gcn/train.py + # (3) 超参: + 名称 类型 必需 + dataset String 是 + data_dir String 是 + train_dir String 是 + train_nodes_num String 是 + # (4) 自定义超参:支持 + # (5) 输入数据配置: "映射名称 = '数据来源'", "代码路径参数 = 'data_dir'" + # (6) 输出数据配置: "映射名称 = '模型输出'", "代码路径参数 = 'train_dir'" + # (7) 添加训练约束: 否 + + # ==================================创建训练作业======================================= + # (1) 算法: 在我的算法中选择前面创建的算法 + # (2) 训练输入: '/S3/gcn/data/' + # (3) 训练输出: '/S3/gcn/results/' + + # 训练 cora 数据集 + # (4) 超参: + "dataset = 'cora'" + "data_dir = 'obs://S3/gcn/data/'" + "train_dir='obs://S3/gcn/results/'" + "train_nodes_num=140" + # 训练 citeseer 数据集 + # (4) 超参: + "dataset = 'citeseer'" + "data_dir = 'obs://S3/data/'" + "train_dir='obs://S3/gcn/results/'" + "train_nodes_num=120" + + # (5) 设置作业日志路径 + ``` + + 训练结果模型将保存在 obs://S3/gcn/results/model/ 文件夹下。您可在 /logs 文件夹下的日志文件中找到如下结果: + + ```text + Epoch:0001 train_loss= 1.95373 train_acc= 0.09286 val_loss= 1.95075 val_acc= 0.20200 time= 7.25737 + Epoch:0002 train_loss= 1.94812 train_acc= 0.32857 val_loss= 1.94717 val_acc= 0.34000 time= 0.00438 + Epoch:0003 train_loss= 1.94249 train_acc= 0.47857 val_loss= 1.94337 val_acc= 0.43000 time= 0.00428 + Epoch:0004 train_loss= 1.93550 train_acc= 0.55000 val_loss= 1.93957 val_acc= 0.46400 time= 0.00421 + Epoch:0005 train_loss= 1.92617 train_acc= 0.67143 val_loss= 1.93558 val_acc= 0.45400 time= 0.00430 + ... + Epoch:0196 train_loss= 0.60326 train_acc= 0.97857 val_loss= 1.05155 val_acc= 0.78200 time= 0.00418 + Epoch:0197 train_loss= 0.60377 train_acc= 0.97143 val_loss= 1.04940 val_acc= 0.78000 time= 0.00418 + Epoch:0198 train_loss= 0.60680 train_acc= 0.95000 val_loss= 1.04847 val_acc= 0.78000 time= 0.00414 + Epoch:0199 train_loss= 0.61920 train_acc= 0.96429 val_loss= 1.04797 val_acc= 0.78400 time= 0.00413 + Epoch:0200 train_loss= 0.57948 train_acc= 0.96429 val_loss= 1.04753 val_acc= 0.78600 time= 0.00415 + Optimization Finished! + Test set results: cost= 1.00983 accuracy= 0.81800 time= 0.39083 + ... + ``` + + 可以得到最后训练的 cora 模型的测试精度为 0.818 , citeseet 模型的测试精度为 0.713。 + +### 推理数据导出 + + 完成上述模型冻结操作,可以执行推理数据导出脚本,生成推理服务所需格式的测试数据。 + + 华为云 modelarts 相关配置内容如下: + + ``` + # ==================================创建算法========================================== + # (1) 上传你的代码和数据集到 S3 桶上 + # (2) 创建方式: 自定义脚本 + AI引擎:Ascend-Powered-Engine mindspore_1.3.0-cann_5.0.2-py_3.7-euler_2.8.3-aarch64 + 代码目录: /S3/gcn/ + 启动文件: /S3/gcn/preprocess.py + # (3) 超参: + 名称 类型 必需 + dataset String 是 + data_dir String 是 + test_nodes_num String 是 + result_path String 是 + # (4) 自定义超参:支持 + # (5) 输入数据配置: "映射名称 = '数据来源'", "代码路径参数 = 'data_dir'" + # (6) 输出数据配置: "映射名称 = '模型输出'", "代码路径参数 = 'result_path'" + # (7) 添加训练约束: 否 + + # ==================================创建训练作业======================================= + # (1) 算法: 在我的算法中选择前面创建的算法 + # (2) 训练输入: '/S3/gcn/results/data_mr/' + # (3) 训练输出: '/S3/gcn/results/data/' + + # 导出 cora 数据集 + # (4) 超参: + "dataset = 'cora'" + "data_dir = 'obs://S3/gcn/results/data_mr/'" + "test_nodes_num='1000'" + "result_path='obs://S3/gcn/results/data/'" + # 导出 citeseer 数据集 + # (4) 超参: + "dataset = 'citeseer'" + "data_dir = 'obs://S3/gcn/results/data_mr/'" + "test_nodes_num='1000'" + "result_path='obs://S3/gcn/results/data/'" + + # (5) 设置作业日志路径 + ``` + + 启动推理数据导出脚本,执行结束后生成推理服务所需测试数据并保存在 obs://S3/gcn/results/data/ 文件夹下,将该数据下载用于推理服务。 + + +## 查看训练任务日志 + +1. 训练完成后进入logs文件夹,点击对应当次训练作业的日志文件即可。 + +2. logs文件夹内生成日志文件,记录模型精度评估结果。 -请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。 diff --git a/official/gnn/gcn/data/ind.citeseer.allx b/official/gnn/gcn/data/ind.citeseer.allx new file mode 100644 index 0000000000000000000000000000000000000000..592091071ac35b78599c722d0e60b1c655711728 Binary files /dev/null and b/official/gnn/gcn/data/ind.citeseer.allx differ diff --git a/official/gnn/gcn/data/ind.citeseer.ally b/official/gnn/gcn/data/ind.citeseer.ally new file mode 100644 index 0000000000000000000000000000000000000000..7503f81a0b6ac1cc21cb5a97073542a6b1dfe99c Binary files /dev/null and b/official/gnn/gcn/data/ind.citeseer.ally differ diff --git a/official/gnn/gcn/data/ind.citeseer.graph b/official/gnn/gcn/data/ind.citeseer.graph new file mode 100644 index 0000000000000000000000000000000000000000..a01dca663d0103c130f525017a4a9c313037e286 Binary files /dev/null and b/official/gnn/gcn/data/ind.citeseer.graph differ diff --git a/official/gnn/gcn/data/ind.citeseer.test.index b/official/gnn/gcn/data/ind.citeseer.test.index new file mode 100644 index 0000000000000000000000000000000000000000..62d9e3d5ed53c66b069125af98529d76c70d9b3e --- /dev/null +++ b/official/gnn/gcn/data/ind.citeseer.test.index @@ -0,0 +1,1000 @@ +2488 +2644 +3261 +2804 +3176 +2432 +3310 +2410 +2812 +2520 +2994 +3282 +2680 +2848 +2670 +3005 +2977 +2592 +2967 +2461 +3184 +2852 +2768 +2905 +2851 +3129 +3164 +2438 +2793 +2763 +2528 +2954 +2347 +2640 +3265 +2874 +2446 +2856 +3149 +2374 +3097 +3301 +2664 +2418 +2655 +2464 +2596 +3262 +3278 +2320 +2612 +2614 +2550 +2626 +2772 +3007 +2733 +2516 +2476 +2798 +2561 +2839 +2685 +2391 +2705 +3098 +2754 +3251 +2767 +2630 +2727 +2513 +2701 +3264 +2792 +2821 +3260 +2462 +3307 +2639 +2900 +3060 +2672 +3116 +2731 +3316 +2386 +2425 +2518 +3151 +2586 +2797 +2479 +3117 +2580 +3182 +2459 +2508 +3052 +3230 +3215 +2803 +2969 +2562 +2398 +3325 +2343 +3030 +2414 +2776 +2383 +3173 +2850 +2499 +3312 +2648 +2784 +2898 +3056 +2484 +3179 +3132 +2577 +2563 +2867 +3317 +2355 +3207 +3178 +2968 +3319 +2358 +2764 +3001 +2683 +3271 +2321 +2567 +2502 +3246 +2715 +3066 +2390 +2381 +3162 +2741 +2498 +2790 +3038 +3321 +2481 +3050 +3161 +3122 +2801 +2957 +3177 +2965 +2621 +3208 +2921 +2802 +2357 +2677 +2519 +2860 +2696 +2368 +3241 +2858 +2419 +2762 +2875 +3222 +3064 +2827 +3044 +2471 +3062 +2982 +2736 +2322 +2709 +2766 +2424 +2602 +2970 +2675 +3299 +2554 +2964 +2597 +2753 +2979 +2523 +2912 +2896 +2317 +3167 +2813 +2482 +2557 +3043 +3244 +2985 +2460 +2363 +3272 +3045 +3192 +2453 +2656 +2834 +2443 +3202 +2926 +2711 +2633 +2384 +2752 +3285 +2817 +2483 +2919 +2924 +2661 +2698 +2361 +2662 +2819 +3143 +2316 +3196 +2739 +2345 +2578 +2822 +3229 +2908 +2917 +2692 +3200 +2324 +2522 +3322 +2697 +3163 +3093 +3233 +2774 +2371 +2835 +2652 +2539 +2843 +3231 +2976 +2429 +2367 +3144 +2564 +3283 +3217 +3035 +2962 +2433 +2415 +2387 +3021 +2595 +2517 +2468 +3061 +2673 +2348 +3027 +2467 +3318 +2959 +3273 +2392 +2779 +2678 +3004 +2634 +2974 +3198 +2342 +2376 +3249 +2868 +2952 +2710 +2838 +2335 +2524 +2650 +3186 +2743 +2545 +2841 +2515 +2505 +3181 +2945 +2738 +2933 +3303 +2611 +3090 +2328 +3010 +3016 +2504 +2936 +3266 +3253 +2840 +3034 +2581 +2344 +2452 +2654 +3199 +3137 +2514 +2394 +2544 +2641 +2613 +2618 +2558 +2593 +2532 +2512 +2975 +3267 +2566 +2951 +3300 +2869 +2629 +2747 +3055 +2831 +3105 +3168 +3100 +2431 +2828 +2684 +3269 +2910 +2865 +2693 +2884 +3228 +2783 +3247 +2770 +3157 +2421 +2382 +2331 +3203 +3240 +2351 +3114 +2986 +2688 +2439 +2996 +3079 +3103 +3296 +2349 +2372 +3096 +2422 +2551 +3069 +2737 +3084 +3304 +3022 +2542 +3204 +2949 +2318 +2450 +3140 +2734 +2881 +2576 +3054 +3089 +3125 +2761 +3136 +3111 +2427 +2466 +3101 +3104 +3259 +2534 +2961 +3191 +3000 +3036 +2356 +2800 +3155 +3224 +2646 +2735 +3020 +2866 +2426 +2448 +3226 +3219 +2749 +3183 +2906 +2360 +2440 +2946 +2313 +2859 +2340 +3008 +2719 +3058 +2653 +3023 +2888 +3243 +2913 +3242 +3067 +2409 +3227 +2380 +2353 +2686 +2971 +2847 +2947 +2857 +3263 +3218 +2861 +3323 +2635 +2966 +2604 +2456 +2832 +2694 +3245 +3119 +2942 +3153 +2894 +2555 +3128 +2703 +2323 +2631 +2732 +2699 +2314 +2590 +3127 +2891 +2873 +2814 +2326 +3026 +3288 +3095 +2706 +2457 +2377 +2620 +2526 +2674 +3190 +2923 +3032 +2334 +3254 +2991 +3277 +2973 +2599 +2658 +2636 +2826 +3148 +2958 +3258 +2990 +3180 +2538 +2748 +2625 +2565 +3011 +3057 +2354 +3158 +2622 +3308 +2983 +2560 +3169 +3059 +2480 +3194 +3291 +3216 +2643 +3172 +2352 +2724 +2485 +2411 +2948 +2445 +2362 +2668 +3275 +3107 +2496 +2529 +2700 +2541 +3028 +2879 +2660 +3324 +2755 +2436 +3048 +2623 +2920 +3040 +2568 +3221 +3003 +3295 +2473 +3232 +3213 +2823 +2897 +2573 +2645 +3018 +3326 +2795 +2915 +3109 +3086 +2463 +3118 +2671 +2909 +2393 +2325 +3029 +2972 +3110 +2870 +3284 +2816 +2647 +2667 +2955 +2333 +2960 +2864 +2893 +2458 +2441 +2359 +2327 +3256 +3099 +3073 +3138 +2511 +2666 +2548 +2364 +2451 +2911 +3237 +3206 +3080 +3279 +2934 +2981 +2878 +3130 +2830 +3091 +2659 +2449 +3152 +2413 +2722 +2796 +3220 +2751 +2935 +3238 +2491 +2730 +2842 +3223 +2492 +3074 +3094 +2833 +2521 +2883 +3315 +2845 +2907 +3083 +2572 +3092 +2903 +2918 +3039 +3286 +2587 +3068 +2338 +3166 +3134 +2455 +2497 +2992 +2775 +2681 +2430 +2932 +2931 +2434 +3154 +3046 +2598 +2366 +3015 +3147 +2944 +2582 +3274 +2987 +2642 +2547 +2420 +2930 +2750 +2417 +2808 +3141 +2997 +2995 +2584 +2312 +3033 +3070 +3065 +2509 +3314 +2396 +2543 +2423 +3170 +2389 +3289 +2728 +2540 +2437 +2486 +2895 +3017 +2853 +2406 +2346 +2877 +2472 +3210 +2637 +2927 +2789 +2330 +3088 +3102 +2616 +3081 +2902 +3205 +3320 +3165 +2984 +3185 +2707 +3255 +2583 +2773 +2742 +3024 +2402 +2718 +2882 +2575 +3281 +2786 +2855 +3014 +2401 +2535 +2687 +2495 +3113 +2609 +2559 +2665 +2530 +3293 +2399 +2605 +2690 +3133 +2799 +2533 +2695 +2713 +2886 +2691 +2549 +3077 +3002 +3049 +3051 +3087 +2444 +3085 +3135 +2702 +3211 +3108 +2501 +2769 +3290 +2465 +3025 +3019 +2385 +2940 +2657 +2610 +2525 +2941 +3078 +2341 +2916 +2956 +2375 +2880 +3009 +2780 +2370 +2925 +2332 +3146 +2315 +2809 +3145 +3106 +2782 +2760 +2493 +2765 +2556 +2890 +2400 +2339 +3201 +2818 +3248 +3280 +2570 +2569 +2937 +3174 +2836 +2708 +2820 +3195 +2617 +3197 +2319 +2744 +2615 +2825 +2603 +2914 +2531 +3193 +2624 +2365 +2810 +3239 +3159 +2537 +2844 +2758 +2938 +3037 +2503 +3297 +2885 +2608 +2494 +2712 +2408 +2901 +2704 +2536 +2373 +2478 +2723 +3076 +2627 +2369 +2669 +3006 +2628 +2788 +3276 +2435 +3139 +3235 +2527 +2571 +2815 +2442 +2892 +2978 +2746 +3150 +2574 +2725 +3188 +2601 +2378 +3075 +2632 +2794 +3270 +3071 +2506 +3126 +3236 +3257 +2824 +2989 +2950 +2428 +2405 +3156 +2447 +2787 +2805 +2720 +2403 +2811 +2329 +2474 +2785 +2350 +2507 +2416 +3112 +2475 +2876 +2585 +2487 +3072 +3082 +2943 +2757 +2388 +2600 +3294 +2756 +3142 +3041 +2594 +2998 +3047 +2379 +2980 +2454 +2862 +3175 +2588 +3031 +3012 +2889 +2500 +2791 +2854 +2619 +2395 +2807 +2740 +2412 +3131 +3013 +2939 +2651 +2490 +2988 +2863 +3225 +2745 +2714 +3160 +3124 +2849 +2676 +2872 +3287 +3189 +2716 +3115 +2928 +2871 +2591 +2717 +2546 +2777 +3298 +2397 +3187 +2726 +2336 +3268 +2477 +2904 +2846 +3121 +2899 +2510 +2806 +2963 +3313 +2679 +3302 +2663 +3053 +2469 +2999 +3311 +2470 +2638 +3120 +3171 +2689 +2922 +2607 +2721 +2993 +2887 +2837 +2929 +2829 +3234 +2649 +2337 +2759 +2778 +2771 +2404 +2589 +3123 +3209 +2729 +3252 +2606 +2579 +2552 diff --git a/official/gnn/gcn/data/ind.citeseer.tx b/official/gnn/gcn/data/ind.citeseer.tx new file mode 100644 index 0000000000000000000000000000000000000000..b2aff18aa1fc6f4f5e8c64db7643b59b24e42584 Binary files /dev/null and b/official/gnn/gcn/data/ind.citeseer.tx differ diff --git a/official/gnn/gcn/data/ind.citeseer.ty b/official/gnn/gcn/data/ind.citeseer.ty new file mode 100644 index 0000000000000000000000000000000000000000..3795f79df0cd1123f4478b77c148a5fbacc4cefb Binary files /dev/null and b/official/gnn/gcn/data/ind.citeseer.ty differ diff --git a/official/gnn/gcn/data/ind.citeseer.x b/official/gnn/gcn/data/ind.citeseer.x new file mode 100644 index 0000000000000000000000000000000000000000..f094104a66f2689ffbd01085fe6a856689e9db9e Binary files /dev/null and b/official/gnn/gcn/data/ind.citeseer.x differ diff --git a/official/gnn/gcn/data/ind.citeseer.y b/official/gnn/gcn/data/ind.citeseer.y new file mode 100644 index 0000000000000000000000000000000000000000..e857ac49b611da98e101a068cf4855350274f68a Binary files /dev/null and b/official/gnn/gcn/data/ind.citeseer.y differ diff --git a/official/gnn/gcn/data/ind.cora.allx b/official/gnn/gcn/data/ind.cora.allx new file mode 100644 index 0000000000000000000000000000000000000000..44d53b1fece343538e45592caac521d73c6f98d6 Binary files /dev/null and b/official/gnn/gcn/data/ind.cora.allx differ diff --git a/official/gnn/gcn/data/ind.cora.ally b/official/gnn/gcn/data/ind.cora.ally new file mode 100644 index 0000000000000000000000000000000000000000..04fbd0b083d09341fcf16b395c642329637f0542 Binary files /dev/null and b/official/gnn/gcn/data/ind.cora.ally differ diff --git a/official/gnn/gcn/data/ind.cora.graph b/official/gnn/gcn/data/ind.cora.graph new file mode 100644 index 0000000000000000000000000000000000000000..4d3bf85dfc8b1c105a2c995597a1b3eed6947925 Binary files /dev/null and b/official/gnn/gcn/data/ind.cora.graph differ diff --git a/official/gnn/gcn/data/ind.cora.test.index b/official/gnn/gcn/data/ind.cora.test.index new file mode 100644 index 0000000000000000000000000000000000000000..ded8092db4cd767a367eebe09003b954a5316f24 --- /dev/null +++ b/official/gnn/gcn/data/ind.cora.test.index @@ -0,0 +1,1000 @@ +2692 +2532 +2050 +1715 +2362 +2609 +2622 +1975 +2081 +1767 +2263 +1725 +2588 +2259 +2357 +1998 +2574 +2179 +2291 +2382 +1812 +1751 +2422 +1937 +2631 +2510 +2378 +2589 +2345 +1943 +1850 +2298 +1825 +2035 +2507 +2313 +1906 +1797 +2023 +2159 +2495 +1886 +2122 +2369 +2461 +1925 +2565 +1858 +2234 +2000 +1846 +2318 +1723 +2559 +2258 +1763 +1991 +1922 +2003 +2662 +2250 +2064 +2529 +1888 +2499 +2454 +2320 +2287 +2203 +2018 +2002 +2632 +2554 +2314 +2537 +1760 +2088 +2086 +2218 +2605 +1953 +2403 +1920 +2015 +2335 +2535 +1837 +2009 +1905 +2636 +1942 +2193 +2576 +2373 +1873 +2463 +2509 +1954 +2656 +2455 +2494 +2295 +2114 +2561 +2176 +2275 +2635 +2442 +2704 +2127 +2085 +2214 +2487 +1739 +2543 +1783 +2485 +2262 +2472 +2326 +1738 +2170 +2100 +2384 +2152 +2647 +2693 +2376 +1775 +1726 +2476 +2195 +1773 +1793 +2194 +2581 +1854 +2524 +1945 +1781 +1987 +2599 +1744 +2225 +2300 +1928 +2042 +2202 +1958 +1816 +1916 +2679 +2190 +1733 +2034 +2643 +2177 +1883 +1917 +1996 +2491 +2268 +2231 +2471 +1919 +1909 +2012 +2522 +1865 +2466 +2469 +2087 +2584 +2563 +1924 +2143 +1736 +1966 +2533 +2490 +2630 +1973 +2568 +1978 +2664 +2633 +2312 +2178 +1754 +2307 +2480 +1960 +1742 +1962 +2160 +2070 +2553 +2433 +1768 +2659 +2379 +2271 +1776 +2153 +1877 +2027 +2028 +2155 +2196 +2483 +2026 +2158 +2407 +1821 +2131 +2676 +2277 +2489 +2424 +1963 +1808 +1859 +2597 +2548 +2368 +1817 +2405 +2413 +2603 +2350 +2118 +2329 +1969 +2577 +2475 +2467 +2425 +1769 +2092 +2044 +2586 +2608 +1983 +2109 +2649 +1964 +2144 +1902 +2411 +2508 +2360 +1721 +2005 +2014 +2308 +2646 +1949 +1830 +2212 +2596 +1832 +1735 +1866 +2695 +1941 +2546 +2498 +2686 +2665 +1784 +2613 +1970 +2021 +2211 +2516 +2185 +2479 +2699 +2150 +1990 +2063 +2075 +1979 +2094 +1787 +2571 +2690 +1926 +2341 +2566 +1957 +1709 +1955 +2570 +2387 +1811 +2025 +2447 +2696 +2052 +2366 +1857 +2273 +2245 +2672 +2133 +2421 +1929 +2125 +2319 +2641 +2167 +2418 +1765 +1761 +1828 +2188 +1972 +1997 +2419 +2289 +2296 +2587 +2051 +2440 +2053 +2191 +1923 +2164 +1861 +2339 +2333 +2523 +2670 +2121 +1921 +1724 +2253 +2374 +1940 +2545 +2301 +2244 +2156 +1849 +2551 +2011 +2279 +2572 +1757 +2400 +2569 +2072 +2526 +2173 +2069 +2036 +1819 +1734 +1880 +2137 +2408 +2226 +2604 +1771 +2698 +2187 +2060 +1756 +2201 +2066 +2439 +1844 +1772 +2383 +2398 +1708 +1992 +1959 +1794 +2426 +2702 +2444 +1944 +1829 +2660 +2497 +2607 +2343 +1730 +2624 +1790 +1935 +1967 +2401 +2255 +2355 +2348 +1931 +2183 +2161 +2701 +1948 +2501 +2192 +2404 +2209 +2331 +1810 +2363 +2334 +1887 +2393 +2557 +1719 +1732 +1986 +2037 +2056 +1867 +2126 +1932 +2117 +1807 +1801 +1743 +2041 +1843 +2388 +2221 +1833 +2677 +1778 +2661 +2306 +2394 +2106 +2430 +2371 +2606 +2353 +2269 +2317 +2645 +2372 +2550 +2043 +1968 +2165 +2310 +1985 +2446 +1982 +2377 +2207 +1818 +1913 +1766 +1722 +1894 +2020 +1881 +2621 +2409 +2261 +2458 +2096 +1712 +2594 +2293 +2048 +2359 +1839 +2392 +2254 +1911 +2101 +2367 +1889 +1753 +2555 +2246 +2264 +2010 +2336 +2651 +2017 +2140 +1842 +2019 +1890 +2525 +2134 +2492 +2652 +2040 +2145 +2575 +2166 +1999 +2434 +1711 +2276 +2450 +2389 +2669 +2595 +1814 +2039 +2502 +1896 +2168 +2344 +2637 +2031 +1977 +2380 +1936 +2047 +2460 +2102 +1745 +2650 +2046 +2514 +1980 +2352 +2113 +1713 +2058 +2558 +1718 +1864 +1876 +2338 +1879 +1891 +2186 +2451 +2181 +2638 +2644 +2103 +2591 +2266 +2468 +1869 +2582 +2674 +2361 +2462 +1748 +2215 +2615 +2236 +2248 +2493 +2342 +2449 +2274 +1824 +1852 +1870 +2441 +2356 +1835 +2694 +2602 +2685 +1893 +2544 +2536 +1994 +1853 +1838 +1786 +1930 +2539 +1892 +2265 +2618 +2486 +2583 +2061 +1796 +1806 +2084 +1933 +2095 +2136 +2078 +1884 +2438 +2286 +2138 +1750 +2184 +1799 +2278 +2410 +2642 +2435 +1956 +2399 +1774 +2129 +1898 +1823 +1938 +2299 +1862 +2420 +2673 +1984 +2204 +1717 +2074 +2213 +2436 +2297 +2592 +2667 +2703 +2511 +1779 +1782 +2625 +2365 +2315 +2381 +1788 +1714 +2302 +1927 +2325 +2506 +2169 +2328 +2629 +2128 +2655 +2282 +2073 +2395 +2247 +2521 +2260 +1868 +1988 +2324 +2705 +2541 +1731 +2681 +2707 +2465 +1785 +2149 +2045 +2505 +2611 +2217 +2180 +1904 +2453 +2484 +1871 +2309 +2349 +2482 +2004 +1965 +2406 +2162 +1805 +2654 +2007 +1947 +1981 +2112 +2141 +1720 +1758 +2080 +2330 +2030 +2432 +2089 +2547 +1820 +1815 +2675 +1840 +2658 +2370 +2251 +1908 +2029 +2068 +2513 +2549 +2267 +2580 +2327 +2351 +2111 +2022 +2321 +2614 +2252 +2104 +1822 +2552 +2243 +1798 +2396 +2663 +2564 +2148 +2562 +2684 +2001 +2151 +2706 +2240 +2474 +2303 +2634 +2680 +2055 +2090 +2503 +2347 +2402 +2238 +1950 +2054 +2016 +1872 +2233 +1710 +2032 +2540 +2628 +1795 +2616 +1903 +2531 +2567 +1946 +1897 +2222 +2227 +2627 +1856 +2464 +2241 +2481 +2130 +2311 +2083 +2223 +2284 +2235 +2097 +1752 +2515 +2527 +2385 +2189 +2283 +2182 +2079 +2375 +2174 +2437 +1993 +2517 +2443 +2224 +2648 +2171 +2290 +2542 +2038 +1855 +1831 +1759 +1848 +2445 +1827 +2429 +2205 +2598 +2657 +1728 +2065 +1918 +2427 +2573 +2620 +2292 +1777 +2008 +1875 +2288 +2256 +2033 +2470 +2585 +2610 +2082 +2230 +1915 +1847 +2337 +2512 +2386 +2006 +2653 +2346 +1951 +2110 +2639 +2520 +1939 +2683 +2139 +2220 +1910 +2237 +1900 +1836 +2197 +1716 +1860 +2077 +2519 +2538 +2323 +1914 +1971 +1845 +2132 +1802 +1907 +2640 +2496 +2281 +2198 +2416 +2285 +1755 +2431 +2071 +2249 +2123 +1727 +2459 +2304 +2199 +1791 +1809 +1780 +2210 +2417 +1874 +1878 +2116 +1961 +1863 +2579 +2477 +2228 +2332 +2578 +2457 +2024 +1934 +2316 +1841 +1764 +1737 +2322 +2239 +2294 +1729 +2488 +1974 +2473 +2098 +2612 +1834 +2340 +2423 +2175 +2280 +2617 +2208 +2560 +1741 +2600 +2059 +1747 +2242 +2700 +2232 +2057 +2147 +2682 +1792 +1826 +2120 +1895 +2364 +2163 +1851 +2391 +2414 +2452 +1803 +1989 +2623 +2200 +2528 +2415 +1804 +2146 +2619 +2687 +1762 +2172 +2270 +2678 +2593 +2448 +1882 +2257 +2500 +1899 +2478 +2412 +2107 +1746 +2428 +2115 +1800 +1901 +2397 +2530 +1912 +2108 +2206 +2091 +1740 +2219 +1976 +2099 +2142 +2671 +2668 +2216 +2272 +2229 +2666 +2456 +2534 +2697 +2688 +2062 +2691 +2689 +2154 +2590 +2626 +2390 +1813 +2067 +1952 +2518 +2358 +1789 +2076 +2049 +2119 +2013 +2124 +2556 +2105 +2093 +1885 +2305 +2354 +2135 +2601 +1770 +1995 +2504 +1749 +2157 diff --git a/official/gnn/gcn/data/ind.cora.tx b/official/gnn/gcn/data/ind.cora.tx new file mode 100644 index 0000000000000000000000000000000000000000..6e856d777401ee15dc8619db76c97d0a40ba2d60 Binary files /dev/null and b/official/gnn/gcn/data/ind.cora.tx differ diff --git a/official/gnn/gcn/data/ind.cora.ty b/official/gnn/gcn/data/ind.cora.ty new file mode 100644 index 0000000000000000000000000000000000000000..da1734ab670b5c8ecee78a9c30b35566df916f4b Binary files /dev/null and b/official/gnn/gcn/data/ind.cora.ty differ diff --git a/official/gnn/gcn/data/ind.cora.x b/official/gnn/gcn/data/ind.cora.x new file mode 100644 index 0000000000000000000000000000000000000000..c4a91d008245403e7f26aa616b437191509793c0 Binary files /dev/null and b/official/gnn/gcn/data/ind.cora.x differ diff --git a/official/gnn/gcn/data/ind.cora.y b/official/gnn/gcn/data/ind.cora.y new file mode 100644 index 0000000000000000000000000000000000000000..58e30ef120d3553cfaac33ff28c6b5436dbef661 Binary files /dev/null and b/official/gnn/gcn/data/ind.cora.y differ diff --git a/official/gnn/gcn/data/ind.pubmed.allx b/official/gnn/gcn/data/ind.pubmed.allx new file mode 100644 index 0000000000000000000000000000000000000000..3b86bccf08111c37c632a9dd8ab63db25ea76361 Binary files /dev/null and b/official/gnn/gcn/data/ind.pubmed.allx differ diff --git a/official/gnn/gcn/data/ind.pubmed.ally b/official/gnn/gcn/data/ind.pubmed.ally new file mode 100644 index 0000000000000000000000000000000000000000..4c8bcdc80882dd225fe4dd4fafd723baa616ac97 Binary files /dev/null and b/official/gnn/gcn/data/ind.pubmed.ally differ diff --git a/official/gnn/gcn/data/ind.pubmed.graph b/official/gnn/gcn/data/ind.pubmed.graph new file mode 100644 index 0000000000000000000000000000000000000000..76eb87e628a6b9d43fc31ea42c8c6c8f3e52b404 Binary files /dev/null and b/official/gnn/gcn/data/ind.pubmed.graph differ diff --git a/official/gnn/gcn/data/ind.pubmed.test.index b/official/gnn/gcn/data/ind.pubmed.test.index new file mode 100644 index 0000000000000000000000000000000000000000..41f6befcb7750cc4858cb852f2b349a9598cb2f4 --- /dev/null +++ b/official/gnn/gcn/data/ind.pubmed.test.index @@ -0,0 +1,1000 @@ +18747 +19392 +19181 +18843 +19221 +18962 +19560 +19097 +18966 +19014 +18756 +19313 +19000 +19569 +19359 +18854 +18970 +19073 +19661 +19180 +19377 +18750 +19401 +18788 +19224 +19447 +19017 +19241 +18890 +18908 +18965 +19001 +18849 +19641 +18852 +19222 +19172 +18762 +19156 +19162 +18856 +18763 +19318 +18826 +19712 +19192 +19695 +19030 +19523 +19249 +19079 +19232 +19455 +18743 +18800 +19071 +18885 +19593 +19394 +19390 +18832 +19445 +18838 +19632 +19548 +19546 +18825 +19498 +19266 +19117 +19595 +19252 +18730 +18913 +18809 +19452 +19520 +19274 +19555 +19388 +18919 +19099 +19637 +19403 +18720 +19526 +18905 +19451 +19408 +18923 +18794 +19322 +19431 +18912 +18841 +19239 +19125 +19258 +19565 +18898 +19482 +19029 +18778 +19096 +19684 +19552 +18765 +19361 +19171 +19367 +19623 +19402 +19327 +19118 +18888 +18726 +19510 +18831 +19490 +19576 +19050 +18729 +18896 +19246 +19012 +18862 +18873 +19193 +19693 +19474 +18953 +19115 +19182 +19269 +19116 +18837 +18872 +19007 +19212 +18798 +19102 +18772 +19660 +19511 +18914 +18886 +19672 +19360 +19213 +18810 +19420 +19512 +18719 +19432 +19350 +19127 +18782 +19587 +18924 +19488 +18781 +19340 +19190 +19383 +19094 +18835 +19487 +19230 +18791 +18882 +18937 +18928 +18755 +18802 +19516 +18795 +18786 +19273 +19349 +19398 +19626 +19130 +19351 +19489 +19446 +18959 +19025 +18792 +18878 +19304 +19629 +19061 +18785 +19194 +19179 +19210 +19417 +19583 +19415 +19443 +18739 +19662 +18904 +18910 +18901 +18960 +18722 +18827 +19290 +18842 +19389 +19344 +18961 +19098 +19147 +19334 +19358 +18829 +18984 +18931 +18742 +19320 +19111 +19196 +18887 +18991 +19469 +18990 +18876 +19261 +19270 +19522 +19088 +19284 +19646 +19493 +19225 +19615 +19449 +19043 +19674 +19391 +18918 +19155 +19110 +18815 +19131 +18834 +19715 +19603 +19688 +19133 +19053 +19166 +19066 +18893 +18757 +19582 +19282 +19257 +18869 +19467 +18954 +19371 +19151 +19462 +19598 +19653 +19187 +19624 +19564 +19534 +19581 +19478 +18985 +18746 +19342 +18777 +19696 +18824 +19138 +18728 +19643 +19199 +18731 +19168 +18948 +19216 +19697 +19347 +18808 +18725 +19134 +18847 +18828 +18996 +19106 +19485 +18917 +18911 +18776 +19203 +19158 +18895 +19165 +19382 +18780 +18836 +19373 +19659 +18947 +19375 +19299 +18761 +19366 +18754 +19248 +19416 +19658 +19638 +19034 +19281 +18844 +18922 +19491 +19272 +19341 +19068 +19332 +19559 +19293 +18804 +18933 +18935 +19405 +18936 +18945 +18943 +18818 +18797 +19570 +19464 +19428 +19093 +19433 +18986 +19161 +19255 +19157 +19046 +19292 +19434 +19298 +18724 +19410 +19694 +19214 +19640 +19189 +18963 +19218 +19585 +19041 +19550 +19123 +19620 +19376 +19561 +18944 +19706 +19056 +19283 +18741 +19319 +19144 +19542 +18821 +19404 +19080 +19303 +18793 +19306 +19678 +19435 +19519 +19566 +19278 +18946 +19536 +19020 +19057 +19198 +19333 +19649 +19699 +19399 +19654 +19136 +19465 +19321 +19577 +18907 +19665 +19386 +19596 +19247 +19473 +19568 +19355 +18925 +19586 +18982 +19616 +19495 +19612 +19023 +19438 +18817 +19692 +19295 +19414 +19676 +19472 +19107 +19062 +19035 +18883 +19409 +19052 +19606 +19091 +19651 +19475 +19413 +18796 +19369 +19639 +19701 +19461 +19645 +19251 +19063 +19679 +19545 +19081 +19363 +18995 +19549 +18790 +18855 +18833 +18899 +19395 +18717 +19647 +18768 +19103 +19245 +18819 +18779 +19656 +19076 +18745 +18971 +19197 +19711 +19074 +19128 +19466 +19139 +19309 +19324 +18814 +19092 +19627 +19060 +18806 +18929 +18737 +18942 +18906 +18858 +19456 +19253 +19716 +19104 +19667 +19574 +18903 +19237 +18864 +19556 +19364 +18952 +19008 +19323 +19700 +19170 +19267 +19345 +19238 +18909 +18892 +19109 +19704 +18902 +19275 +19680 +18723 +19242 +19112 +19169 +18956 +19343 +19650 +19541 +19698 +19521 +19087 +18976 +19038 +18775 +18968 +19671 +19412 +19407 +19573 +19027 +18813 +19357 +19460 +19673 +19481 +19036 +19614 +18787 +19195 +18732 +18884 +19613 +19657 +19575 +19226 +19589 +19234 +19617 +19707 +19484 +18740 +19424 +18784 +19419 +19159 +18865 +19105 +19315 +19480 +19664 +19378 +18803 +19605 +18870 +19042 +19426 +18848 +19223 +19509 +19532 +18752 +19691 +18718 +19209 +19362 +19090 +19492 +19567 +19687 +19018 +18830 +19530 +19554 +19119 +19442 +19558 +19527 +19427 +19291 +19543 +19422 +19142 +18897 +18950 +19425 +19002 +19588 +18978 +19551 +18930 +18736 +19101 +19215 +19150 +19263 +18949 +18974 +18759 +19335 +19200 +19129 +19328 +19437 +18988 +19429 +19368 +19406 +19049 +18811 +19296 +19256 +19385 +19602 +18770 +19337 +19580 +19476 +19045 +19132 +19089 +19120 +19265 +19483 +18767 +19227 +18934 +19069 +18820 +19006 +19459 +18927 +19037 +19280 +19441 +18823 +19015 +19114 +19618 +18957 +19176 +18853 +19648 +19201 +19444 +19279 +18751 +19302 +19505 +18733 +19601 +19533 +18863 +19708 +19387 +19346 +19152 +19206 +18851 +19338 +19681 +19380 +19055 +18766 +19085 +19591 +19547 +18958 +19146 +18840 +19051 +19021 +19207 +19235 +19086 +18979 +19300 +18939 +19100 +19619 +19287 +18980 +19277 +19326 +19108 +18920 +19625 +19374 +19078 +18734 +19634 +19339 +18877 +19423 +19652 +19683 +19044 +18983 +19330 +19529 +19714 +19468 +19075 +19540 +18839 +19022 +19286 +19537 +19175 +19463 +19167 +19705 +19562 +19244 +19486 +19611 +18801 +19178 +19590 +18846 +19450 +19205 +19381 +18941 +19670 +19185 +19504 +19633 +18997 +19113 +19397 +19636 +19709 +19289 +19264 +19353 +19584 +19126 +18938 +19669 +18964 +19276 +18774 +19173 +19231 +18973 +18769 +19064 +19040 +19668 +18738 +19082 +19655 +19236 +19352 +19609 +19628 +18951 +19384 +19122 +18875 +18992 +18753 +19379 +19254 +19301 +19506 +19135 +19010 +19682 +19400 +19579 +19316 +19553 +19208 +19635 +19644 +18891 +19024 +18989 +19250 +18850 +19317 +18915 +19607 +18799 +18881 +19479 +19031 +19365 +19164 +18744 +18760 +19502 +19058 +19517 +18735 +19448 +19243 +19453 +19285 +18857 +19439 +19016 +18975 +19503 +18998 +18981 +19186 +18994 +19240 +19631 +19070 +19174 +18900 +19065 +19220 +19229 +18880 +19308 +19372 +19496 +18771 +19325 +19538 +19033 +18874 +19077 +19211 +18764 +19458 +19571 +19121 +19019 +19059 +19497 +18969 +19666 +19297 +19219 +19622 +19184 +18977 +19702 +19539 +19329 +19095 +19675 +18972 +19514 +19703 +19188 +18866 +18812 +19314 +18822 +18845 +19494 +19411 +18916 +19686 +18967 +19294 +19143 +19204 +18805 +19689 +19233 +18758 +18748 +19011 +19685 +19336 +19608 +19454 +19124 +18868 +18807 +19544 +19621 +19228 +19154 +19141 +19145 +19153 +18860 +19163 +19393 +19268 +19160 +19305 +19259 +19471 +19524 +18783 +19396 +18894 +19430 +19690 +19348 +19597 +19592 +19677 +18889 +19331 +18773 +19137 +19009 +18932 +19599 +18816 +19054 +19067 +19477 +19191 +18921 +18940 +19578 +19183 +19004 +19072 +19710 +19005 +19610 +18955 +19457 +19148 +18859 +18993 +19642 +19047 +19418 +19535 +19600 +19312 +19039 +19028 +18879 +19003 +19026 +19013 +19149 +19177 +19217 +18987 +19354 +19525 +19202 +19084 +19032 +18749 +18867 +19048 +18999 +19260 +19630 +18727 +19356 +19083 +18926 +18789 +19370 +18861 +19311 +19557 +19531 +19436 +19140 +19310 +19501 +18721 +19604 +19713 +19262 +19563 +19507 +19440 +19572 +19513 +19515 +19518 +19421 +19470 +19499 +19663 +19508 +18871 +19528 +19500 +19307 +19288 +19594 +19271 diff --git a/official/gnn/gcn/data/ind.pubmed.tx b/official/gnn/gcn/data/ind.pubmed.tx new file mode 100644 index 0000000000000000000000000000000000000000..eee4f3c988fde0da87a9ec7b01ca5e7863c1b1f1 Binary files /dev/null and b/official/gnn/gcn/data/ind.pubmed.tx differ diff --git a/official/gnn/gcn/data/ind.pubmed.ty b/official/gnn/gcn/data/ind.pubmed.ty new file mode 100644 index 0000000000000000000000000000000000000000..225a0bb9bb6430e3925d04688daf6a0da964c07a Binary files /dev/null and b/official/gnn/gcn/data/ind.pubmed.ty differ diff --git a/official/gnn/gcn/data/ind.pubmed.x b/official/gnn/gcn/data/ind.pubmed.x new file mode 100644 index 0000000000000000000000000000000000000000..16c0eca8d2b3297014ff80327416f9332eda10d7 Binary files /dev/null and b/official/gnn/gcn/data/ind.pubmed.x differ diff --git a/official/gnn/gcn/data/ind.pubmed.y b/official/gnn/gcn/data/ind.pubmed.y new file mode 100644 index 0000000000000000000000000000000000000000..e86670576c85340eeb2271b3c918bfa89322b2cf Binary files /dev/null and b/official/gnn/gcn/data/ind.pubmed.y differ diff --git a/official/gnn/gcn/data/trans.citeseer.graph b/official/gnn/gcn/data/trans.citeseer.graph new file mode 100644 index 0000000000000000000000000000000000000000..8f84c514dfc0a1247a400820c5e7d611fa4c486a Binary files /dev/null and b/official/gnn/gcn/data/trans.citeseer.graph differ diff --git a/official/gnn/gcn/data/trans.citeseer.tx b/official/gnn/gcn/data/trans.citeseer.tx new file mode 100644 index 0000000000000000000000000000000000000000..b2aff18aa1fc6f4f5e8c64db7643b59b24e42584 Binary files /dev/null and b/official/gnn/gcn/data/trans.citeseer.tx differ diff --git a/official/gnn/gcn/data/trans.citeseer.ty b/official/gnn/gcn/data/trans.citeseer.ty new file mode 100644 index 0000000000000000000000000000000000000000..3795f79df0cd1123f4478b77c148a5fbacc4cefb Binary files /dev/null and b/official/gnn/gcn/data/trans.citeseer.ty differ diff --git a/official/gnn/gcn/data/trans.citeseer.x b/official/gnn/gcn/data/trans.citeseer.x new file mode 100644 index 0000000000000000000000000000000000000000..f094104a66f2689ffbd01085fe6a856689e9db9e Binary files /dev/null and b/official/gnn/gcn/data/trans.citeseer.x differ diff --git a/official/gnn/gcn/data/trans.citeseer.y b/official/gnn/gcn/data/trans.citeseer.y new file mode 100644 index 0000000000000000000000000000000000000000..e857ac49b611da98e101a068cf4855350274f68a Binary files /dev/null and b/official/gnn/gcn/data/trans.citeseer.y differ diff --git a/official/gnn/gcn/data/trans.cora.graph b/official/gnn/gcn/data/trans.cora.graph new file mode 100644 index 0000000000000000000000000000000000000000..df2946aab731d71701c2e32149069f8b0c7acd14 Binary files /dev/null and b/official/gnn/gcn/data/trans.cora.graph differ diff --git a/official/gnn/gcn/data/trans.cora.tx b/official/gnn/gcn/data/trans.cora.tx new file mode 100644 index 0000000000000000000000000000000000000000..6e856d777401ee15dc8619db76c97d0a40ba2d60 Binary files /dev/null and b/official/gnn/gcn/data/trans.cora.tx differ diff --git a/official/gnn/gcn/data/trans.cora.ty b/official/gnn/gcn/data/trans.cora.ty new file mode 100644 index 0000000000000000000000000000000000000000..da1734ab670b5c8ecee78a9c30b35566df916f4b Binary files /dev/null and b/official/gnn/gcn/data/trans.cora.ty differ diff --git a/official/gnn/gcn/data/trans.cora.x b/official/gnn/gcn/data/trans.cora.x new file mode 100644 index 0000000000000000000000000000000000000000..c4a91d008245403e7f26aa616b437191509793c0 Binary files /dev/null and b/official/gnn/gcn/data/trans.cora.x differ diff --git a/official/gnn/gcn/data/trans.cora.y b/official/gnn/gcn/data/trans.cora.y new file mode 100644 index 0000000000000000000000000000000000000000..58e30ef120d3553cfaac33ff28c6b5436dbef661 Binary files /dev/null and b/official/gnn/gcn/data/trans.cora.y differ diff --git a/official/gnn/gcn/data/trans.pubmed.graph b/official/gnn/gcn/data/trans.pubmed.graph new file mode 100644 index 0000000000000000000000000000000000000000..fa65f4d555424aee1b4ba08431a8803ccd17bf76 Binary files /dev/null and b/official/gnn/gcn/data/trans.pubmed.graph differ diff --git a/official/gnn/gcn/data/trans.pubmed.tx b/official/gnn/gcn/data/trans.pubmed.tx new file mode 100644 index 0000000000000000000000000000000000000000..eee4f3c988fde0da87a9ec7b01ca5e7863c1b1f1 Binary files /dev/null and b/official/gnn/gcn/data/trans.pubmed.tx differ diff --git a/official/gnn/gcn/data/trans.pubmed.ty b/official/gnn/gcn/data/trans.pubmed.ty new file mode 100644 index 0000000000000000000000000000000000000000..225a0bb9bb6430e3925d04688daf6a0da964c07a Binary files /dev/null and b/official/gnn/gcn/data/trans.pubmed.ty differ diff --git a/official/gnn/gcn/data/trans.pubmed.x b/official/gnn/gcn/data/trans.pubmed.x new file mode 100644 index 0000000000000000000000000000000000000000..16c0eca8d2b3297014ff80327416f9332eda10d7 Binary files /dev/null and b/official/gnn/gcn/data/trans.pubmed.x differ diff --git a/official/gnn/gcn/data/trans.pubmed.y b/official/gnn/gcn/data/trans.pubmed.y new file mode 100644 index 0000000000000000000000000000000000000000..e86670576c85340eeb2271b3c918bfa89322b2cf Binary files /dev/null and b/official/gnn/gcn/data/trans.pubmed.y differ diff --git a/official/gnn/gcn/export.py b/official/gnn/gcn/export.py index aa91b622bdd377861f141086c1e00b5fcf77e73c..41940343659fea49f71f99e76e633e024d07c745 100644 --- a/official/gnn/gcn/export.py +++ b/official/gnn/gcn/export.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,8 @@ # ============================================================================ """export checkpoint file into air models""" import argparse +import os + import numpy as np from mindspore import Tensor, context, load_checkpoint, export @@ -23,10 +25,10 @@ from src.config import ConfigGCN parser = argparse.ArgumentParser(description="GCN export") parser.add_argument("--device_id", type=int, default=0, help="Device id") -parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") +parser.add_argument("--ckpt_dir", type=str, default = "./results/model", help="Checkpoint file path.") parser.add_argument("--dataset", type=str, default="cora", choices=["cora", "citeseer"], help="Dataset.") -parser.add_argument("--file_name", type=str, default="gcn", help="output file name.") -parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format") +parser.add_argument("--file_dir", type=str, default="./results/model", help="output file name.") +parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="ONNX", help="file format") parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)") args = parser.parse_args() @@ -41,18 +43,18 @@ if __name__ == "__main__": if args.dataset == "cora": input_dim = 1433 class_num = 7 - adj = Tensor(np.zeros((2708, 2708), np.float64)) - feature = Tensor(np.zeros((2708, 1433), np.float32)) + node_nums = 2708 + adj = Tensor(np.zeros((1, node_nums*node_nums), np.float32)) + feature = Tensor(np.zeros((1, node_nums*input_dim), np.float32)) else: input_dim = 3703 class_num = 6 - adj = Tensor(np.zeros((3312, 3312), np.float64)) - feature = Tensor(np.zeros((3312, 3703), np.float32)) - - gcn_net = GCN(config, input_dim, class_num) + node_nums = 3312 + adj = Tensor(np.zeros((1, node_nums*node_nums), np.float32)) + feature = Tensor(np.zeros((1, node_nums*input_dim), np.float32)) + gcn_net = GCN(config, input_dim, class_num, node_nums) gcn_net.set_train(False) - load_checkpoint(args.ckpt_file, net=gcn_net) + load_checkpoint(os.path.join(args.ckpt_dir, args.dataset + ".ckpt"), net=gcn_net) gcn_net.add_flags_recursive(fp16=True) - - export(gcn_net, adj, feature, file_name=args.file_name, file_format=args.file_format) + export(gcn_net, adj, feature, file_name=os.path.join(args.file_dir, args.dataset), file_format=args.file_format) diff --git a/official/gnn/gcn/infer/convert/run.sh b/official/gnn/gcn/infer/convert/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..d4001b0300af5ecab0bb32743ce2c24eb04cec81 --- /dev/null +++ b/official/gnn/gcn/infer/convert/run.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ $# != 1 ] +then + echo "Usage: sh run_train.sh [DATASET_NAME]" +exit 1 +fi + +DATASET_NAME=$1 +echo $DATASET_NAME + +cd ../../ || exit + +if [ $DATASET_NAME = cora ]; then + atc --model=./infer/data/model/cora.air \ + --framework=1 \ + --output=./infer/data/model/cora \ + --soc_version=Ascend310 +else + atc --model=./infer/data/model/citeseer.air \ + --framework=1 \ + --output=./infer/data/model/citeseer \ + --soc_version=Ascend310 +fi + + + + diff --git a/official/gnn/gcn/infer/data/config/gcn_citeseer.pipeline b/official/gnn/gcn/infer/data/config/gcn_citeseer.pipeline new file mode 100644 index 0000000000000000000000000000000000000000..c281ab11608ab411f8910e79e290f1a4e9936782 --- /dev/null +++ b/official/gnn/gcn/infer/data/config/gcn_citeseer.pipeline @@ -0,0 +1,42 @@ +{ + "gcn": { + "stream_config": { + "deviceId": "0" + }, + "appsrc0": { + "props": { + "blocksize": "409600" + }, + "factory": "appsrc", + "next": "mxpi_tensorinfer0:0" + }, + "appsrc1": { + "props": { + "blocksize": "409600" + }, + "factory": "appsrc", + "next": "mxpi_tensorinfer0:1" + }, + "mxpi_tensorinfer0": { + "props": { + "dataSource":"appsrc0,appsrc1", + "modelPath": "../data/model/citeseer.om" + }, + "factory": "mxpi_tensorinfer", + "next": "mxpi_dataserialize0" + }, + "mxpi_dataserialize0": { + "props": { + "outputDataKeys": "mxpi_tensorinfer0" + }, + "factory": "mxpi_dataserialize", + "next": "appsink0" + }, + "appsink0": { + "props": { + "blocksize": "4096000" + }, + "factory": "appsink" + } + } +} diff --git a/official/gnn/gcn/infer/data/config/gcn_cora.pipeline b/official/gnn/gcn/infer/data/config/gcn_cora.pipeline new file mode 100644 index 0000000000000000000000000000000000000000..dba0fda37da98c87b197aae406db8bce2d5e9900 --- /dev/null +++ b/official/gnn/gcn/infer/data/config/gcn_cora.pipeline @@ -0,0 +1,42 @@ +{ + "gcn": { + "stream_config": { + "deviceId": "0" + }, + "appsrc0": { + "props": { + "blocksize": "409600" + }, + "factory": "appsrc", + "next": "mxpi_tensorinfer0:0" + }, + "appsrc1": { + "props": { + "blocksize": "409600" + }, + "factory": "appsrc", + "next": "mxpi_tensorinfer0:1" + }, + "mxpi_tensorinfer0": { + "props": { + "dataSource":"appsrc0,appsrc1", + "modelPath": "../data/model/cora.om" + }, + "factory": "mxpi_tensorinfer", + "next": "mxpi_dataserialize0" + }, + "mxpi_dataserialize0": { + "props": { + "outputDataKeys": "mxpi_tensorinfer0" + }, + "factory": "mxpi_dataserialize", + "next": "appsink0" + }, + "appsink0": { + "props": { + "blocksize": "4096000" + }, + "factory": "appsink" + } + } +} diff --git a/official/gnn/gcn/infer/data/model/citeseer.air b/official/gnn/gcn/infer/data/model/citeseer.air new file mode 100644 index 0000000000000000000000000000000000000000..9f157896ea705ff6e9a7fbfe85fc112b6455363b Binary files /dev/null and b/official/gnn/gcn/infer/data/model/citeseer.air differ diff --git a/official/gnn/gcn/infer/data/model/citeseer.om b/official/gnn/gcn/infer/data/model/citeseer.om new file mode 100644 index 0000000000000000000000000000000000000000..7709c64b8891825489173c62b9ff67380d28e10f Binary files /dev/null and b/official/gnn/gcn/infer/data/model/citeseer.om differ diff --git a/official/gnn/gcn/infer/data/model/citeseer.onnx b/official/gnn/gcn/infer/data/model/citeseer.onnx new file mode 100644 index 0000000000000000000000000000000000000000..6bab6a97e1c9e880c25eda27924086a658356c27 Binary files /dev/null and b/official/gnn/gcn/infer/data/model/citeseer.onnx differ diff --git a/official/gnn/gcn/infer/data/model/cora.air b/official/gnn/gcn/infer/data/model/cora.air new file mode 100644 index 0000000000000000000000000000000000000000..aa4eb0f38331164bad5a9413c19affa3e5ab09bc Binary files /dev/null and b/official/gnn/gcn/infer/data/model/cora.air differ diff --git a/official/gnn/gcn/infer/data/model/cora.om b/official/gnn/gcn/infer/data/model/cora.om new file mode 100644 index 0000000000000000000000000000000000000000..d39ac68caa52217a3c575e59476beb573972bc36 Binary files /dev/null and b/official/gnn/gcn/infer/data/model/cora.om differ diff --git a/official/gnn/gcn/infer/data/model/cora.onnx b/official/gnn/gcn/infer/data/model/cora.onnx new file mode 100644 index 0000000000000000000000000000000000000000..247b3e2473ffb6977e42663f9d06ce59203d42d9 Binary files /dev/null and b/official/gnn/gcn/infer/data/model/cora.onnx differ diff --git a/official/gnn/gcn/infer/mxbase/run.sh b/official/gnn/gcn/infer/mxbase/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..fb26b1c006ad0659973e9fceeeb2b2dd0a1565ee --- /dev/null +++ b/official/gnn/gcn/infer/mxbase/run.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +if [ $# != 1 ] +then + echo "Usage: sh run_train.sh [DATASET_NAME]" +exit 1 +fi + +DATASET_NAME=$1 +echo $DATASET_NAME + +set -e + +#CUR_PATH=$(cd "$(dirname "$0")" || { warn "Failed to check path/to/run.sh" ; exit ; } ; pwd) + +# Simple log helper functions +info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; } +warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; } + +#export MX_SDK_HOME="${CUR_PATH}/../../.." +export MX_SDK_HOME="/home/data/cz/app/mxManufacture" +export LD_LIBRARY_PATH="${MX_SDK_HOME}/lib":"${MX_SDK_HOME}/opensource/lib":"${MX_SDK_HOME}/opensource/lib64":"/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64":${LD_LIBRARY_PATH} +export GST_PLUGIN_SCANNER="${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner" +export GST_PLUGIN_PATH="${MX_SDK_HOME}/opensource/lib/gstreamer-1.0":"${MX_SDK_HOME}/lib/plugins" + +# complie +if [ $DATASET_NAME = cora ]; then +g++ ./src/cora_infer.cpp -I "${MX_SDK_HOME}/include/" -I "${MX_SDK_HOME}/opensource/include/" -L "${MX_SDK_HOME}/lib/" -L "${MX_SDK_HOME}/opensource/lib/" -L "${MX_SDK_HOME}/opensource/lib64/" -std=c++11 -D_GLIBCXX_USE_CXX11_ABI=0 -Dgoogle=mindxsdk_private -fPIC -fstack-protector-all -g -Wl,-z,relro,-z,now,-z,noexecstack -pie -Wall -lglog -lmxbase -lmxpidatatype -lplugintoolkit -lstreammanager -lcpprest -lmindxsdk_protobuf -o main + +else +g++ ./src/citeseer_infer.cpp -I "${MX_SDK_HOME}/include/" -I "${MX_SDK_HOME}/opensource/include/" -L "${MX_SDK_HOME}/lib/" -L "${MX_SDK_HOME}/opensource/lib/" -L "${MX_SDK_HOME}/opensource/lib64/" -std=c++11 -D_GLIBCXX_USE_CXX11_ABI=0 -Dgoogle=mindxsdk_private -fPIC -fstack-protector-all -g -Wl,-z,relro,-z,now,-z,noexecstack -pie -Wall -lglog -lmxbase -lmxpidatatype -lplugintoolkit -lstreammanager -lcpprest -lmindxsdk_protobuf -o main +fi + +# run +./main +exit 0 \ No newline at end of file diff --git a/official/gnn/gcn/infer/mxbase/src/citeseer_infer.cpp b/official/gnn/gcn/infer/mxbase/src/citeseer_infer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7a0e83287e1eb1b17a41c8aeea8151397c3cf1a1 --- /dev/null +++ b/official/gnn/gcn/infer/mxbase/src/citeseer_infer.cpp @@ -0,0 +1,336 @@ +/* + * Copyright (c) 2021.Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "MxBase/Log/Log.h" +#include "MxBase/MemoryHelper/MemoryHelper.h" +#include "MxBase/Tensor/TensorBase/TensorBase.h" +#include "MxStream/StreamManager/MxStreamManager.h" +#include "MxTools/Proto/MxpiDataType.pb.h" +#include +#include "MxTools/Proto/MxpiDataTypeDeleter.h" +#include +namespace { +const int INT32_BYTELEN = 4; +const int FLOAT16_BYTELEN = 2; +const int FLOAT32_BYTELEN = 4; +const int FLOAT64_BYTELEN = 8; + +const int FS_BITS = 0x8000; +const int FE_BITS = 0x7c00; +const int FM_BITS = 0x03ff; +const int FS_SHIFT = 16; +const int FE_SHIFT = 10; +const int FM_SHIFT = 13; +const int EXP_SHIFT = 23; +const int FE_BIAS = 0x70; + +const int INPUT_NUMS = 2; +const int TEST_NODE_NUMS = 1000; +} +using namespace std; + +int node_nums,class_nums; + +std::vector Split(const std::string& inString, char delimiter) +{ + std::vector result; + if (inString.empty()) { + return result; + } + + std::string::size_type fast = 0; + std::string::size_type slow = 0; + while ((fast = inString.find_first_of(delimiter, slow)) != std::string::npos) { + result.push_back(inString.substr(slow, fast - slow)); + slow = inString.find_first_not_of(delimiter, fast); + } + + if (slow != std::string::npos) { + result.push_back(inString.substr(slow, fast - slow)); + } + + return result; +} + +std::string& Trim(std::string& str) +{ + str.erase(0, str.find_first_not_of(' ')); + str.erase(str.find_last_not_of(' ') + 1); + return str; +} + +std::vector SplitWithRemoveBlank(std::string& str, char rule) +{ + Trim(str); + std::vector strVec = Split(str, rule); + for (size_t i = 0; i < strVec.size(); i++) { + strVec[i] = Trim(strVec[i]); + } + std::vector res = {}; + for (size_t i = 0; i < strVec.size(); i++) { + res.push_back(std::stoi(strVec[i])); + } + return res; +} + +std::string ReadPipelineConfig(std::string &pipelineConfigPath) +{ + std::ifstream file(pipelineConfigPath.c_str(), std::ifstream::binary); + if (!file) { + LogError << pipelineConfigPath <<" file is not exists"; + return ""; + } + file.seekg(0, std::ifstream::end); + uint32_t fileSize = file.tellg(); + file.seekg(0); + std::unique_ptr data(new char[fileSize]); + file.read(data.get(), fileSize); + file.close(); + std::string pipelineConfig(data.get(), fileSize); + return pipelineConfig; +} + + +//float恢复为32位 +float half_to_float(unsigned short h) +{ + short *ptr; + int fs, fe, fm, rlt; + ptr = (short *)&h; + fs = ((*ptr)&FS_BITS) << FS_SHIFT; + fe = ((*ptr)&FE_BITS) >> FE_SHIFT; + fe = fe + FE_BIAS; + fe = fe << EXP_SHIFT; + fm = ((*ptr)&FM_BITS) << FM_SHIFT; + rlt = fs | fe | fm; + return *((float *)&rlt); +} + +void GetTensors(const std::shared_ptr &tensorPackageList, std::vector &tensors) { + for (int i = 0; i < tensorPackageList->tensorpackagevec_size(); ++i) { + for (int j = 0; j < tensorPackageList->tensorpackagevec(i).tensorvec_size(); j++) { + MxBase::MemoryData memoryData = {}; + memoryData.deviceId = tensorPackageList->tensorpackagevec(i).tensorvec(j).deviceid(); + memoryData.type = (MxBase::MemoryData::MemoryType)tensorPackageList + ->tensorpackagevec(i) + .tensorvec(j) + .memtype(); + memoryData.size = (uint32_t)tensorPackageList->tensorpackagevec(i) + .tensorvec(j) + .tensordatasize(); + memoryData.ptrData = (void *)tensorPackageList->tensorpackagevec(i) + .tensorvec(j) + .tensordataptr(); + if (memoryData.type == MxBase::MemoryData::MEMORY_HOST || + memoryData.type == MxBase::MemoryData::MEMORY_HOST_MALLOC || + memoryData.type == MxBase::MemoryData::MEMORY_HOST_NEW) { + memoryData.deviceId = -1; + } + std::vector outputShape = {}; + for (int k = 0; k < tensorPackageList->tensorpackagevec(i) + .tensorvec(j) + .tensorshape_size(); + ++k) { + outputShape.push_back( + (uint32_t)tensorPackageList->tensorpackagevec(i) + .tensorvec(j) + .tensorshape(k)); + } + MxBase::TensorBase tmpTensor( + memoryData, true, outputShape, + (MxBase::TensorDataType)tensorPackageList->tensorpackagevec(0) + .tensorvec(j) + .tensordatatype()); + tensors.push_back(tmpTensor); + class_nums = tensors[0].GetSize(); + node_nums = memoryData.size/FLOAT16_BYTELEN/tensors[0].GetSize(); + cout<<"node_nums="<(new MxTools::MxpiTensorPackageList, MxTools::g_deleteFuncMxpiTensorPackageList); + auto tensorPackage = tensorPackageList->add_tensorpackagevec(); + auto tensorVec = tensorPackage->add_tensorvec(); + tensorVec->set_tensordataptr((uint64_t)memoryDst.ptrData); + tensorVec->set_tensordatasize(dataSize); + tensorVec->set_tensordatatype(MxBase::TENSOR_DTYPE_FLOAT32); + tensorVec->set_memtype(MxTools::MXPI_MEMORY_HOST_NEW); + tensorVec->set_deviceid(0); + tensorVec->add_tensorshape(1); + tensorVec->add_tensorshape(vec.size()); + + MxStream::MxstProtobufIn dataBuffer; + ostringstream dataSource; + dataSource << "appsrc" << inPluginId; + + dataBuffer.key = dataSource.str(); + dataBuffer.messagePtr = static_pointer_cast(tensorPackageList); + vector dataBufferVec; + dataBufferVec.push_back(dataBuffer); + ret = mxStreamManager->SendProtobuf(STREAM_NAME, inPluginId, dataBufferVec); + return ret; +} + + +vector get_pred(unsigned short *data) +{ + vector ret; + int index; + float temp, hf; + for(int i=0;i get_label(string file) +{ + vector onehots, ret; + ifstream ifile(file); + int k; + for(int i = 0; i < node_nums*class_nums; ++i) + { + ifile >> k; + onehots.push_back(k); + } + ifile.close(); + + for(int i=0;i pred, vector label, int test_nodes_num) +{ + int end = pred.size(); + int start = end - test_nodes_num; + float s = 0; + for(int i=start;i(); + APP_ERROR ret = mxStreamManager->InitManager(); + if (ret != APP_ERR_OK) { + LogError << GetError(ret) << "Failed to init Stream manager."; + return ret; + } + ret = mxStreamManager->CreateMultipleStreams(pipelineConfig); + if (ret != APP_ERR_OK) { + LogError << GetError(ret) << "Failed to create Stream."; + return ret; + } + std::vector filePaths = { + "../data/input/citeseer/adjacency.txt", "../data/input/citeseer/feature.txt" + }; + + for (int i = 0; i < INPUT_NUMS; i++) { + ifstream ifile(filePaths[i]); + ostringstream buf; + char ch; + while (buf && ifile.get(ch)) {buf.put(ch);} + std::string str = buf.str(); + stringstream ss(str);//初始化 + float x; + std::vector vec; + while (ss >> x){vec.push_back(x);} + ret = SendEachProtobuf(streamName,i,vec,mxStreamManager); + } + + std::vector keyVec = {"mxpi_tensorinfer0"}; + std::vector output = mxStreamManager->GetProtobuf(streamName, 0, keyVec); + + if (output.size() == 0) { + LogError << "output size is 0"; + return APP_ERR_ACL_FAILURE; + } + if (output[0].errorCode != APP_ERR_OK) { + LogError << "GetProtobuf error. errorCode=" << output[0].errorCode; + return output[0].errorCode; + } + LogInfo << "errorCode=" << output[0].errorCode; + LogInfo << "key=" << output[0].messageName; + LogInfo << "value=" << output[0].messagePtr->DebugString(); + + auto tensorPackageList = std::static_pointer_cast(output[0].messagePtr); + vector tensors = {}; + GetTensors(tensorPackageList, tensors); + void *tensorPtr = tensors[0].GetBuffer(); + std::vector sp = tensors[0].GetShape(); + unsigned short *ptr = (unsigned short *)tensorPtr; + + vector pred_label = get_pred(ptr); + vector true_label = get_label("../data/input/citeseer/label_onehot.txt"); + float accuracy = Acc(pred_label,true_label,TEST_NODE_NUMS); + + cout<<"============================ Infer Result ============================"<DestroyAllStreams(); + + return 0; +} \ No newline at end of file diff --git a/official/gnn/gcn/infer/mxbase/src/cora_infer.cpp b/official/gnn/gcn/infer/mxbase/src/cora_infer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..264b734334611c747cf8d58952d785428ce99382 --- /dev/null +++ b/official/gnn/gcn/infer/mxbase/src/cora_infer.cpp @@ -0,0 +1,337 @@ +/* + * Copyright (c) 2021.Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "MxBase/Log/Log.h" +#include "MxBase/MemoryHelper/MemoryHelper.h" +#include "MxBase/Tensor/TensorBase/TensorBase.h" +#include "MxStream/StreamManager/MxStreamManager.h" +#include "MxTools/Proto/MxpiDataType.pb.h" +#include +#include "MxTools/Proto/MxpiDataTypeDeleter.h" +#include +namespace { +const int INT32_BYTELEN = 4; +const int FLOAT16_BYTELEN = 2; +const int FLOAT32_BYTELEN = 4; +const int FLOAT64_BYTELEN = 8; + +const int FS_BITS = 0x8000; +const int FE_BITS = 0x7c00; +const int FM_BITS = 0x03ff; +const int FS_SHIFT = 16; +const int FE_SHIFT = 10; +const int FM_SHIFT = 13; +const int EXP_SHIFT = 23; +const int FE_BIAS = 0x70; + +const int INPUT_NUMS = 2; +const int TEST_NODE_NUMS = 1000; +} +using namespace std; + +int node_nums,class_nums; + +std::vector Split(const std::string& inString, char delimiter) +{ + std::vector result; + if (inString.empty()) { + return result; + } + + std::string::size_type fast = 0; + std::string::size_type slow = 0; + while ((fast = inString.find_first_of(delimiter, slow)) != std::string::npos) { + result.push_back(inString.substr(slow, fast - slow)); + slow = inString.find_first_not_of(delimiter, fast); + } + + if (slow != std::string::npos) { + result.push_back(inString.substr(slow, fast - slow)); + } + + return result; +} + +std::string& Trim(std::string& str) +{ + str.erase(0, str.find_first_not_of(' ')); + str.erase(str.find_last_not_of(' ') + 1); + return str; +} + +std::vector SplitWithRemoveBlank(std::string& str, char rule) +{ + Trim(str); + std::vector strVec = Split(str, rule); + for (size_t i = 0; i < strVec.size(); i++) { + strVec[i] = Trim(strVec[i]); + } + std::vector res = {}; + for (size_t i = 0; i < strVec.size(); i++) { + res.push_back(std::stoi(strVec[i])); + } + return res; +} + +std::string ReadPipelineConfig(std::string &pipelineConfigPath) +{ + std::ifstream file(pipelineConfigPath.c_str(), std::ifstream::binary); + if (!file) { + LogError << pipelineConfigPath <<" file is not exists"; + return ""; + } + file.seekg(0, std::ifstream::end); + uint32_t fileSize = file.tellg(); + file.seekg(0); + std::unique_ptr data(new char[fileSize]); + file.read(data.get(), fileSize); + file.close(); + std::string pipelineConfig(data.get(), fileSize); + return pipelineConfig; +} + +//float恢复为32位 +float half_to_float(unsigned short h) +{ + short *ptr; + int fs, fe, fm, rlt; + ptr = (short *)&h; + fs = ((*ptr)&FS_BITS) << FS_SHIFT; + fe = ((*ptr)&FE_BITS) >> FE_SHIFT; + fe = fe + FE_BIAS; + fe = fe << EXP_SHIFT; + fm = ((*ptr)&FM_BITS) << FM_SHIFT; + rlt = fs | fe | fm; + return *((float *)&rlt); +} + +void GetTensors(const std::shared_ptr &tensorPackageList, std::vector &tensors) { +// int size = tensorPackageList->tensorpackagevec(0).tensorvec(0).tensordatasize(); + for (int i = 0; i < tensorPackageList->tensorpackagevec_size(); ++i) { + for (int j = 0; j < tensorPackageList->tensorpackagevec(i).tensorvec_size(); j++) { +// cout<<'*'<tensorpackagevec(i).tensorvec_size()<tensorpackagevec(i).tensorvec(j).deviceid(); + memoryData.type = (MxBase::MemoryData::MemoryType)tensorPackageList + ->tensorpackagevec(i) + .tensorvec(j) + .memtype(); + memoryData.size = (uint32_t)tensorPackageList->tensorpackagevec(i) + .tensorvec(j) + .tensordatasize(); + memoryData.ptrData = (void *)tensorPackageList->tensorpackagevec(i) + .tensorvec(j) + .tensordataptr(); + if (memoryData.type == MxBase::MemoryData::MEMORY_HOST || + memoryData.type == MxBase::MemoryData::MEMORY_HOST_MALLOC || + memoryData.type == MxBase::MemoryData::MEMORY_HOST_NEW) { + memoryData.deviceId = -1; + } + std::vector outputShape = {}; + for (int k = 0; k < tensorPackageList->tensorpackagevec(i) + .tensorvec(j) + .tensorshape_size(); + ++k) { + outputShape.push_back( + (uint32_t)tensorPackageList->tensorpackagevec(i) + .tensorvec(j) + .tensorshape(k)); + } + MxBase::TensorBase tmpTensor( + memoryData, true, outputShape, + (MxBase::TensorDataType)tensorPackageList->tensorpackagevec(0) + .tensorvec(j) + .tensordatatype()); + tensors.push_back(tmpTensor); + class_nums = tensors[0].GetSize(); + node_nums = memoryData.size/FLOAT16_BYTELEN/tensors[0].GetSize(); + cout<<"node_nums="<(new MxTools::MxpiTensorPackageList, MxTools::g_deleteFuncMxpiTensorPackageList); + auto tensorPackage = tensorPackageList->add_tensorpackagevec(); + auto tensorVec = tensorPackage->add_tensorvec(); + tensorVec->set_tensordataptr((uint64_t)memoryDst.ptrData); + tensorVec->set_tensordatasize(dataSize); + tensorVec->set_tensordatatype(MxBase::TENSOR_DTYPE_FLOAT32); + tensorVec->set_memtype(MxTools::MXPI_MEMORY_HOST_NEW); + tensorVec->set_deviceid(0); + tensorVec->add_tensorshape(1); + tensorVec->add_tensorshape(vec.size()); + + MxStream::MxstProtobufIn dataBuffer; + ostringstream dataSource; + dataSource << "appsrc" << inPluginId; + + dataBuffer.key = dataSource.str(); + dataBuffer.messagePtr = static_pointer_cast(tensorPackageList); + vector dataBufferVec; + dataBufferVec.push_back(dataBuffer); + ret = mxStreamManager->SendProtobuf(STREAM_NAME, inPluginId, dataBufferVec); + return ret; +} + + +vector get_pred(unsigned short *data) +{ + vector ret; + int index; + float temp, hf; + for(int i=0;i get_label(string file) +{ + vector onehots,ret; + ifstream ifile(file); + int k; + for(int i = 0; i < node_nums*class_nums; ++i) + { + ifile >> k; + onehots.push_back(k); + } + ifile.close(); + + for(int i=0;i pred, vector label, int test_nodes_num) +{ + int end = pred.size(); + int start = end - test_nodes_num; + float s = 0; + for(int i=start;i(); + APP_ERROR ret = mxStreamManager->InitManager(); + if (ret != APP_ERR_OK) { + LogError << GetError(ret) << "Failed to init Stream manager."; + return ret; + } + ret = mxStreamManager->CreateMultipleStreams(pipelineConfig); + if (ret != APP_ERR_OK) { + LogError << GetError(ret) << "Failed to create Stream."; + return ret; + } + std::vector filePaths = { + "../data/input/cora/adjacency.txt", "../data/input/cora/feature.txt" + }; + + for (int i = 0; i < INPUT_NUMS; i++) { + ifstream ifile(filePaths[i]); + ostringstream buf; + char ch; + while (buf && ifile.get(ch)) {buf.put(ch);} + std::string str = buf.str(); + stringstream ss(str);//初始化 + float x; + std::vector vec; + while (ss >> x){vec.push_back(x);} + ret = SendEachProtobuf(streamName,i,vec,mxStreamManager); + } + + std::vector keyVec = {"mxpi_tensorinfer0"}; + std::vector output = mxStreamManager->GetProtobuf(streamName, 0, keyVec); + + if (output.size() == 0) { + LogError << "output size is 0"; + return APP_ERR_ACL_FAILURE; + } + if (output[0].errorCode != APP_ERR_OK) { + LogError << "GetProtobuf error. errorCode=" << output[0].errorCode; + return output[0].errorCode; + } + LogInfo << "errorCode=" << output[0].errorCode; + LogInfo << "key=" << output[0].messageName; + LogInfo << "value=" << output[0].messagePtr->DebugString(); + + auto tensorPackageList = std::static_pointer_cast(output[0].messagePtr); + vector tensors = {}; + GetTensors(tensorPackageList, tensors); + void *tensorPtr = tensors[0].GetBuffer(); + std::vector sp = tensors[0].GetShape(); + unsigned short *ptr = (unsigned short *)tensorPtr; + + vector pred_label = get_pred(ptr); + vector true_label = get_label("../data/input/cora/label_onehot.txt"); + float accuracy = Acc(pred_label,true_label,TEST_NODE_NUMS); + + cout<<"============================ Infer Result ============================"<DestroyAllStreams(); + + return 0; +} \ No newline at end of file diff --git a/official/gnn/gcn/infer/sdk/main_infer.py b/official/gnn/gcn/infer/sdk/main_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..ffa3b8426c888c574216b601517e25715df13775 --- /dev/null +++ b/official/gnn/gcn/infer/sdk/main_infer.py @@ -0,0 +1,196 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +sample script of CLUE infer using SDK run in docker +""" + +import argparse +import glob +import os +import time +from pathlib import Path + +import MxpiDataType_pb2 as MxpiDataType +import numpy as np +from StreamManagerApi import StreamManagerApi, MxDataInput, InProtobufVector, \ + MxProtobufIn, StringVector + +def parse_args(): + """set and check parameters.""" + parser = argparse.ArgumentParser(description="gcn process") + parser.add_argument("--dataset", type=str, default="cora", help="SDK infer pipeline") + parser.add_argument('--data_dir', type=str, default='../data/input', help='Data path') + parser.add_argument("--pipeline", type=str, default="../utils/gcn_cora.pipeline", help="SDK infer pipeline") + parser.add_argument('--data_adj', type=str, default="adjacency.txt") + parser.add_argument("--data_feature", type=str, default="feature.txt") + parser.add_argument('--data_label', type=str, default="label_onehot.txt") + parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test') + args_opt = parser.parse_args() + return args_opt + +args = parse_args() + +if args.dataset == "cora": + Node_num = 2708 + Feature_dim = 1433 + Class_num = 7 + args.pipeline = "../data/config/gcn_cora.pipeline" +elif args.dataset == "citeseer": + Node_num = 3312 + Feature_dim = 3703 + Class_num = 6 + args.pipeline = "../data/config/gcn_citeseer.pipeline" + +adj_shape = [1, Node_num*Node_num] +feature_shape = [1, Node_num*Feature_dim] + +def send_source_data(appsrc_id, filename, stream_name, stream_manager, shape, tp): + """ + Construct the input of the stream, + send inputs data to a specified stream based on streamName. + + Returns: + bool: send data success or not + """ + tensors = (np.loadtxt(os.path.join(args.data_dir, args.dataset, filename), dtype=tp)).astype(np.float32) + tensors = tensors.reshape(shape[0], shape[1]) + tensor_package_list = MxpiDataType.MxpiTensorPackageList() + tensor_package = tensor_package_list.tensorPackageVec.add() + data_input = MxDataInput() + tensor_vec = tensor_package.tensorVec.add() + tensor_vec.deviceId = 0 + tensor_vec.memType = 0 + for i in shape: + tensor_vec.tensorShape.append(i) + print(filename + " shape :", tensor_vec.tensorShape) + array_bytes = tensors.tobytes() + data_input.data = array_bytes + tensor_vec.dataStr = data_input.data + tensor_vec.tensorDataSize = len(array_bytes) + key = "appsrc{}".format(appsrc_id).encode('utf-8') + protobuf_vec = InProtobufVector() + protobuf = MxProtobufIn() + protobuf.key = key + protobuf.type = b'MxTools.MxpiTensorPackageList' + protobuf.protobuf = tensor_package_list.SerializeToString() + protobuf_vec.push_back(protobuf) + ret = stream_manager.SendProtobuf(stream_name, appsrc_id, protobuf_vec) + if ret < 0: + print("Failed to send data to stream.") + return False + else: + print("Send successfully!") + return True + +def send_appsrc_data(appsrc_id, file_name, stream_name, stream_manager, shape, tp): + """ + send three stream to infer model, include input ids, input mask and token type_id. + + Returns: + bool: send data success or not + """ + if not send_source_data(appsrc_id, file_name, stream_name, stream_manager, shape, tp): + return False + return True + +def accurate(label, preds): + """Accuracy with masking.""" + preds = preds.astype(np.float32) + correct_prediction = np.equal(np.argmax(preds, axis=1), np.argmax(label, axis=1)) + accuracy_all = correct_prediction.astype(np.float32) + mask = np.zeros([len(preds)]).astype(np.float32) + mask[len(preds) - args.test_nodes_num:len(preds)] = 1 + mask = mask.astype(np.float32) + mask_reduce = np.mean(mask) + mask = mask / mask_reduce + accuracy_all *= mask + return np.mean(accuracy_all) + +def post_process(args, infer_result): + """ + process the result of infer tensor to Visualization results. + Args: + args: param of config. + file_name: label file name. + infer_result: get logit from infer result + max_seq_length: sentence input length default is 128. + """ + # get the infer result + result = MxpiDataType.MxpiTensorPackageList() + result.ParseFromString(infer_result[0].messageBuf) + + res = np.frombuffer(result.tensorPackageVec[0].tensorVec[0].dataStr, dtype=np.float16) + res = res.reshape((Node_num, Class_num)) + + label = np.loadtxt(os.path.join(args.data_dir, args.dataset, args.data_label), dtype=np.int32) + label = label.reshape((Node_num, Class_num)) + + pred_label = np.argmax(res, axis=1) + ground_true_label = np.argmax(label, axis=1) + acc = accurate(label, res) + print('============================ Infer Result ============================') + print("Pred_label label:{}".format(pred_label)) + print("Infer acc:%f"%(acc)) + print('=======================================================================') + return + +def run(): + """ + read pipeline and do infer + """ + # init stream manager + stream_manager_api = StreamManagerApi() + ret = stream_manager_api.InitManager() + if ret != 0: + print("Failed to init Stream manager, ret=%s" % str(ret)) + return + + # create streams by pipeline config file + with open(os.path.realpath(args.pipeline), 'rb') as f: + pipeline_str = f.read() + ret = stream_manager_api.CreateMultipleStreams(pipeline_str) + if ret != 0: + print("Failed to create Stream, ret=%s" % str(ret)) + return + + stream_name = b'gcn' + infer_total_time = 0 + + if not send_appsrc_data(0, args.data_adj, stream_name, stream_manager_api, adj_shape, np.float64): + return + + if not send_appsrc_data(1, args.data_feature, stream_name, stream_manager_api, feature_shape, np.float32): + return + + # Obtain the inference result by specifying streamName and uniqueId. + key_vec = StringVector() + key_vec.push_back(b'mxpi_tensorinfer0') + start_time = time.time() + infer_result = stream_manager_api.GetProtobuf(stream_name, 0, key_vec) + infer_total_time += time.time() - start_time + if infer_result.size() == 0: + print("inferResult is null") + return + if infer_result[0].errorCode != 0: + print("GetProtobuf error. errorCode=%d" % (infer_result[0].errorCode)) + return + + post_process(args, infer_result) + stream_manager_api.DestroyAllStreams() + + +if __name__ == '__main__': + run() diff --git a/official/gnn/gcn/infer/sdk/run.sh b/official/gnn/gcn/infer/sdk/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..c10eed63a48dac27a00466b49211541de86adffb --- /dev/null +++ b/official/gnn/gcn/infer/sdk/run.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ $# != 1 ] +then + echo "Usage: sh run_train.sh [DATASET_NAME]" +exit 1 +fi + +DATASET_NAME=$1 +echo $DATASET_NAME + +set -e + +#CUR_PATH=$(cd "$(dirname "$0")" || { warn "Failed to check path/to/run.sh" ; exit ; } ; pwd) + +# Simple log helper functions +info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; } +warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; } + +export MX_SDK_HOME=/home/data/sjtu_liu/mxVision +export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH} +export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner +export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins + +#to set PYTHONPATH, import the StreamManagerApi.py +export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python + +python3.7 main_infer.py --dataset=$DATASET_NAME +exit 0 diff --git a/official/gnn/gcn/mindspore_hub_conf.py b/official/gnn/gcn/mindspore_hub_conf.py index b4bb80ef6e25002065b0da13ccc949fbae04f5a7..11f5c817e95ad5dcc2a959be2921ea5effd31a91 100644 --- a/official/gnn/gcn/mindspore_hub_conf.py +++ b/official/gnn/gcn/mindspore_hub_conf.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/official/gnn/gcn/preprocess.py b/official/gnn/gcn/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..5e19efa23504c99e5640dac359228623e74c535b --- /dev/null +++ b/official/gnn/gcn/preprocess.py @@ -0,0 +1,55 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +preprocess. +""" +import os +import argparse + +import numpy as np +from src.dataset import get_adj_features_labels, get_mask + +def generate_txt(): + """Generate txt files.""" + def w2txt(file, data): + s = "" + for i in range(len(data)): + s = s + str(data[i]) + s = s + " " + with open(file, "w") as f: + f.write(s) + + parser = argparse.ArgumentParser(description='preprocess') + parser.add_argument('--dataset', type=str, default='cora', help='Dataset name') + parser.add_argument('--data_dir', type=str, default='./results/data_mr', help='Dataset directory') + parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test') + parser.add_argument('--result_path', type=str, default='./results/data', help='Result path') + args_opt = parser.parse_args() + + if not os.path.exists(os.path.join(args_opt.result_path, args_opt.dataset)): + os.makedirs(os.path.join(args_opt.result_path, args_opt.dataset)) + + adj, feature, label_onehot, _ = get_adj_features_labels(os.path.join(args_opt.data_dir, args_opt.dataset)) + adj = (adj.reshape(-1)).astype(np.float32) + feature = (feature.reshape(-1)).astype(np.float32) + label_onehot = (label_onehot.reshape(-1)).astype(np.int32) + + w2txt(os.path.join(args_opt.result_path, args_opt.dataset, "adjacency.txt"), adj) + w2txt(os.path.join(args_opt.result_path, args_opt.dataset, "feature.txt"), feature) + w2txt(os.path.join(args_opt.result_path, args_opt.dataset, "label_onehot.txt"), label_onehot) + +if __name__ == '__main__': + generate_txt() \ No newline at end of file diff --git a/official/gnn/gcn/scripts/run_process_data.sh b/official/gnn/gcn/scripts/run_process_data.sh index 013cf2f28cbec2cf8c43def7909b122bc1458a09..03d2109b34b7ad97463ab2625926909edc75756b 100644 --- a/official/gnn/gcn/scripts/run_process_data.sh +++ b/official/gnn/gcn/scripts/run_process_data.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ MINDRECORD_PATH=`pwd`/data_mr rm -f $MINDRECORD_PATH/$DATASET_NAME rm -f $MINDRECORD_PATH/$DATASET_NAME.db -cd ../../utils/graph_to_mindrecord || exit +cd ../src || exit python writer.py --mindrecord_script $DATASET_NAME \ --mindrecord_file "$MINDRECORD_PATH/$DATASET_NAME" \ @@ -52,4 +52,4 @@ python writer.py --mindrecord_script $DATASET_NAME \ --mindrecord_page_size_by_bit 20 \ --graph_api_args "$SRC_PATH" -cd - || exit +cd - || exit \ No newline at end of file diff --git a/official/gnn/gcn/scripts/run_train.sh b/official/gnn/gcn/scripts/run_train.sh index 46dee49b0d7bc3f31a24122f24d5be768f1d7fb4..6129cde010521a597a41f6299dc77a95b0d44405 100644 --- a/official/gnn/gcn/scripts/run_train.sh +++ b/official/gnn/gcn/scripts/run_train.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/official/gnn/gcn/src/citeseer/__init__.py b/official/gnn/gcn/src/citeseer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/official/gnn/gcn/src/citeseer/mr_api.py b/official/gnn/gcn/src/citeseer/mr_api.py new file mode 100644 index 0000000000000000000000000000000000000000..eef72d1d9f34e2e13abeb41c9b962b1b8cc6249e --- /dev/null +++ b/official/gnn/gcn/src/citeseer/mr_api.py @@ -0,0 +1,129 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +User-defined API for MindRecord GNN writer. +""" +import os + +import pickle as pkl +import numpy as np +import scipy.sparse as sp +from mindspore import log as logger + +# parse args from command line parameter 'graph_api_args' +# args delimiter is ':' +args = os.environ['graph_api_args'].split(':') +CITESEER_PATH = args[0] +dataset_str = 'citeseer' + +# profile: (num_features, feature_data_types, feature_shapes) +node_profile = (2, ["float32", "int32"], [[-1], [-1]]) +edge_profile = (0, [], []) + +node_ids = [] + + +def _normalize_citeseer_features(features): + row_sum = np.array(features.sum(1)) + r_inv = np.power(row_sum * 1.0, -1).flatten() + r_inv[np.isinf(r_inv)] = 0. + r_mat_inv = sp.diags(r_inv) + features = r_mat_inv.dot(features) + return features + + +def _parse_index_file(filename): + """Parse index file.""" + index = [] + for line in open(filename): + index.append(int(line.strip())) + return index + + +def yield_nodes(task_id=0): + """ + Generate node data + + Yields: + data (dict): data row which is dict. + """ + logger.info("Node task is {}".format(task_id)) + names = ['x', 'y', 'tx', 'ty', 'allx', 'ally'] + objects = [] + for name in names: + with open("{}/ind.{}.{}".format(CITESEER_PATH, dataset_str, name), 'rb') as f: + objects.append(pkl.load(f, encoding='latin1')) + x, y, tx, ty, allx, ally = tuple(objects) + test_idx_reorder = _parse_index_file( + "{}/ind.{}.test.index".format(CITESEER_PATH, dataset_str)) + test_idx_range = np.sort(test_idx_reorder) + + tx = _normalize_citeseer_features(tx) + allx = _normalize_citeseer_features(allx) + + # Fix citeseer dataset (there are some isolated nodes in the graph) + # Find isolated nodes, add them as zero-vecs into the right position + test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1) + tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1])) + tx_extended[test_idx_range-min(test_idx_range), :] = tx + tx = tx_extended + ty_extended = np.zeros((len(test_idx_range_full), y.shape[1])) + ty_extended[test_idx_range-min(test_idx_range), :] = ty + ty = ty_extended + + features = sp.vstack((allx, tx)).tolil() + features[test_idx_reorder, :] = features[test_idx_range, :] + features = features.A + + labels = np.vstack((ally, ty)) + labels[test_idx_reorder, :] = labels[test_idx_range, :] + + line_count = 0 + for i, label in enumerate(labels): + if not 1 in label.tolist(): + continue + node = {'id': i, 'type': 0, 'feature_1': features[i].tolist(), + 'feature_2': label.tolist().index(1)} + line_count += 1 + node_ids.append(i) + yield node + logger.info('Processed {} lines for nodes.'.format(line_count)) + + +def yield_edges(task_id=0): + """ + Generate edge data + + Yields: + data (dict): data row which is dict. + """ + logger.info("Edge task is {}".format(task_id)) + with open("{}/ind.{}.graph".format(CITESEER_PATH, dataset_str), 'rb') as f: + graph = pkl.load(f, encoding='latin1') + line_count = 0 + for i in graph: + for dst_id in graph[i]: + if not i in node_ids: + logger.info('Source node {} does not exist.'.format(i)) + continue + if not dst_id in node_ids: + logger.info('Destination node {} does not exist.'.format( + dst_id)) + continue + edge = {'id': line_count, + 'src_id': i, 'dst_id': dst_id, 'type': 0} + line_count += 1 + yield edge + logger.info('Processed {} lines for edges.'.format(line_count)) diff --git a/official/gnn/gcn/src/config.py b/official/gnn/gcn/src/config.py index 8efa668e2facb13fc4b64e2315558a0602d00d27..d46dd507342d3b608af7d6355adc48fbb14325a0 100644 --- a/official/gnn/gcn/src/config.py +++ b/official/gnn/gcn/src/config.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/official/gnn/gcn/src/cora/__init__.py b/official/gnn/gcn/src/cora/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/official/gnn/gcn/src/cora/mr_api.py b/official/gnn/gcn/src/cora/mr_api.py new file mode 100644 index 0000000000000000000000000000000000000000..ee1c1e87dc7e840cdaec06038b297ea2adc83aef --- /dev/null +++ b/official/gnn/gcn/src/cora/mr_api.py @@ -0,0 +1,105 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +User-defined API for MindRecord GNN writer. +""" +import os + +import pickle as pkl +import numpy as np +import scipy.sparse as sp + +# parse args from command line parameter 'graph_api_args' +# args delimiter is ':' +args = os.environ['graph_api_args'].split(':') +CORA_PATH = args[0] +dataset_str = 'cora' + +# profile: (num_features, feature_data_types, feature_shapes) +node_profile = (2, ["float32", "int32"], [[-1], [-1]]) +edge_profile = (0, [], []) + + +def _normalize_cora_features(features): + row_sum = np.array(features.sum(1)) + r_inv = np.power(row_sum * 1.0, -1).flatten() + r_inv[np.isinf(r_inv)] = 0. + r_mat_inv = sp.diags(r_inv) + features = r_mat_inv.dot(features) + return features + + +def _parse_index_file(filename): + """Parse index file.""" + index = [] + for line in open(filename): + index.append(int(line.strip())) + return index + + +def yield_nodes(task_id=0): + """ + Generate node data + + Yields: + data (dict): data row which is dict. + """ + print("Node task is {}".format(task_id)) + + names = ['tx', 'ty', 'allx', 'ally'] + objects = [] + for name in names: + with open("{}/ind.{}.{}".format(CORA_PATH, dataset_str, name), 'rb') as f: + objects.append(pkl.load(f, encoding='latin1')) + tx, ty, allx, ally = tuple(objects) + test_idx_reorder = _parse_index_file( + "{}/ind.{}.test.index".format(CORA_PATH, dataset_str)) + test_idx_range = np.sort(test_idx_reorder) + + features = sp.vstack((allx, tx)).tolil() + features[test_idx_reorder, :] = features[test_idx_range, :] + features = _normalize_cora_features(features) + features = features.A + + labels = np.vstack((ally, ty)) + labels[test_idx_reorder, :] = labels[test_idx_range, :] + + line_count = 0 + for i, label in enumerate(labels): + node = {'id': i, 'type': 0, 'feature_1': features[i].tolist(), + 'feature_2': label.tolist().index(1)} + line_count += 1 + yield node + print('Processed {} lines for nodes.'.format(line_count)) + + +def yield_edges(task_id=0): + """ + Generate edge data + + Yields: + data (dict): data row which is dict. + """ + print("Edge task is {}".format(task_id)) + with open("{}/ind.{}.graph".format(CORA_PATH, dataset_str), 'rb') as f: + graph = pkl.load(f, encoding='latin1') + line_count = 0 + for i in graph: + for dst_id in graph[i]: + edge = {'id': line_count, + 'src_id': i, 'dst_id': dst_id, 'type': 0} + line_count += 1 + yield edge + print('Processed {} lines for edges.'.format(line_count)) diff --git a/official/gnn/gcn/src/dataset.py b/official/gnn/gcn/src/dataset.py index 2402b66c4b055fc6095926e07d161b5c4e497dcd..3b0982a2f4a6d8da48f54e905cc385dbbf931f55 100644 --- a/official/gnn/gcn/src/dataset.py +++ b/official/gnn/gcn/src/dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/official/gnn/gcn/src/gcn.py b/official/gnn/gcn/src/gcn.py index 7da858a091cf37cda861a1274e3d0c4e2eaaba29..784528c698d25ed2fc5a7d65740d789e8813fc8e 100644 --- a/official/gnn/gcn/src/gcn.py +++ b/official/gnn/gcn/src/gcn.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -75,7 +75,6 @@ class GraphConvolution(nn.Cell): fc = self.fc(dropout) output_feature = self.matmul(adj, fc) - if self.activation_flag: output_feature = self.activation(output_feature) return output_feature @@ -91,13 +90,16 @@ class GCN(nn.Cell): feature (numpy.ndarray): Input channel in each layer. output_dim (int): The number of output channels, equal to classes num. """ - - def __init__(self, config, input_dim, output_dim): + def __init__(self, config, input_dim, output_dim, node_nums): super(GCN, self).__init__() + self.input_dim = input_dim + self.node_nums = node_nums self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout) self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None) def construct(self, adj, feature): + adj = adj.view((self.node_nums, self.node_nums)) + feature = feature.view((self.node_nums, self.input_dim)) output0 = self.layer0(adj, feature) output1 = self.layer1(adj, output0) - return output1 + return output1 \ No newline at end of file diff --git a/official/gnn/gcn/src/graph_map_schema.py b/official/gnn/gcn/src/graph_map_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..71807d53037a9bdab918331b7827b7d3e385ccba --- /dev/null +++ b/official/gnn/gcn/src/graph_map_schema.py @@ -0,0 +1,164 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License") +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Graph data convert tool for MindRecord. +""" +import numpy as np +from mindspore import log as logger + +__all__ = ['GraphMapSchema'] + + +class GraphMapSchema: + """ + Class is for transformation from graph data to MindRecord. + """ + + def __init__(self): + """ + init + """ + self.num_node_features = 0 + self.num_edge_features = 0 + self.union_schema_in_mindrecord = { + "first_id": {"type": "int64"}, + "second_id": {"type": "int64"}, + "third_id": {"type": "int64"}, + "type": {"type": "int32"}, + "weight": {"type": "float32"}, + "attribute": {"type": "string"}, # 'n' for ndoe, 'e' for edge + "node_feature_index": {"type": "int32", "shape": [-1]}, + "edge_feature_index": {"type": "int32", "shape": [-1]} + } + + @property + def get_schema(self): + """ + Get schema + """ + return self.union_schema_in_mindrecord + + def set_node_feature_profile(self, num_features, features_data_type, features_shape): + """ + Set node features profile + """ + if num_features != len(features_data_type) or num_features != len(features_shape): + logger.info("Node feature profile is not match.") + raise ValueError("Node feature profile is not match.") + + self.num_node_features = num_features + for i in range(num_features): + k = i + 1 + field_key = 'node_feature_' + str(k) + field_value = {"type": features_data_type[i], "shape": features_shape[i]} + self.union_schema_in_mindrecord[field_key] = field_value + + def set_edge_feature_profile(self, num_features, features_data_type, features_shape): + """ + Set edge features profile + """ + if num_features != len(features_data_type) or num_features != len(features_shape): + logger.info("Edge feature profile is not match.") + raise ValueError("Edge feature profile is not match.") + + self.num_edge_features = num_features + for i in range(num_features): + k = i + 1 + field_key = 'edge_feature_' + str(k) + field_value = {"type": features_data_type[i], "shape": features_shape[i]} + self.union_schema_in_mindrecord[field_key] = field_value + + def transform_node(self, node): + """ + Executes transformation from node data to union format. + Args: + node(schema): node's data + Returns: + graph data with union schema + """ + if node is None: + logger.info("node cannot be None.") + raise ValueError("node cannot be None.") + + node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "weight": 1.0, "attribute": 'n', + "type": node["type"], "node_feature_index": []} + if "weight" in node: + node_graph["weight"] = node["weight"] + + for i in range(self.num_node_features): + k = i + 1 + node_field_key = 'feature_' + str(k) + graph_field_key = 'node_feature_' + str(k) + graph_field_type = self.union_schema_in_mindrecord[graph_field_key]["type"] + if node_field_key in node: + node_graph["node_feature_index"].append(k) + node_graph[graph_field_key] = np.reshape(np.array(node[node_field_key], dtype=graph_field_type), [-1]) + else: + node_graph[graph_field_key] = np.reshape(np.array([0], dtype=graph_field_type), [-1]) + + if node_graph["node_feature_index"]: + node_graph["node_feature_index"] = np.array(node_graph["node_feature_index"], dtype="int32") + else: + node_graph["node_feature_index"] = np.array([-1], dtype="int32") + + node_graph["edge_feature_index"] = np.array([-1], dtype="int32") + for i in range(self.num_edge_features): + k = i + 1 + graph_field_key = 'edge_feature_' + str(k) + graph_field_type = self.union_schema_in_mindrecord[graph_field_key]["type"] + node_graph[graph_field_key] = np.reshape(np.array([0], dtype=graph_field_type), [-1]) + return node_graph + + def transform_edge(self, edge): + """ + Executes transformation from edge data to union format. + Args: + edge(schema): edge's data + Returns: + graph data with union schema + """ + if edge is None: + logger.info("edge cannot be None.") + raise ValueError("edge cannot be None.") + + edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "weight": 1.0, + "attribute": 'e', "type": edge["type"], "edge_feature_index": []} + + if "weight" in edge: + edge_graph["weight"] = edge["weight"] + + for i in range(self.num_edge_features): + k = i + 1 + edge_field_key = 'feature_' + str(k) + graph_field_key = 'edge_feature_' + str(k) + graph_field_type = self.union_schema_in_mindrecord[graph_field_key]["type"] + if edge_field_key in edge: + edge_graph["edge_feature_index"].append(k) + edge_graph[graph_field_key] = np.reshape(np.array(edge[edge_field_key], dtype=graph_field_type), [-1]) + else: + edge_graph[graph_field_key] = np.reshape(np.array([0], dtype=graph_field_type), [-1]) + + if edge_graph["edge_feature_index"]: + edge_graph["edge_feature_index"] = np.array(edge_graph["edge_feature_index"], dtype="int32") + else: + edge_graph["edge_feature_index"] = np.array([-1], dtype="int32") + + edge_graph["node_feature_index"] = np.array([-1], dtype="int32") + for i in range(self.num_node_features): + k = i + 1 + graph_field_key = 'node_feature_' + str(k) + graph_field_type = self.union_schema_in_mindrecord[graph_field_key]["type"] + edge_graph[graph_field_key] = np.array([0], dtype=graph_field_type) + return edge_graph diff --git a/official/gnn/gcn/src/metrics.py b/official/gnn/gcn/src/metrics.py index d11236fa73b2bc2a6acc5b28fda4c3049a442c47..3980dffc94ba0d947934bfcaccfbf4362436fe2e 100644 --- a/official/gnn/gcn/src/metrics.py +++ b/official/gnn/gcn/src/metrics.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/official/gnn/gcn/src/writer.py b/official/gnn/gcn/src/writer.py new file mode 100644 index 0000000000000000000000000000000000000000..8d2cd01a36279ee7095f337a6a5e24e458f96457 --- /dev/null +++ b/official/gnn/gcn/src/writer.py @@ -0,0 +1,160 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +######################## write mindrecord example ######################## +Write mindrecord by data dictionary: +python writer.py mindrecord_script /YourScriptPath ... +""" +import argparse +import os +import time +from importlib import import_module +from multiprocessing import Pool + +from mindspore.mindrecord import FileWriter +from src.graph_map_schema import GraphMapSchema + + +def exec_task(task_id, mindrecord_dict_data, graph_map_schema, writer, parallel_writer=True): + """ + Execute task with specified task id + """ + print("exec task {}, parallel: {} ...".format(task_id, parallel_writer)) + imagenet_iter = mindrecord_dict_data(task_id) + batch_size = 512 + transform_count = 0 + while True: + data_list = [] + try: + for _ in range(batch_size): + data = imagenet_iter.__next__() + if 'dst_id' in data: + data = graph_map_schema.transform_edge(data) + else: + data = graph_map_schema.transform_node(data) + data_list.append(data) + transform_count += 1 + writer.write_raw_data(data_list, parallel_writer=parallel_writer) + print("transformed {} record...".format(transform_count)) + except StopIteration: + if data_list: + writer.write_raw_data(data_list, parallel_writer=parallel_writer) + print("transformed {} record...".format(transform_count)) + break + + +def init_writer(mr_schema, mindrecord_file, mindrecord_partitions, mindrecord_header_size_by_bit, mindrecord_page_size_by_bit): + """ + init writer + """ + print("Init writer ...") + mr_writer = FileWriter(mindrecord_file, mindrecord_partitions) + + # set the header size + if mindrecord_header_size_by_bit != 24: + header_size = 1 << mindrecord_header_size_by_bit + mr_writer.set_header_size(header_size) + + # set the page size + if mindrecord_page_size_by_bit != 25: + page_size = 1 << mindrecord_page_size_by_bit + mr_writer.set_page_size(page_size) + + # create the schema + mr_writer.add_schema(mr_schema, "mindrecord_graph_schema") + + # open file and set header + mr_writer.open_and_set_header() + + return mr_writer + + +def run_parallel_workers(num_tasks, mindrecord_workers, mindrecord_dict_data, graph_map_schema, writer): + """ + run parallel workers + """ + # set number of workers + num_workers = mindrecord_workers + + task_list = list(range(num_tasks)) + + if num_workers > num_tasks: + num_workers = num_tasks + + if os.name == 'nt': + for window_task_id in task_list: + exec_task(window_task_id, mindrecord_dict_data, graph_map_schema, writer, False) + elif num_tasks > 1: + with Pool(num_workers) as p: + p.map(exec_task, task_list) + else: + exec_task(0, mindrecord_dict_data, graph_map_schema, writer, False) + + +def writer_data(mindrecord_script = "template", + mindrecord_file = "/tmp/mindrecord/xyz", + mindrecord_partitions = 1, + mindrecord_header_size_by_bit = 24, + mindrecord_page_size_by_bit = 25, + mindrecord_workers = 8, + num_node_tasks = 1, + num_edge_tasks = 1, + graph_api_args = "/tmp/nodes.csv:/tmp/edges.csv"): + print(mindrecord_file) + print(graph_api_args) + if os.path.exists(os.path.join(mindrecord_file, mindrecord_script)): + os.remove(os.path.join(mindrecord_file, mindrecord_script)) + os.remove(os.path.join(mindrecord_file, mindrecord_script + ".db")) + + mindrecord_file = os.path.join(mindrecord_file, mindrecord_script) + start_time = time.time() + + # pass mr_api arguments + os.environ['graph_api_args'] = graph_api_args + + try: + mr_api = import_module('src.' + mindrecord_script + '.mr_api') + except ModuleNotFoundError: + raise RuntimeError("Unknown module path: {}".format(mindrecord_script + '.mr_api')) + + # init graph schema + graph_map_schema = GraphMapSchema() + + num_features, feature_data_types, feature_shapes = mr_api.node_profile + graph_map_schema.set_node_feature_profile(num_features, feature_data_types, feature_shapes) + + num_features, feature_data_types, feature_shapes = mr_api.edge_profile + graph_map_schema.set_edge_feature_profile(num_features, feature_data_types, feature_shapes) + + graph_schema = graph_map_schema.get_schema + + # init writer + writer = init_writer(graph_schema, mindrecord_file, mindrecord_partitions, mindrecord_header_size_by_bit, mindrecord_page_size_by_bit) + + # write nodes data + mindrecord_dict_data = mr_api.yield_nodes + run_parallel_workers(num_node_tasks, mindrecord_workers, mindrecord_dict_data, graph_map_schema, writer) + + # write edges data + mindrecord_dict_data = mr_api.yield_edges + run_parallel_workers(num_edge_tasks, mindrecord_workers, mindrecord_dict_data, graph_map_schema, writer) + + # writer wrap up + ret = writer.commit() + + end_time = time.time() + print("") + print("END. Total time: {}".format(end_time - start_time)) + print("") diff --git a/official/gnn/gcn/train.py b/official/gnn/gcn/train.py index 706d4ee67e18fb9112d114e0084f135e6b9ddf5c..219c985ea3879b4423ef6078cb3eff66c3785e92 100644 --- a/official/gnn/gcn/train.py +++ b/official/gnn/gcn/train.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,42 +26,58 @@ from matplotlib import pyplot as plt from matplotlib import animation from sklearn import manifold from mindspore import context -from mindspore import Tensor +from mindspore import Tensor, export from mindspore.train.serialization import save_checkpoint, load_checkpoint +from src.writer import writer_data from src.gcn import GCN from src.metrics import LossAccuracyWrapper, TrainNetWrapper from src.config import ConfigGCN from src.dataset import get_adj_features_labels, get_mask - def t_SNE(out_feature, dim): t_sne = manifold.TSNE(n_components=dim, init='pca', random_state=0) return t_sne.fit_transform(out_feature) - def update_graph(i, data, scat, plot): scat.set_offsets(data[i]) plt.title('t-SNE visualization of Epoch:{0}'.format(i)) return scat, plot +def w2txt(file, data): + s = "" + for i in range(len(data)): + s = s + str(data[i]) + s = s + " " + with open(file, "w") as f: + f.write(s) def train(): """Train model.""" parser = argparse.ArgumentParser(description='GCN') - parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Dataset directory') + parser.add_argument('--dataset', type=str, default='cora', help='Dataset name') + parser.add_argument('--data_dir', type=str, default='obs://hw2czq/path/gcn/data_mr/', help='Dataset directory') + parser.add_argument('--output_dir', type=str, default='obs://hw2czq/path/gcn/output', help='The path model saved') parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training') parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation') parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test') parser.add_argument('--save_TSNE', type=ast.literal_eval, default=False, help='Whether to save t-SNE graph') - args_opt = parser.parse_args() - if not os.path.exists("ckpts"): - os.mkdir("ckpts") + args_opt, unkown = parser.parse_known_args() + + # 训练数据预处理 + if not os.path.exists(os.path.join(args_opt.output_dir, "data_mr")): + os.mkdir(os.path.join(args_opt.output_dir, "data_mr")) + writer_data(mindrecord_script=args_opt.dataset, + mindrecord_file=os.path.join(args_opt.output_dir, "data_mr"), + mindrecord_partitions=1, + mindrecord_header_size_by_bit=18, + mindrecord_page_size_by_bit=20, + graph_api_args=args_opt.data_dir) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) config = ConfigGCN() - adj, feature, label_onehot, label = get_adj_features_labels(args_opt.data_dir) + adj, feature, label_onehot, label = get_adj_features_labels(os.path.join(args_opt.output_dir, "data_mr", args_opt.dataset)) nodes_num = label_onehot.shape[0] train_mask = get_mask(nodes_num, 0, args_opt.train_nodes_num) @@ -69,8 +85,9 @@ def train(): test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num) class_num = label_onehot.shape[1] + node_nums = feature.shape[0] input_dim = feature.shape[1] - gcn_net = GCN(config, input_dim, class_num) + gcn_net = GCN(config, input_dim, class_num, node_nums) gcn_net.add_flags_recursive(fp16=True) adj = Tensor(adj) @@ -116,11 +133,25 @@ def train(): if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]): print("Early stopping...") break - save_checkpoint(gcn_net, "ckpts/gcn.ckpt") - gcn_net_test = GCN(config, input_dim, class_num) - load_checkpoint("ckpts/gcn.ckpt", net=gcn_net_test) + + # 保存模型 + if not os.path.exists(os.path.join(args_opt.output_dir, "model")): + os.makedirs(os.path.join(args_opt.output_dir, "model")) + save_checkpoint(gcn_net, os.path.join(args_opt.output_dir, "model", (args_opt.dataset + '.ckpt'))) + gcn_net_test = GCN(config, input_dim, class_num, node_nums) + load_checkpoint(os.path.join(args_opt.output_dir, "model", (args_opt.dataset + '.ckpt')), net=gcn_net_test) gcn_net_test.add_flags_recursive(fp16=True) + # 模型冻结 + adj_tensor = Tensor(np.zeros((1, node_nums*node_nums), np.float32)) + feature_tensor = Tensor(np.zeros((1, node_nums*input_dim), np.float32)) + gcn_net_test.set_train(False) + load_checkpoint(os.path.join(args_opt.output_dir, "model", (args_opt.dataset + '.ckpt')), net=gcn_net_test) + export(gcn_net_test, adj_tensor, feature_tensor, file_name=os.path.join(args_opt.output_dir, "model", args_opt.dataset), file_format="AIR") + export(gcn_net_test, adj_tensor, feature_tensor, file_name=os.path.join(args_opt.output_dir, "model", args_opt.dataset), file_format="ONNX") + export(gcn_net_test, adj_tensor, feature_tensor, file_name=os.path.join(args_opt.output_dir, "model", args_opt.dataset), file_format="MINDIR") + + # 精度测试 test_net = LossAccuracyWrapper(gcn_net_test, label_onehot, test_mask, config.weight_decay) t_test = time.time() test_net.set_train(False) @@ -130,10 +161,21 @@ def train(): print("Test set results:", "loss=", "{:.5f}".format(test_loss), "accuracy=", "{:.5f}".format(test_accuracy), "time=", "{:.5f}".format(time.time() - t_test)) + # # 生成推理数据, 用于 SDK 、MxBase 推理 + # if not os.path.exists(os.path.join(args_opt.output_dir, "data", args_opt.dataset)): + # os.makedirs(os.path.join(args_opt.output_dir, "data", args_opt.dataset)) + # adj, feature, label_onehot, label = get_adj_features_labels(os.path.join(args_opt.output_dir, "data_mr", args_opt.dataset)) + # adj = (adj.reshape(-1)).astype(np.float32) + # feature = (feature.reshape(-1)).astype(np.float32) + # label_onehot = (label_onehot.reshape(-1)).astype(np.int32) + # w2txt(os.path.join(os.path.join(args_opt.output_dir, "data", args_opt.dataset), "adjacency.txt"), adj) + # w2txt(os.path.join(os.path.join(args_opt.output_dir, "data", args_opt.dataset), "feature.txt"), feature) + # w2txt(os.path.join(os.path.join(args_opt.output_dir, "data", args_opt.dataset), "label_onehot.txt"), label_onehot) + if args_opt.save_TSNE: ani = animation.FuncAnimation(fig, update_graph, frames=range(config.epochs + 1), fargs=(graph_data, scat, plt)) ani.save('t-SNE_visualization.gif', writer='imagemagick') - if __name__ == '__main__': train() +