Ai
1 Star 0 Fork 292

王潇潇1107/PaddleDetection

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
Fall_detection.py 5.26 KB
一键复制 编辑 原始数据 按行查看 历史
王潇潇1107 提交于 2023-09-24 14:51 +08:00 . 创建
import os
import cv2
import paddle
import yaml
from deploy.python.keypoint_infer import KeyPointDetector
from deploy.python.det_keypoint_unite_utils import *
from deploy.python.infer import Detector, DetectorPicoDet, PredictConfig, print_arguments, get_test_images, bench_log
from deploy.python.det_keypoint_unite_infer1 import my_topdown_unite_predict as predict
from matplotlib import pyplot as plt
# 首先先进行初始环境的初始化操作
def init_all(det_model_dir,keypoint_model_dir):
paddle.enable_static()
deploy_file = os.path.join(det_model_dir, 'infer_cfg.yml')
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
arch = yml_conf['arch']
detector_func = 'Detector'
if arch == 'PicoDet':
detector_func = 'DetectorPicoDet'
detector = eval(detector_func)(det_model_dir,
device="CPU",
threshold=0.35)
topdown_keypoint_detector = KeyPointDetector(
keypoint_model_dir,
device="CPU")
return detector,topdown_keypoint_detector
class DataHandler():
"""处理框数据与关键点数据"""
def __init__(self,threshold=0.35):
self.threshold = threshold
self.buffer = {"a":[],"b":[]}
def keypoints_handler(self,json_result)->list:
"""记录肩膀、手腕、臀、脚踝"""
json_result = json_result[0][1][0][0]
keypoints_list = []
# print(json_result)
if json_result[5][-1] > self.threshold and json_result[9][-1] > self.threshold and json_result[12][-1] > self.threshold and json_result[15][-1] > self.threshold:
keypoints_list += json_result[5],json_result[9],json_result[12],json_result[15]
return keypoints_list
else:
return []
def comm_handler():
"""如果有异常就调用进行处理"""
pass
if __name__ == "__main__":
"""模型初始化地址参数,需要修改"""
det_model_dir = "output_inference/mot_ppyoloe_l_36e_pipeline"
keypoint_model_dir = "output_inference/dark_hrnet_w32_256x192"
detector,topdown_keypoint_detector = init_all(det_model_dir,keypoint_model_dir)
pose_handler = DataHandler()
"""单图测试"""
single_img = "demo/hrnet_demo.jpg"
img = cv2.imread(single_img)
print("开始检测")
json_result,image = predict(detector, topdown_keypoint_detector, img,keypoint_batch_size=1,if_show=True)
# print(json_result)
# print("box信息",json_result[0][0])
# print("关键点信息",json_result[0][1][0])
"""额外处理环节"""
keypoints_list = pose_handler.keypoints_handler(json_result)
print("关键点结果:",keypoints_list)
print("肩膀与脚踝距离 与 手腕与脚踝距离的比值:",abs(keypoints_list[0][1]-keypoints_list[3][1])/abs(keypoints_list[1][1]-keypoints_list[3][1]))
cv2.imshow("123",image)
cv2.imwrite("test_my_image.png",image)
cv2.waitKey(0)
cv2.destroyAllWindows()
"""视频检测测试"""
# stream = "/home/aistudio/work/50WaystoFall.mp4"
# capture = cv2.VideoCapture(stream)
# # cv2.namedWindow("video")
# count = 0
# json_result = [[[],[[]]]]
# # 高度缓冲区队列,用于判断是否高度快速下降
# high_list = []
#
# while (1):
# ret, frame = capture.read()
# count+=1
# if not ret:
# break
#
#
# json_result,frame = predict(detector, topdown_keypoint_detector, frame,keypoint_batch_size=1,if_show=True)
# count = 0
#
# if json_result != [[[],[[]]]] and json_result != []:
# keypoints_list = pose_handler.keypoints_handler(json_result)
# # 注意没有检测到几个位置或置信度低就没有结果。
# if keypoints_list == []:
# pass
# else:
# high = keypoints_list[0][1]
# if len(high_list)==0:
# high_list.append(high)
#
# # 注意 这里的逻辑要根据opencv的坐标轴设定,向下为正
# if high > high_list[-1]:
# high_list.append(high)
# elif high < high_list[-1]:
# high_list.remove(high_list[-1])
#
# # 对视频和实际测试后得到的最优缓冲区参数是 7
# if len(high_list) > 7:
# if abs(high_list[-1]-high_list[0])/high_list[0] > 0.3:
# print("高度快速下降,可能摔倒!请检查")
# high_list = []
#
# ratio = abs(keypoints_list[0][1]-keypoints_list[2][1])/abs(keypoints_list[2][1]-keypoints_list[3][1])
#
# # 测试后摔倒比值的最优参数是 1.6
# if ratio > 1.6:
# print("摔倒!!")
# # print("肩膀与臀部距离 与 臀部与脚踝距离的比值:",ratio)
# # print("关键点信息",json_result[0][1][0])
#
# # print("关键点结果:",keypoints_list)
# # print("box信息",json_result[0][0])
# plt.imshow(frame)
# aistudio中,不要使用cv2展示
# cv2.imshow("video",frame)
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
# capture.release()
# cv2.destroyAllWindows()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/wangxx1107_admin/PaddleDetection.git
git@gitee.com:wangxx1107_admin/PaddleDetection.git
wangxx1107_admin
PaddleDetection
PaddleDetection
release/2.6

搜索帮助