1 Star 0 Fork 0

MrLiuSheep/去水印模型

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
pred2.py 1.37 KB
一键复制 编辑 原始数据 按行查看 历史
import torch
import numpy as np
import random
from model2 import MyModel
from torchvision.transforms import transforms
import cv2
def prediction(model:MyModel, img_path, save_path, device):
model.eval()
img = cv2.imread(img_path)
h, w = img.shape[:2]
rand_x = random.randint(0, max(0, w-128-1))
rand_y = random.randint(0, max(0, h-128-1))
rand_w = random.randint(50, 128)
rand_h = random.randint(50, 128)
img[rand_y:min(h, rand_y+rand_h), rand_x:min(w, rand_x+rand_w), :] = 0
mask = np.zeros([img.shape[0], img.shape[1], 1], dtype=np.uint8)
mask[rand_y:rand_y+rand_h, rand_x:rand_x+rand_w, :] = 255
inp = transforms.ToTensor()(img)
mask = transforms.ToTensor()(mask)
inp = inp.unsqueeze(0)
mask = mask.unsqueeze(0)
with torch.no_grad():
out = model(inp.to(device), mask.to(device))
out = torch.squeeze(out)
img2 = out.to('cpu')
img2 = img2.numpy()
img2 = (img2 * 255).astype(np.uint8)
img2 = np.transpose(img2, (1, 2, 0))
# img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2BGR)
# 拼接输出和输入
h, w = img.shape[:2]
h2, w2 = img2.shape[:2]
m1 = np.zeros((max(h, h2), max(w, w2), 3), dtype=np.uint8)
m2 = np.zeros((max(h, h2), max(w, w2), 3), dtype=np.uint8)
m1[:h, :w, :] = img
m2[:h2, :w2, :] = img2
cat_img = np.vstack((m1, m2))
cv2.imwrite(save_path, cat_img)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/LiuSheepSpace/remove-watermark-model.git
git@gitee.com:LiuSheepSpace/remove-watermark-model.git
LiuSheepSpace
remove-watermark-model
去水印模型
master

搜索帮助