Ai
1 Star 0 Fork 1

Barneys/pytorch_captcha_recognition

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
captcha_train.py 1.99 KB
一键复制 编辑 原始数据 按行查看 历史
Barneys 提交于 2021-12-11 13:49 +08:00 . 20211211
# -*- coding: UTF-8 -*-
import torch
import torch.nn as nn
from torch.autograd import Variable
import my_dataset
from captcha_cnn_model import CNN
import captcha_predict
import os
# Hyper Parameters
num_epochs = 28
learning_rate = 0.001
def main():
if os.path.exists('model.pkl'):
print('find the model,continue training')
# 导入之前的模型,重新训练
cnn = CNN()
cnn.eval()
cnn.load_state_dict(torch.load('model.pkl'))
else:
# 没有模型只能重新训练
print('init net')
cnn = CNN()
cnn.train()
acc = []
criterion = nn.MultiLabelSoftMarginLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)
# Train the Model
train_dataloader = my_dataset.get_train_data_loader()
# print(len(list(enumerate(train_dataloader))))
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_dataloader):
images = Variable(images)
labels = Variable(labels.float())
predict_labels = cnn(images)
# print(predict_labels.type)
# print(labels.type)
loss = criterion(predict_labels, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 10 == 0:
print("epoch:", epoch, "step:", i, "loss:", loss.item())
if (i+1) % 500 == 0:
torch.save(cnn.state_dict(), "./model.pkl") #current is model.pkl
print("save model")
print("epoch:", epoch, "step:", i, "loss:", loss.item())
print("-------------------------------start validate------------------------")
acc.append(captcha_predict.validate())
print('-------------------------------end validate--------------------------')
for acc_item in acc:
print(acc_item)
torch.save(cnn.state_dict(), "./model.pkl") #current is model.pkl
print("save last model")
if __name__ == '__main__':
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/Barneys/pytorch_captcha_recognition.git
git@gitee.com:Barneys/pytorch_captcha_recognition.git
Barneys
pytorch_captcha_recognition
pytorch_captcha_recognition
master

搜索帮助