代码拉取完成,页面将自动刷新
import torch
import torch.nn as nn
from pytorch_msssim import ms_ssim
import os
import random
from tqdm import tqdm
from torch.utils.data import DataLoader
from dataset2 import MyDataset
from model2 import MyModel, GAN
from pred2 import prediction
if __name__ == '__main__':
img_rootPath = r'../JPEGImages'
saveModel_dir = r'checkpoints'
verification = r'verification'
load_model = r''
epoch = 300
# warm_up = 90000
batch_size = 8
lr = 1e-4
alpha = 4e-4
train_ds = MyDataset(img_rootPath, 'train', 0.9)
test_ds = MyDataset(img_rootPath, 'test', 0.9)
train_dl = DataLoader(train_ds, shuffle=True, batch_size=batch_size, drop_last=True)
test_dl = DataLoader(test_ds, shuffle=True, batch_size=batch_size, drop_last=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MyModel()
if load_model:
model.load_state_dict(torch.load(load_model), strict=False)
model.to(device)
gan = GAN()
gan.to(device)
optim = torch.optim.Adam(model.parameters(), lr)
gan_optim = torch.optim.Adam(gan.parameters(), lr)
loss_fn = nn.MSELoss()
gan_loss_fn = nn.BCELoss()
maxAcc_test = 0
# warm_up_times = 0
# while warm_up_times < warm_up:
# model.train()
# with tqdm(train_dl) as train_dl_buffer:
# # 模型主体训练
# total_loss = 0
# for batch, (img, mask, targets, local_area) in enumerate(train_dl_buffer):
# out = model(img.to(device), mask.to(device))
# loss = loss_fn(out, targets.to(device))
# optim.zero_grad()
# loss.backward()
# optim.step()
# total_loss += loss.item()
# batch += 1
# warm_up_times += 1
# train_dl_buffer.set_description("Epoch {}/{}: trainloss {:.4f}".format(warm_up_times+1, warm_up, total_loss/batch))
for e in range(epoch):
model.train()
with tqdm(train_dl) as train_dl_buffer:
gan_total_loss = 0
main_total_loss = 0
for batch, (img, mask, target, local_area) in enumerate(train_dl_buffer):
img = img.to(device)
mask = mask.to(device)
target = target.to(device)
# gan对抗器训练
with torch.no_grad():
out_img = model(img, mask)
local_target = []
local_out_img = []
for target_batch, out_img_batch, local_area_batch in zip(target, out_img, local_area):
x1, x2, y1, y2 = local_area_batch
local_target.append(target_batch[:, y1:y2, x1:x2])
local_out_img.append(out_img_batch[:, y1:y2, x1:x2])
local_target = torch.stack(local_target, dim=0)
local_out_img = torch.stack(local_out_img, dim=0)
real = gan(target, local_target)
fake = gan(out_img, local_out_img)
loss = gan_loss_fn(real, torch.ones(real.shape[0], 1, device=device)) + \
gan_loss_fn(fake, torch.zeros(fake.shape[0], 1, device=device))
gan_optim.zero_grad()
loss.backward()
gan_optim.step()
main_total_loss += loss.item()
# gan生成器训练
out_img = model(img, mask)
local_target = []
local_out_img = []
for target_batch, out_img_batch, local_area_batch in zip(target, out_img, local_area):
x1, x2, y1, y2 = local_area_batch
local_target.append(target_batch[:, y1:y2, x1:x2])
local_out_img.append(out_img_batch[:, y1:y2, x1:x2])
local_target = torch.stack(local_target, dim=0)
local_out_img = torch.stack(local_out_img, dim=0)
fake = gan(out_img, local_out_img)
main_loss = alpha*loss_fn(out_img, target) + loss_fn(local_out_img, local_target)
gan_loss = gan_loss_fn(fake, torch.ones(fake.shape[0], 1, device=device))
loss = main_loss + alpha*gan_loss
optim.zero_grad()
gan_optim.zero_grad()
loss.backward()
optim.step()
gan_total_loss += loss.item()
batch += 1
train_dl_buffer.set_description("Epoch {}: mainloss {:.8f} ganloss {:.8f}".format(e+1, main_total_loss/batch, gan_total_loss/batch))
# 保存模型
path = os.path.join(saveModel_dir, 'checkpoint.pth')
torch.save(model.state_dict(), path)
print("保存模型===》{}".format(path))
# 模型测试
total_acc = 0
model.eval()
with tqdm(test_dl) as test_dl_buffer:
for batch, (img, mask, target, local_area) in enumerate(test_dl_buffer):
with torch.no_grad():
out = model(img.to(device), mask.to(device))
acc = ms_ssim(out, target.to(device), data_range=1, size_average=True)
total_acc += acc.item()
batch += 1
test_dl_buffer.set_description("Epoch {}: testAcc {:.4f}".format(e+1, total_acc/batch))
if total_acc/batch > maxAcc_test:
maxAcc_test = total_acc/batch
path = os.path.join(saveModel_dir, 'best_test.pth')
torch.save(model.state_dict(), path)
print("保存模型===》{}".format(path))
for i in range(10):
img_path = test_ds.img_path_arr[random.randint(0, test_ds.__len__()-1)]
save_path = os.path.join(verification, f'{i}.jpg')
prediction(model, img_path, save_path, device)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。