代码拉取完成,页面将自动刷新
import os
import time
import numpy as np
import fastdeploy as fd
import csv
from PIL import Image
from Levenshtein import distance, ratio
from datetime import datetime
import argparse
import cv2
import shutil
from concurrent.futures import ThreadPoolExecutor, as_completed
# 在文件顶部定义全局变量
global ppocr_v3
ppocr_v3 = None
def initialize_model(params):
global ppocr_v3
try:
# 如果模型已经存在,先清理旧模型
if ppocr_v3 is not None:
del ppocr_v3
ppocr_v3 = None
args = parse_arguments(params)
runtime_option = build_option(args)
load_model(args, runtime_option)
return ppocr_v3 is not None
except Exception as e:
print(f"模型初始化失败: {str(e)}")
return False
def parse_arguments(params):
parser = argparse.ArgumentParser()
for key, value in params.items():
parser.add_argument(f"--{key}", type=type(value), default=value)
return parser.parse_args([]) # 传入空列表,使用默认值
def build_option(args):
option = fd.RuntimeOption()
if args.device.lower() == "gpu":
option.use_gpu(args.device_id)
option.set_cpu_thread_num(args.cpu_thread_num)
if args.device.lower() == "kunlunxin":
option.use_kunlunxin()
return option
if args.backend.lower() == "trt":
assert args.device.lower(
) == "gpu", "TensorRT backend require inference on device GPU."
option.use_trt_backend()
elif args.backend.lower() == "pptrt":
assert args.device.lower(
) == "gpu", "Paddle-TensorRT backend require inference on device GPU."
option.use_trt_backend()
option.enable_paddle_trt_collect_shape()
option.enable_paddle_to_trt()
elif args.backend.lower() == "ort":
option.use_ort_backend()
elif args.backend.lower() == "paddle":
option.use_paddle_infer_backend()
elif args.backend.lower() == "openvino":
assert args.device.lower(
) == "cpu", "OpenVINO backend require inference on device CPU."
option.use_openvino_backend()
return option
def load_model(args, runtime_option):
# Detection模型, 检测文字框
det_model_file = os.path.join(args.det_model, "inference.pdmodel")
det_params_file = os.path.join(args.det_model, "inference.pdiparams")
# Classification模型,方向分类,可选
cls_model_file = os.path.join(args.cls_model, "inference.pdmodel")
cls_params_file = os.path.join(args.cls_model, "inference.pdiparams")
# Recognition模型,文字识别模型
rec_model_file = os.path.join(args.rec_model, "inference.pdmodel")
rec_params_file = os.path.join(args.rec_model, "inference.pdiparams")
rec_label_file = args.rec_label_file
# PPOCR的cls和rec模型现在已经支持推理一个Batch的数据
# 定义下面两个变量后, 可用于设置trt输入shape, 并在PPOCR模型初始化后, 完成Batch推理设置
cls_batch_size = 1
rec_batch_size = 6
# 当使用TRT时,分别给三个模型的runtime设置动态shape,并完成模型的创建.
# 注意: 需要在检测模型创建完成后,再设置分类模型的动态输入并创建分类模型, 识别模型同理.
# 如果用户想要自己改动检测模型的输入shape, 我们建议用户把检测模型的长和高设置为32的倍数.
det_option = runtime_option
det_option.trt_option.set_shape("x", [1, 3, 64, 64], [1, 3, 640, 640],
[1, 3, 960, 960])
# 用户可以把TRT引擎文件保存至本地
# det_option.set_trt_cache_file(args.det_model + "/det_trt_cache.trt")
global det_model
det_model = fd.vision.ocr.DBDetector(
det_model_file, det_params_file, runtime_option=det_option)
cls_option = runtime_option
cls_option.trt_option.set_shape("x", [1, 3, 48, 10],
[cls_batch_size, 3, 48, 320],
[cls_batch_size, 3, 48, 1024])
# 用户可以把TRT引擎文件保存至本地
# cls_option.set_trt_cache_file(args.cls_model + "/cls_trt_cache.trt")
global cls_model
cls_model = fd.vision.ocr.Classifier(
cls_model_file, cls_params_file, runtime_option=cls_option)
rec_option = runtime_option
rec_option.trt_option.set_shape("x", [1, 3, 48, 10],
[rec_batch_size, 3, 48, 320],
[rec_batch_size, 3, 48, 2304])
# 用户可以把TRT引擎文件保存至本地
# rec_option.set_trt_cache_file(args.rec_model + "/rec_trt_cache.trt")
global rec_model
rec_model = fd.vision.ocr.Recognizer(
rec_model_file,
rec_params_file,
rec_label_file,
runtime_option=rec_option)
# 创建PP-OCR,串联3个模型,其中cls_model可选,如无需求,可设置为None
global ppocr_v3
if args.use_cls_model == True:
print("使用文字方向检测")
ppocr_v3 = fd.vision.ocr.PPOCRv3(
det_model=det_model, rec_model=rec_model,cls_model=cls_model)#cls_model=cls_model,
else:
ppocr_v3 = fd.vision.ocr.PPOCRv3(
det_model=det_model, rec_model=rec_model,cls_model=None)#cls_model=cls_model,
# 给cls和rec模型设置推理时的batch size
# 此值能为-1, 和1到正无穷
# 当此值为-1时, cls和rec模型的batch size将默认和det模型检测出的框的数量相同
ppocr_v3.cls_batch_size = cls_batch_size
ppocr_v3.rec_batch_size = rec_batch_size
def resize_and_crop_image(image_path, start_h, end_h, base_width=960):
with Image.open(image_path) as img:
# 获取原始图片的宽度和高度
width, height = img.size
# 检查是否需要调整大小
if base_width == 0:
# 如果基础宽度为0,则不调整大小,直接使用原始图像
resized_image = img.resize((width, height), resample=Image.LANCZOS)
else:
# 计算新的高度,保持长宽比
if height > width:
# 图片高度大于宽度,计算新的宽度(按比例)和新的高度(设置为base_width)
new_width = int(base_width * (width / height))
new_height = base_width
else:
# 图片宽度大于或等于高度,计算新的高度(按比例)和新的宽度(设置为base_width)
new_width = base_width
new_height = int(base_width * (height / width))
# 使用resize方法调整图片大小
resized_image = img.resize((new_width, new_height), resample=Image.LANCZOS)
# 检查裁剪比例是否合理
if end_h <= start_h or end_h <= 0:
# 如果不合理,则直接返回调整大小后的图像
return resized_image
# 获取调整大小后的图像的宽度和高度
new_width, new_height = resized_image.size
# 计算裁剪区域的坐标
# 注意:start_h 和 end_h 应该是介于0到1之间的浮点数,表示从顶部开始的百分比
top = int(new_height * start_h)
bottom = int(new_height * end_h)
left = 0
right = new_width
# 进行裁剪
cropped_image = resized_image.crop((left, top, right, bottom))
return cropped_image
def predict(model, image):
im = np.array(image).astype(np.uint8) # 确保图像数据类型为 uint8
result = model.predict(im)
re = []
for box, text, score in zip(result.boxes, result.text, result.rec_scores):
re.append([[[box[0], box[1]], [box[2], box[3]], [box[4], box[5]], [box[6], box[7]]], [text, score]])
print(re)
return re
def get_all_files_in_folder(folder_path):
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.gif')
return [
os.path.join(root, file)
for root, dirs, files in os.walk(folder_path)
for file in files
if file.lower().endswith(image_extensions)
]
def is_image_valid(image_path):
try:
img = Image.open(image_path)
img.verify()
return True
except Exception:
return False
def find_right_coordinates(lst, target_text, x_num, up_num, do_num, find_flag):
'''
lst:ocr识别的结果
target_text:定位内容,通过与ocr识别结果中的每一个text内容进行对比判断
x_num: 关键信息框后移num*target_text宽度
up_num: 关键信息框上扩num*target_text高度
do_num:关键信息框下扩num*target_text高度
x_num, up_num, do_num三个参数的目的是为了让定位元素的检测框X中线能穿过关键信息的框
'''
# find_flag = 1 # 1代表精确查询,其他数字模糊查询
# text_num_flag = 1 # 1表示返回定位元素后的一个元素内容,其他则表示定位元素后面所有元素内容
# 找到含有目标文本的元素及其坐标
target_element = None
for element in lst:
if find_flag == 1 and target_text == element[1][0]:
target_element = element
break
if find_flag != 1 and target_text in element[1][0]:
target_element = element
break
if target_element is not None:
# 获取坐标
coordinates = target_element[0]
# 获得左上角left_x右上角得right_x坐标,和右上角y1,右下角y2
left_x = coordinates[0][0]
right_x = coordinates[1][0]
right_y1 = coordinates[1][1]
right_y2 = coordinates[2][1]
# 获得目标向右的关键信息提取
wor = '' # 获得所要得内容
for other_element in lst: # x_num是右target_text识别框长度得比例,up_num和do_num是上下移动相对于target框得比例。是针对每次循环得要素框,所以只能根据要素框来调节参数,目标要素针对指定要素框中线,来调节位置
# 判断下一个元素是否被target_element中线穿过。如果穿过,就确定是他。
if right_x + (right_x - left_x) * x_num>other_element[0][1][0] > right_x and other_element[0][1][1] - (
right_y2 - right_y1) * up_num < (right_y1 + right_y2) / 2 < other_element[0][2][1] + (
right_y2 - right_y1) * do_num:
wor =wor+ other_element[1][0] #获得后面所有的文本
return wor
def writ_csv(csv_file, data):
if not os.path.exists(os.path.dirname(csv_file)):
os.makedirs(os.path.dirname(csv_file))
try:
# 尝试使用gbk编码写入
with open(csv_file, 'a', newline='', encoding='gbk') as f:
writer = csv.writer(f)
writer.writerow(data)
except UnicodeEncodeError:
# 如果列表中的元素是字符串的话,就尝试用GBK编码处理它,并且对于编码过程中遇到的无法编码的字符,选择性地忽略它们。如果列表中的元素不是字符串,则不做任何改变直接保留原样。
try_data = [item.encode('gbk', 'ignore').decode('gbk') if isinstance(item, str) else item for item in data]
with open(csv_file, 'a', newline='', encoding='gbk') as f:
writer = csv.writer(f)
writer.writerow(try_data)
def juge_str_in_list(query_type,keyword_list,re_lsit):
'''
query_type:查询的类型,"title"和"text"
keyword_list:关键字列表
re_lsit:ocr识别结果
'''
# 参数
keywords = []
eles_list = []
for element in re_lsit:
for keyword in keyword_list:
# 执行字符串相似度比较,如果字符比较有两个字符不相符并且整体相似度在80%以上,则判断属于标题字符。
# 执行模糊查询
if query_type == "title" and distance(element[1][0], keyword)<=2 and ratio(element[1][0], keyword) >=0.80:
keywords.append(keyword)
return [",".join(set(keywords))]
if query_type == "text" and keyword in element[1][0]:
keywords.append(keyword)
eles_list.append(element[1][0])
if query_type == "text":
return [",".join(set(keywords)),",".join(set(eles_list))]
def title_process(file, keywords, start_h, end_h, base_width):
'''
ocr标题识别
'''
# 参数
title_keyword_list = keywords # 标题关键词列表
if is_image_valid(file)==False:
return
s_time = time.time()
# resize图片 和 裁剪图片
im = resize_and_crop_image(file,start_h,end_h,base_width)
result = predict(ppocr_v3, im)
# print(result)
csv_text = [os.path.basename(file),file]
ju_text = juge_str_in_list('title',title_keyword_list,result)
if ju_text != None:
csv_text.extend(ju_text)
else:
csv_text.extend([''])
e_time = time.time()
print(f"Processed file: {os.path.basename(file)},标题关键字是:{ju_text}',共花费时间{round((e_time - s_time), 2)}秒")
return csv_text
def text_keyword_process(file, text_keyword_list, start_h, end_h, base_width):
'''
ocr内容关键词提取
'''
if is_image_valid(file) == False:
return
s_time = time.time()
# resize图片 和 裁剪图片
im = resize_and_crop_image(file, start_h, end_h, base_width)
result = predict(ppocr_v3, im)
# print(result)
csv_text = [os.path.basename(file), file]
ju_text = juge_str_in_list('text', text_keyword_list, result)
if ju_text != None:
csv_text.extend(ju_text)
else:
csv_text.extend([''])
e_time = time.time()
print(f"Processed file: {os.path.basename(file)},内容关键字是:{ju_text}',共花费时间{round((e_time - s_time), 2)}秒")
return csv_text
def text_ocr_process(file,tar_text_dic,start_h,end_h,base_width):
'''
ocr关键信息提取
'''
# 参数
# start_h = 0 # 裁剪开始的高度百分比,
# end_h = 0 # 裁剪结束的高度百分比,为0表示不裁剪
# base_width = 960 # 图片尺寸重新定义。0表示不定义
if is_image_valid(file):
s_time = time.time()
# resize图片 和 裁剪图片
im = resize_and_crop_image(file,start_h,end_h,base_width)
result = predict(ppocr_v3, im)
# print(result)
csv_text = [os.path.basename(file)]
for key in tar_text_dic:
text = find_right_coordinates(result, key, **tar_text_dic[key])
if text != None:
csv_text.append('\t' + str(text))
else:
csv_text.append('')
e_time = time.time()
print(f"Processed file: {os.path.basename(file)}关键信息已提取',共花费时间{round((e_time - s_time), 2)}秒")
return csv_text
# 关键信息提取
def Key_information_extraction(input_dir, output_dir, tar_text_dic, start_h, end_h, base_width, progress_callback=None, stop_check=None):
global ppocr_v3
if ppocr_v3 is None:
print("模型尚未初始化,正在尝试初始化...")
if not initialize_model({}): # 使用默认参数
print("模型初始化失败")
return
# 获取202410101151 类型时间戳作为输出目录
formatted_time = datetime.now().strftime('%Y%m%d%H%M')
outdirpath = os.path.join(output_dir, formatted_time)
# 关键信息提取
csv_text_path = os.path.join(outdirpath, f'关键信息提取.csv')
## 写表头部分
header = ['filename'] + list(tar_text_dic.keys()) # 关键信息的表头
writ_csv(csv_text_path, header) # 关键信息提取得输出csv文件写入表头标题
# 批量处理文件夹目录
all_files = get_all_files_in_folder(input_dir)
for i, image_file in enumerate(all_files):
if stop_check and stop_check():
print("操作被用户中止")
return None
csv_text = text_ocr_process(image_file, tar_text_dic, start_h, end_h, base_width)
writ_csv(csv_text_path, csv_text)
if progress_callback:
progress_callback(csv_text, i + 1, len(all_files))
print(f"--------------------------关键信息提取进度{i + 1}/{len(all_files)}")
return outdirpath # 返回输出文件夹路径
def copy_file(src, dst):
shutil.copy2(src, dst)
def multi_threaded_file_copy(file_list, max_workers=4):
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(copy_file, src, dst) for src, dst in file_list]
for future in as_completed(futures):
try:
future.result()
except Exception as e:
print(f"Error copying file: {e}")
#内容提取分类
def content_keyword_categorization_and_classification(dir_path, out_dir, keywords, progress_callback=None, stop_check=None, start_h=0, end_h=0, base_width=960):
print(f"开始执行内容关键字分类,输入目录: {dir_path}, 输出目录: {out_dir}")
global ppocr_v3
if ppocr_v3 is None:
print("模型尚未初始化,正在尝试初始化...")
if not initialize_model({}): # 使用默认参数
print("模型初始化失败")
return
formatted_time = datetime.now().strftime('%Y%m%d%H%M')
outdirpath = os.path.join(out_dir, formatted_time)
csv_text_key_path = os.path.join(outdirpath, f'内容关键词提取.csv')
header = ['filename', 'filepath', 'keyword','content']
writ_csv(csv_text_key_path, header)
all_files = get_all_files_in_folder(dir_path)
file_copy_list = []
for index, file in enumerate(all_files):
if stop_check and stop_check():
print("操作被用户中止")
return None
csv_text = text_keyword_process(file, keywords, start_h, end_h, base_width)
writ_csv(csv_text_key_path, csv_text)
if progress_callback:
progress_callback(csv_text, index + 1, len(all_files))
print(f"--------------------------内容关键字分类进度{index + 1}/{len(all_files)}")
# 在这里准备文件复制列表
if csv_text[2].strip(): # 如果有关键词
keywords_found = csv_text[2].strip().split(',')
for keyword in keywords_found:
keyword = keyword.strip()
if keyword:
keyword_folder = os.path.join(outdirpath, keyword)
if not os.path.exists(keyword_folder):
os.makedirs(keyword_folder)
dest_file = os.path.join(keyword_folder, os.path.basename(file))
file_copy_list.append((file, dest_file))
# 使用多线程复制文件
multi_threaded_file_copy(file_copy_list)
print("内容关键字分类和文件分类完成")
print(f"内容关键字分类完成,返回输出目录: {outdirpath}")
return outdirpath # 返回输出文件夹路径
#标题提取分类
def title_keyword_categorization(input_dir, output_dir, keywords, progress_callback=None, stop_check=None, start_h=0.05, end_h=0.35, base_width=960):
global ppocr_v3
if ppocr_v3 is None:
print("模型尚未初始化,正在尝试初始化...")
if not initialize_model({}): # 使用默认参数
print("模型初始化失败")
return None
# 创建一个新的输出目录,避免与输入目录冲突
formatted_time = datetime.now().strftime('%Y%m%d%H%M')
outdirpath = os.path.join(output_dir, f'title_extraction_{formatted_time}')
os.makedirs(outdirpath, exist_ok=True)
# 标题提取分类
csv_title_path = os.path.join(outdirpath, f'标题提取.csv')
header = ['filename', 'filepath', 'keyword']
writ_csv(csv_title_path, header) # 标题提取得输出csv文件写入表头标题
# 批量处理文件夹目录
all_files = get_all_files_in_folder(input_dir)
file_copy_list = []
for index, file in enumerate(all_files):
if stop_check and stop_check():
print("操作被用户中止")
return None
csv_text = title_process(file, keywords, start_h, end_h, base_width)
if csv_text:
writ_csv(csv_title_path, csv_text)
if progress_callback:
progress_callback(csv_text, index + 1, len(all_files))
print(f"--------------------------标题分类进度{index + 1}/{len(all_files)}")
# 在这里准备文件复制列表
if csv_text[2].strip(): # 如果有关键词
keywords_found = csv_text[2].strip().split(',')
for keyword in keywords_found:
keyword = keyword.strip()
if keyword:
keyword_folder = os.path.join(outdirpath, keyword)
if not os.path.exists(keyword_folder):
os.makedirs(keyword_folder)
dest_file = os.path.join(keyword_folder, os.path.basename(file))
file_copy_list.append((file, dest_file))
# 使用多线程复制文件
multi_threaded_file_copy(file_copy_list)
print("内容关键字分类和文件分类完成")
print(f"内容关键字分类完成,返回输出目录: {outdirpath}")
return outdirpath # 返回新创建的输出文件夹路径
def single_image_ocr(single_img_path, visualize=True):
global ppocr_v3
if ppocr_v3 is None:
print("模型尚未初始化,正在尝试初始化...")
if not initialize_model({}): # 使用默认参数
print("模型初始化失败")
return None, None
if not is_image_valid(single_img_path):
return None, None
im = resize_and_crop_image(single_img_path, 0, 0, base_width=0)
result, vis_path = predict_single(ppocr_v3, im, visualize)
# 返回识别结果列表和可视化结果的路径(如果有的话)
return [item[1][0] for item in result], vis_path
def predict_single(model, image, visualize=True):
im = np.array(image).astype(np.uint8) # 确保图像数据类型为 uint8
results = model.predict(im)
re = []
for box, text, score in zip(results.boxes, results.text, results.rec_scores):
re.append([[[box[0], box[1]], [box[2], box[3]], [box[4], box[5]], [box[6], box[7]]], [text, score]])
vis_path = None
if visualize:
# 可视化检测结果
vis_im = fd.vision.vis_ppocr(im, results)
# 生成带时间戳的文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
vis_path = os.path.join('visionImages', f"visualized_result_{timestamp}.jpg")
if not os.path.exists(os.path.dirname(vis_path)):
os.makedirs(os.path.dirname(vis_path))
cv2.imwrite(vis_path, vis_im)
return re, vis_path
# 在主程序中初始化模型
# if __name__ == '__main__':
# initialize_model()
# 现在可以调用各个函数,它们会共用已初始化的模型
# single_image_ocr(single_img_path)
# title_keyword_categorization(dir_path, out_dir)
# content_keyword_categorization(dir_path, out_dir)
# Key_information_extraction(dir_path, out_dir)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。