1 Star 0 Fork 0

BG8CFB/YOLOV7_API

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
detect_with_API.py 5.65 KB
一键复制 编辑 原始数据 按行查看 历史
BG8CFB 提交于 2023-12-17 11:53 +08:00 . 初始化
import torch
from numpy import random
from models.experimental import attempt_load
from utils.datasets import MyLoadImages
from utils.general import check_img_size, non_max_suppression, apply_classifier, \
scale_coords, set_logging
from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier
class simulation_opt:
def __init__(self, weights='models/yolov7.pt',
img_size = 640, conf_thres = 0.25,
iou_thres = 0.45,device='', view_img= False,
classes = None, agnostic_nms = False,
augment = False, update = False, exist_ok = False):
self.weights = weights
self.source = None
self.img_size = img_size
self.conf_thres = conf_thres
self.iou_thres = iou_thres
self.device = device
self.view_img = view_img
self.classes = classes
self.agnostic_nms = agnostic_nms
self.augment =augment
self.update = update
self.exist_ok = exist_ok
class detectapi:
def __init__(self, weights, img_size=640):
self.opt = simulation_opt(weights=weights, img_size=img_size)
weights, imgsz = self.opt.weights, self.opt.img_size
# Initialize
set_logging()
self.device = select_device(self.opt.device)
self.half = self.device.type != 'cpu' # half precision only supported on CUDA
# Load model
self.model = attempt_load(weights, map_location=self.device) # load FP32 model
self.stride = int(self.model.stride.max()) # model stride
self.imgsz = check_img_size(imgsz, s=self.stride) # check img_size
if self.half:
self.model.half() # to FP16
# Second-stage classifier
self.classify = False
if self.classify:
self.modelc = load_classifier(name='resnet101', n=2) # initialize
self.modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=self.device)['model']).to(self.device).eval()
# read names and colors
self.names = self.model.module.names if hasattr(self.model, 'module') else self.model.names
self.colors = [[random.randint(0, 255) for _ in range(3)] for _ in self.names]
def detect(self, source): # 使用时,调用这个函数
if type(source) != list:
raise TypeError('source must be a list which contain pictures read by cv2')
dataset = MyLoadImages(source, img_size=self.imgsz, stride=self.stride)#imgsz
# 原来是通过路径加载数据集的,现在source里面就是加载好的图片,所以数据集对象的实现要
# 重写。修改代码后附。在utils.dataset.py上修改。
# Run inference
if self.device.type != 'cpu':
self.model(torch.zeros(1, 3, self.imgsz, self.imgsz).to(self.device).type_as(next(self.model.parameters()))) # run once
#t0 = time.time()
result = []
'''
for path, img, im0s, vid_cap in dataset:'''
for img, im0s in dataset:
img = torch.from_numpy(img).to(self.device)
img = img.half() if self.half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# Inference
#t1 = time_synchronized()
pred = self.model(img, augment=self.opt.augment)[0]
# Apply NMS
pred = non_max_suppression(pred, self.opt.conf_thres, self.opt.iou_thres, classes=self.opt.classes, agnostic=self.opt.agnostic_nms)
#t2 = time_synchronized()
# Apply Classifier
if self.classify:
pred = apply_classifier(pred, self.modelc, img, im0s)
# Print time (inference + NMS)
#print(f'{s}Done. ({t2 - t1:.3f}s)')
# Process detections
det = pred[0] # 原来的情况是要保持图片,因此多了很多关于保持路径上的处理。另外,pred
# 其实是个列表。元素个数为batch_size。由于对于我这个api,每次只处理一个图片,
# 所以pred中只有一个元素,直接取出来就行,不用for循环。
im0 = im0s.copy() # 这是原图片,与被传进来的图片是同地址的,需要copy一个副本,否则,原来的图片会受到影响
# s += '%gx%g ' % img.shape[2:] # print string
# gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
result_txt = []
# 对于一张图片,可能有多个可被检测的目标。所以结果标签也可能有多个。
# 每被检测出一个物体,result_txt的长度就加一。result_txt中的每个元素是个列表,记录着
# 被检测物的类别引索,在图片上的位置,以及置信度
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
# Write results
for *xyxy, conf, cls in reversed(det):
# xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (int(cls.item()), [int(_.item()) for _ in xyxy], conf.item()) # label format
result_txt.append(line)
label = f'{self.names[int(cls)]} {conf:.2f}'
plot_one_box(xyxy, im0, label=label, color=self.colors[int(cls)], line_thickness=3)
result.append((im0, result_txt)) # 对于每张图片,返回画完框的图片,以及该图片的标签列表。
return result, self.names
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/bg8cfb/yolov7_api.git
git@gitee.com:bg8cfb/yolov7_api.git
bg8cfb
yolov7_api
YOLOV7_API
master

搜索帮助