1 Star 0 Fork 1

wangchao/CompareTool

forked from liuqiyuan/CompareTool 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
run_compare.py 924 Bytes
一键复制 编辑 原始数据 按行查看 历史
liuqiyuan 提交于 2022-04-28 20:43 +08:00 . add
import sklearn
import torch
import numpy as np
from mindspore import context
from src.compare import compare
from module.encoder_pt import cnnEncoder as pt_encoder
from module.encoder_ms import cnnEncoder as ms_encoder
# from module.ms_mse_loss import MyMseLoss as ms_encoder
# from module.pt_mse_loss import MyMseLoss as pt_encoder
context.set_context(device_target="Ascend", device_id=7)
if __name__ == "__main__":
# denfine your own net
PtNet = pt_encoder()
MsNet = ms_encoder()
cp = compare(ptmodule=PtNet,
msmodule=MsNet,
module_type="net", # "net" or "loss"
input_shape=[[1,1,28,28]], # if 2 inputs, format is [[],[]]
init_mode="random", # "random" or "ones"
input_num=1, # 1 or 2, modify input_shape at the same time
print_result=True
)
cp.start_compare()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/wangchao285/CompareTool.git
git@gitee.com:wangchao285/CompareTool.git
wangchao285
CompareTool
CompareTool
master

搜索帮助