代码拉取完成,页面将自动刷新
import torch
import numpy as np
import random
from model2 import MyModel
from torchvision.transforms import transforms
import cv2
def prediction(model:MyModel, img_path, save_path, device):
model.eval()
img = cv2.imread(img_path)
h, w = img.shape[:2]
rand_x = random.randint(0, max(0, w-128-1))
rand_y = random.randint(0, max(0, h-128-1))
rand_w = random.randint(50, 128)
rand_h = random.randint(50, 128)
img[rand_y:min(h, rand_y+rand_h), rand_x:min(w, rand_x+rand_w), :] = 0
mask = np.zeros([img.shape[0], img.shape[1], 1], dtype=np.uint8)
mask[rand_y:rand_y+rand_h, rand_x:rand_x+rand_w, :] = 255
inp = transforms.ToTensor()(img)
mask = transforms.ToTensor()(mask)
inp = inp.unsqueeze(0)
mask = mask.unsqueeze(0)
with torch.no_grad():
out = model(inp.to(device), mask.to(device))
out = torch.squeeze(out)
img2 = out.to('cpu')
img2 = img2.numpy()
img2 = (img2 * 255).astype(np.uint8)
img2 = np.transpose(img2, (1, 2, 0))
# img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2BGR)
# 拼接输出和输入
h, w = img.shape[:2]
h2, w2 = img2.shape[:2]
m1 = np.zeros((max(h, h2), max(w, w2), 3), dtype=np.uint8)
m2 = np.zeros((max(h, h2), max(w, w2), 3), dtype=np.uint8)
m1[:h, :w, :] = img
m2[:h2, :w2, :] = img2
cat_img = np.vstack((m1, m2))
cv2.imwrite(save_path, cat_img)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。