代码拉取完成,页面将自动刷新
# -*- coding: UTF-8 -*-
import torch
import torch.nn as nn
from torch.autograd import Variable
import my_dataset
from captcha_cnn_model import CNN
import captcha_predict
import os
# Hyper Parameters
num_epochs = 28
learning_rate = 0.001
def main():
if os.path.exists('model.pkl'):
print('find the model,continue training')
# 导入之前的模型,重新训练
cnn = CNN()
cnn.eval()
cnn.load_state_dict(torch.load('model.pkl'))
else:
# 没有模型只能重新训练
print('init net')
cnn = CNN()
cnn.train()
acc = []
criterion = nn.MultiLabelSoftMarginLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)
# Train the Model
train_dataloader = my_dataset.get_train_data_loader()
# print(len(list(enumerate(train_dataloader))))
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_dataloader):
images = Variable(images)
labels = Variable(labels.float())
predict_labels = cnn(images)
# print(predict_labels.type)
# print(labels.type)
loss = criterion(predict_labels, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 10 == 0:
print("epoch:", epoch, "step:", i, "loss:", loss.item())
if (i+1) % 500 == 0:
torch.save(cnn.state_dict(), "./model.pkl") #current is model.pkl
print("save model")
print("epoch:", epoch, "step:", i, "loss:", loss.item())
print("-------------------------------start validate------------------------")
acc.append(captcha_predict.validate())
print('-------------------------------end validate--------------------------')
for acc_item in acc:
print(acc_item)
torch.save(cnn.state_dict(), "./model.pkl") #current is model.pkl
print("save last model")
if __name__ == '__main__':
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。