1 Star 0 Fork 1

asoiretop/车牌识别

forked from yeye0810/车牌识别 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 3.05 KB
一键复制 编辑 原始数据 按行查看 历史
JiangCe0810 提交于 2018-10-16 15:51 +08:00 . 2018/10/16
import argparse
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
import torch.optim as optim
import os
from utils.parse_config import *
from utils.utils import *
from utils.datasets import *
from model import *
# initial the parameters
parse = argparse.ArgumentParser()
parse.add_argument('--epochs', type=int, default=30, help='number of epoch')
parse.add_argument('--image_folder', type=str, default='data/images', help='path to image dataset')
parse.add_argument('--batch_size', type=int, default=10, help='size of each imae batch')
parse.add_argument('--model_config_path', type=str, default='cfg/WPOD_NET.cfg', help='path to model config file')
parse.add_argument('--weights_path', type=str, default='weights/WPOD_NET.h5', help='path to weight file')
parse.add_argument('--train_path', type=str, default='data/train.txt', help='path to images')
parse.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
parse.add_argument('--checkpoint_interval', type=int, default=1, help='interval between saving model weights')
parse.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='directory where model checkpoints are saved')
parse.add_argument('--img_size', type=int, default=416, help='size of each image dimension')
parse.add_argument('--use_cuda', type=bool, default=True, help='whether to use cuda if avaiable')
opt = parse.parse_args()
print(opt)
cuda = torch.cuda.is_available() and opt.use_cuda
os.makedirs('output', exist_ok=True)
os.makedirs('checkpoints', exist_ok=True)
hyperparams = parse_model_config(opt.model_config_path)[0]
learning_rate = float(hyperparams['learning_rate'])
momentum = float(hyperparams['momentum'])
decay = float(hyperparams['decay'])
burn_in = int(hyperparams['burn_in'])
# Initial model
model = WPOD_NET(opt.model_config_path)
model.apply(weights_init_normal)
if cuda:
model = model.cuda()
model.train()
# Get dataloader
dataloader = torch.utils.data.DataLoader(ListDataset(opt.train_path),
batch_size = opt.batch_size,
shuffle = False,
num_workers = opt.n_cpu)
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
for epoch in range(opt.epochs):
for batch_i, (_, imgs, targets) in enumerate(dataloader):
# print(imgs.shape)
# imgs = imgs.permute(2,3,1,0)
imgs = Variable(imgs.type(Tensor))
targets = Variable(targets.type(Tensor))
optimizer.zero_grad()
loss = model(imgs, targets)
loss.backward()
optimizer.step()
print('[Epoch %d/%d, Batch %d/%d] Loss: %f' %
(epoch, opt.epochs, batch_i, len(dataloader),
model.loss))
model.seen += imgs.size(0)
if epoch % opt.checkpoint_interval == 0:
model.save_weights('%s/%d.weights' % (opt.checkpoint_dir, epoch))
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/aspiretop/lpr.git
git@gitee.com:aspiretop/lpr.git
aspiretop
lpr
车牌识别
master

搜索帮助