From 2e73972a5dc240fcc301fa240e0bcf6529bb117d Mon Sep 17 00:00:00 2001 From: may Date: Mon, 22 Jul 2024 11:00:56 +0800 Subject: [PATCH 1/7] Add yolov6 and yolov8. --- models/cv/detection/yolov6/ixrt/README.md | 71 ++++ .../cv/detection/yolov6/ixrt/build_engine.py | 94 +++++ models/cv/detection/yolov6/ixrt/common.py | 335 ++++++++++++++++++ models/cv/detection/yolov6/ixrt/deploy.py | 99 ++++++ models/cv/detection/yolov6/ixrt/inference.py | 253 +++++++++++++ models/cv/detection/yolov6/ixrt/quant.py | 105 ++++++ .../scripts/infer_yolov6s_fp16_accuracy.sh | 65 ++++ .../scripts/infer_yolov6s_fp16_performance.sh | 78 ++++ .../scripts/infer_yolov6s_int8_accuracy.sh | 85 +++++ .../scripts/infer_yolov6s_int8_performance.sh | 86 +++++ models/cv/detection/yolov8/ixrt/README.md | 59 +++ .../cv/detection/yolov8/ixrt/build_engine.py | 94 +++++ models/cv/detection/yolov8/ixrt/common.py | 335 ++++++++++++++++++ models/cv/detection/yolov8/ixrt/export.py | 43 +++ models/cv/detection/yolov8/ixrt/inference.py | 237 +++++++++++++ models/cv/detection/yolov8/ixrt/quant.py | 105 ++++++ .../scripts/infer_yolov8n_fp16_accuracy.sh | 65 ++++ .../scripts/infer_yolov8n_fp16_performance.sh | 66 ++++ .../scripts/infer_yolov8n_int8_accuracy.sh | 85 +++++ .../scripts/infer_yolov8n_int8_performance.sh | 85 +++++ 20 files changed, 2445 insertions(+) create mode 100644 models/cv/detection/yolov6/ixrt/README.md create mode 100644 models/cv/detection/yolov6/ixrt/build_engine.py create mode 100644 models/cv/detection/yolov6/ixrt/common.py create mode 100644 models/cv/detection/yolov6/ixrt/deploy.py create mode 100644 models/cv/detection/yolov6/ixrt/inference.py create mode 100644 models/cv/detection/yolov6/ixrt/quant.py create mode 100644 models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_fp16_accuracy.sh create mode 100644 models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_fp16_performance.sh create mode 100644 models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_int8_accuracy.sh create mode 100644 models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_int8_performance.sh create mode 100644 models/cv/detection/yolov8/ixrt/README.md create mode 100644 models/cv/detection/yolov8/ixrt/build_engine.py create mode 100644 models/cv/detection/yolov8/ixrt/common.py create mode 100644 models/cv/detection/yolov8/ixrt/export.py create mode 100644 models/cv/detection/yolov8/ixrt/inference.py create mode 100644 models/cv/detection/yolov8/ixrt/quant.py create mode 100644 models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_fp16_accuracy.sh create mode 100644 models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_fp16_performance.sh create mode 100644 models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_int8_accuracy.sh create mode 100644 models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_int8_performance.sh diff --git a/models/cv/detection/yolov6/ixrt/README.md b/models/cv/detection/yolov6/ixrt/README.md new file mode 100644 index 00000000..8d78557a --- /dev/null +++ b/models/cv/detection/yolov6/ixrt/README.md @@ -0,0 +1,71 @@ +# YOLOv6 + +## Description + +YOLOv6 integrates cutting-edge object detection advancements from industry and academia, incorporating recent innovations in network design, training strategies, testing techniques, quantization, and optimization methods. This culmination results in a suite of deployment-ready networks, accommodating varied use cases across different scales. + +## Setup + +### Install + +```bash +# Install libGL +## CentOS +yum install -y mesa-libGL +## Ubuntu +apt install -y libgl1-mesa-dev + +pip3 install tqdm +pip3 install onnx +pip3 install onnxsim +pip3 install pycocotools +``` + +### Download + +Pretrained model: + +Dataset: to download the validation dataset. + +### Model Conversion + +```bash +# install yolov6 +git clone https://github.com/meituan/YOLOv6.git +cd YOLOv6 +pip3 install -r requirements.txt + +# export onnx model +python3 deploy/ONNX/export_onnx.py --weights ../yolov6s.pt --img 640 --batch-size 32 --simplify + +cd .. +``` + +## Inference + +```bash +export DATASETS_DIR=/Path/to/coco/ +``` + +### FP16 + +```bash +# Accuracy +bash scripts/infer_yolov6s_fp16_accuracy.sh +# Performance +bash scripts/infer_yolov6s_fp16_performance.sh +``` + +### INT8 + +```bash +# Accuracy +bash scripts/infer_yolov6s_int8_accuracy.sh +# Performance +bash scripts/infer_yolov6s_int8_performance.sh +``` + + +## Reference + +YOLOv6: diff --git a/models/cv/detection/yolov6/ixrt/build_engine.py b/models/cv/detection/yolov6/ixrt/build_engine.py new file mode 100644 index 00000000..f5e1719a --- /dev/null +++ b/models/cv/detection/yolov6/ixrt/build_engine.py @@ -0,0 +1,94 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import cv2 +import argparse +import numpy as np + +import torch +import tensorrt +from tensorrt import Dims + + +def build_engine_trtapi_staticshape(config): + IXRT_LOGGER = tensorrt.Logger(tensorrt.Logger.WARNING) + builder = tensorrt.Builder(IXRT_LOGGER) + EXPLICIT_BATCH = 1 << (int)(tensorrt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(EXPLICIT_BATCH) + build_config = builder.create_builder_config() + parser = tensorrt.OnnxParser(network, IXRT_LOGGER) + parser.parse_from_file(config.model) + + precision = tensorrt.BuilderFlag.INT8 if config.precision == "int8" else tensorrt.BuilderFlag.FP16 + # print("precision : ", precision) + build_config.set_flag(precision) + + plan = builder.build_serialized_network(network, build_config) + engine_file_path = config.engine + with open(engine_file_path, "wb") as f: + f.write(plan) + print("Build static shape engine done!") + + +def build_engine_trtapi_dynamicshape(config): + IXRT_LOGGER = tensorrt.Logger(tensorrt.Logger.WARNING) + builder = tensorrt.Builder(IXRT_LOGGER) + EXPLICIT_BATCH = 1 << (int)(tensorrt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(EXPLICIT_BATCH) + build_config = builder.create_builder_config() + + profile = builder.create_optimization_profile() + profile.set_shape("input", + Dims([1, 3, 608, 608]), + Dims([32, 3, 608, 608]), + Dims([64, 3, 608, 608]), + ) + build_config.add_optimization_profile(profile) + + parser = tensorrt.OnnxParser(network, IXRT_LOGGER) + parser.parse_from_file(config.model) + precision = tensorrt.BuilderFlag.INT8 if config.precision == "int8" else tensorrt.BuilderFlag.FP16 + # print("precision : ", precision) + build_config.set_flag(precision) + + # set dynamic + num_inputs = network.num_inputs + for i in range(num_inputs): + input_tensor = network.get_input(i) + input_tensor.shape = Dims([-1, 3, 608, 608]) + + plan = builder.build_serialized_network(network, build_config) + engine_file_path = config.engine + with open(engine_file_path, "wb") as f: + f.write(plan) + print("Build dynamic shape engine done!") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str) + parser.add_argument("--precision", type=str, choices=["float16", "int8", "float32"], default="int8", + help="The precision of datatype") + # engine args + parser.add_argument("--engine", type=str, default=None) + + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + build_engine_trtapi_staticshape(args) + # build_engine_trtapi_dynamicshape(args) diff --git a/models/cv/detection/yolov6/ixrt/common.py b/models/cv/detection/yolov6/ixrt/common.py new file mode 100644 index 00000000..dc3c2766 --- /dev/null +++ b/models/cv/detection/yolov6/ixrt/common.py @@ -0,0 +1,335 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import cv2 +import glob +import time +import numpy as np +from tqdm import tqdm + +import tensorrt +import pycuda.driver as cuda + + +def load_class_names(namesfile): + class_names = [] + with open(namesfile, 'r') as fp: + lines = fp.readlines() + for line in lines: + line = line.rstrip() + class_names.append(line) + return class_names + +# input : [bsz, box_num, 5(cx, cy, w, h, conf) + class_num(prob[0], prob[1], ...)] +# output : [bsz, box_num, 6(left_top_x, left_top_y, right_bottom_x, right_bottom_y, class_id, max_prob*conf)] +def box_class85to6(input): + center_x_y = input[:, :2] + side = input[:, 2:4] + conf = input[:, 4:5] + class_id = np.argmax(input[:, 5:], axis = -1) + class_id = class_id.astype(np.float32).reshape(-1, 1) + 1 + max_prob = np.max(input[:, 5:], axis = -1).reshape(-1, 1) + x1_y1 = center_x_y - 0.5 * side + x2_y2 = center_x_y + 0.5 * side + nms_input = np.concatenate([x1_y1, x2_y2, class_id, max_prob*conf], axis = -1) + return nms_input + +def save2json(batch_img_id, pred_boxes, json_result, class_trans): + for i, boxes in enumerate(pred_boxes): + if boxes is not None: + image_id = int(batch_img_id[i]) + # have no target + if image_id == -1: + continue + + for x1, y1, x2, y2, _, p, c in boxes: + x1, y1, x2, y2, p = float(x1), float(y1), float(x2), float(y2), float(p) + c = int(c) + x = x1 + y = y1 + w = x2 - x1 + h = y2 - y1 + + json_result.append( + { + "image_id": image_id, + "category_id": class_trans[c - 1], + "bbox": [x, y, w, h], + "score": p, + } + ) + +################## About TensorRT ################# +def create_engine_context(engine_path, logger): + with open(engine_path, "rb") as f: + runtime = tensorrt.Runtime(logger) + assert runtime + engine = runtime.deserialize_cuda_engine(f.read()) + assert engine + context = engine.create_execution_context() + assert context + + return engine, context + +def setup_io_bindings(engine, context): + # Setup I/O bindings + inputs = [] + outputs = [] + allocations = [] + + for i in range(engine.num_bindings): + is_input = False + if engine.binding_is_input(i): + is_input = True + name = engine.get_binding_name(i) + dtype = engine.get_binding_dtype(i) + shape = context.get_binding_shape(i) + if is_input: + batch_size = shape[0] + size = np.dtype(tensorrt.nptype(dtype)).itemsize + for s in shape: + size *= s + allocation = cuda.mem_alloc(size) + binding = { + "index": i, + "name": name, + "dtype": np.dtype(tensorrt.nptype(dtype)), + "shape": list(shape), + "allocation": allocation, + } + # print(f"binding {i}, name : {name} dtype : {np.dtype(tensorrt.nptype(dtype))} shape : {list(shape)}") + allocations.append(allocation) + if engine.binding_is_input(i): + inputs.append(binding) + else: + outputs.append(binding) + return inputs, outputs, allocations +########################################################## + + +################## About Loading Dataset ################# +def load_images(images_path): + """ + If image path is given, return it directly + For txt file, read it and return each line as image path + In other case, it's a folder, return a list with names of each + jpg, jpeg and png file + """ + input_path_extension = images_path.split('.')[-1] + if input_path_extension in ['jpg', 'jpeg', 'png']: + return [images_path] + elif input_path_extension == "txt": + with open(images_path, "r") as f: + return f.read().splitlines() + else: + return glob.glob( + os.path.join(images_path, "*.jpg")) + \ + glob.glob(os.path.join(images_path, "*.png")) + \ + glob.glob(os.path.join(images_path, "*.jpeg")) + +def prepare_batch(images_path, bs=16, input_size=(608, 608)): + + width, height = input_size + + batch_names = [] + batch_images = [] + batch_shapes = [] + + temp_names = [] + temp_images = [] + temp_shapes = [] + + for i, image_path in tqdm(enumerate(images_path), desc="Loading coco data"): + name = os.path.basename(image_path) + image = cv2.imread(image_path) + h, w, _ = image.shape + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image_resized = cv2.resize(image_rgb, (width, height), + interpolation=cv2.INTER_LINEAR) + custom_image = image_resized.transpose(2, 0, 1).astype(np.float32) / 255. + custom_image = np.expand_dims(custom_image, axis=0) + + if i != 0 and i % bs == 0: + batch_names.append(temp_names) + batch_images.append(np.concatenate(temp_images, axis=0)) + batch_shapes.append(temp_shapes) + + temp_names = [name] + temp_images = [custom_image] + temp_shapes = [(h, w)] + else: + temp_names.append(name) + temp_images.append(custom_image) + temp_shapes.append((h, w)) + + return batch_names, batch_images, batch_shapes +########################################################## + + +################## About Operating box ################# +def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32): + # Resize and pad image while meeting stride-multiple constraints + shape = im.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: # only scale down, do not scale up (for better val mAP) + r = min(r, 1.0) + + # Compute padding + ratio = r, r # width, height ratios + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + if auto: # minimum rectangle + dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding + elif scaleFill: # stretch + dw, dh = 0.0, 0.0 + new_unpad = (new_shape[1], new_shape[0]) + ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios + + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border + return im, ratio, (dw, dh) + +def scale_boxes(net_shape, boxes, ori_shape, use_letterbox=False): + # Rescale boxes (xyxy) from net_shape to ori_shape + + if use_letterbox: + + gain = min( + net_shape[0] / ori_shape[0], net_shape[1] / ori_shape[1] + ) # gain = new / old + pad = (net_shape[1] - ori_shape[1] * gain) / 2, ( + net_shape[0] - ori_shape[0] * gain + ) / 2.0 + + boxes[:, [0, 2]] -= pad[0] # x padding + boxes[:, [1, 3]] -= pad[1] # y padding + boxes[:, :4] /= gain + else: + x_scale, y_scale = net_shape[1] / ori_shape[1], net_shape[0] / ori_shape[0] + + boxes[:, 0] /= x_scale + boxes[:, 1] /= y_scale + boxes[:, 2] /= x_scale + boxes[:, 3] /= y_scale + + clip_boxes(boxes, ori_shape) + return boxes + +def clip_boxes(boxes, shape): + + boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2 + boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2 +########################################################## + + +################## About pre and post processing ######### +def pre_processing(src_img, imgsz=608): + resized = cv2.resize(src_img, (imgsz, imgsz), interpolation=cv2.INTER_LINEAR) + in_img = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB) + in_img = np.transpose(in_img, (2, 0, 1)).astype(np.float32) + in_img = np.expand_dims(in_img, axis=0) + in_img /= 255.0 + return in_img + +def nms_cpu(boxes, confs, nms_thresh=0.5, min_mode=False): + # print(boxes.shape) + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1) * (y2 - y1) + order = confs.argsort()[::-1] + + keep = [] + while order.size > 0: + idx_self = order[0] + idx_other = order[1:] + + keep.append(idx_self) + + xx1 = np.maximum(x1[idx_self], x1[idx_other]) + yy1 = np.maximum(y1[idx_self], y1[idx_other]) + xx2 = np.minimum(x2[idx_self], x2[idx_other]) + yy2 = np.minimum(y2[idx_self], y2[idx_other]) + + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + inter = w * h + + if min_mode: + over = inter / np.minimum(areas[order[0]], areas[order[1:]]) + else: + over = inter / (areas[order[0]] + areas[order[1:]] - inter) + + inds = np.where(over <= nms_thresh)[0] + order = order[inds + 1] + + return np.array(keep) + + +def post_processing(img, conf_thresh, nms_thresh, output, num_classes=80): + + # [batch, num, 1, 4] + box_array = output[:, :, :4] + # [batch, num, 2] + class_confs = output[:, :, 4:] + + max_conf = class_confs[:, :, 1] + max_id = class_confs[:, :, 0] + + bboxes_batch = [] + for i in range(box_array.shape[0]): + + argwhere = max_conf[i] > conf_thresh + l_box_array = box_array[i, argwhere, :] + l_max_conf = max_conf[i, argwhere] + l_max_id = max_id[i, argwhere] + + bboxes = [] + # nms for each class + for j in range(num_classes): + + cls_argwhere = l_max_id == j + ll_box_array = l_box_array[cls_argwhere, :] + ll_max_conf = l_max_conf[cls_argwhere] + ll_max_id = l_max_id[cls_argwhere] + + keep = nms_cpu(ll_box_array, ll_max_conf, nms_thresh) + + if (keep.size > 0): + ll_box_array = ll_box_array[keep, :] + ll_max_conf = ll_max_conf[keep] + ll_max_id = ll_max_id[keep] + + for k in range(ll_box_array.shape[0]): + bboxes.append([ll_box_array[k, 0], ll_box_array[k, 1], ll_box_array[k, 2], + ll_box_array[k, 3], ll_max_conf[k], ll_max_conf[k], ll_max_id[k]]) + + bboxes_batch.append(bboxes) + + return bboxes_batch +########################################################## + diff --git a/models/cv/detection/yolov6/ixrt/deploy.py b/models/cv/detection/yolov6/ixrt/deploy.py new file mode 100644 index 00000000..f73d14b2 --- /dev/null +++ b/models/cv/detection/yolov6/ixrt/deploy.py @@ -0,0 +1,99 @@ +# !/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import argparse +import copy + +from typing import Union, Callable, List + +from tensorrt.deploy.api import * +from tensorrt.deploy.backend.onnx.converter import default_converter +from tensorrt.deploy.backend.torch.executor.operators._operators import to_py_type +from tensorrt.deploy.ir.operator_attr import BaseOperatorAttr, EmptyAttr +from tensorrt.deploy.ir.operator_type import OperatorType as OP +from tensorrt.deploy.ir import operator_attr as attr, Operator, generate_operator_name +from tensorrt.deploy.fusion import BasePass, PatternGraph, build_sequence_graph, GraphMatcher, PassSequence +from tensorrt.deploy.ir import Graph +from tensorrt.deploy.quantizer.quant_operator.base import quant_single_input_operator +from tensorrt.deploy.backend.onnx.converter import convert_onnx_operator +from tensorrt.deploy.api import GraphTransform, create_source, create_target + +class FuseSiLUPass(BasePass): + def process(self, graph: Graph) -> Graph: + pattern = build_sequence_graph([OP.SIGMOID, OP.MUL]) + + matcher = GraphMatcher(pattern, strict=False) + self.transform = GraphTransform(graph) + matcher.findall(graph, self.fuse_mish) + return graph + + def fuse_mish(self, graph: Graph, pattern_graph: PatternGraph): + sigmoid = pattern_graph.nodes[0].operator + mul = pattern_graph.nodes[-1].operator + + if not self.can_fused(graph, pattern_graph): + return + + self.transform.delete_operators_between_op_op(sigmoid, mul) + + silu_op = Operator( + name=generate_operator_name(graph, pattern="SiLU_{idx}"), + op_type=OP.SILU, + inputs=copy.copy(sigmoid.inputs), + outputs=copy.copy(mul.outputs), + ) + silu_op.is_quant_operator = sigmoid.is_quant_operator and mul.is_quant_operator + graph.add_operator(silu_op) + + def can_fused(self, graph: Graph, pattern_graph: PatternGraph): + sigmoid = pattern_graph.nodes[0].operator + mul = pattern_graph.nodes[-1].operator + + # 如果 sigmoid 的结果 被多个 OP 使用,则不能融合 + if len(self.transform.get_next_operators(sigmoid)) > 1: + return False + + # 检查 mul 的输入是不是和 sigmoid 是同源的 + softplus_prev_op = graph.get_previous_operators(sigmoid) + if len(softplus_prev_op) != 1: + return False + + mul_prev_op = graph.get_previous_operators(mul) + if len(mul_prev_op) != 2: + return False + + for op in mul_prev_op: + if op is softplus_prev_op[0]: + return True + + return False + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--src", type=str) + parser.add_argument("--dst", type=str) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + + args = parse_args() + graph = create_source(args.src)() + graph = FuseSiLUPass().process(graph) + create_target(saved_path=args.dst).export(graph) + print("Surged onnx lies on", args.dst) diff --git a/models/cv/detection/yolov6/ixrt/inference.py b/models/cv/detection/yolov6/ixrt/inference.py new file mode 100644 index 00000000..836f13b2 --- /dev/null +++ b/models/cv/detection/yolov6/ixrt/inference.py @@ -0,0 +1,253 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import sys +sys.path.insert(0, "YOLOv6") +import json +import argparse +import time +import tensorrt +from tensorrt import Dims +import pycuda.autoinit +import pycuda.driver as cuda +import torch +import numpy as np +from tqdm import tqdm + +from common import create_engine_context, setup_io_bindings + +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + +from yolov6.core.evaler import Evaler +from yolov6.utils.events import NCOLS +from yolov6.utils.nms import non_max_suppression +from yolov6.data.data_load import create_dataloader + + +coco_classes = { + 0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', + 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', + 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', + 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', + 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', + 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', + 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', + 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush' +} + +class EvalerIXRT(Evaler): + def eval_ixrt(self, args, stride=32): + self.stride = stride + def init_data(dataloader, task): + self.is_coco = self.data.get("is_coco", False) + self.ids = self.coco80_to_coco91_class() if self.is_coco else list(range(1000)) + pad = 0.0 + dataloader = create_dataloader( + self.data[task], self.img_size, self.batch_size, self.stride, + check_labels=True, pad=pad, rect=False, data_dict=self.data, task=task)[0] + return dataloader + + dataloader = init_data(None,'val') + pred_results = [] + + input_name = "input" + host_mem = tensorrt.IHostMemory + logger = tensorrt.Logger(tensorrt.Logger.ERROR) + engine, context = create_engine_context(args.model_engine, logger) + input_idx = engine.get_binding_index(input_name) + context.set_binding_shape(input_idx, Dims((args.bsz,3,args.imgsz,args.imgsz))) + inputs, outputs, allocations = setup_io_bindings(engine, context) + + if args.warm_up > 0: + print("\nWarm Start.") + for i in range(args.warm_up): + context.execute_v2(allocations) + print("Warm Done.") + + pbar = tqdm(dataloader, desc="Inferencing model in validation dataset.", ncols=NCOLS) + + forward_time = 0.0 + num_samples = 0 + for imgs, targes, paths, shapes in pbar: + imgs = imgs.float() + pad_batch = len(imgs) != self.batch_size + if pad_batch: + origin_size = len(imgs) + imgs = np.resize(imgs, (self.batch_size, *imgs.shape[1:])) + imgs /= 255.0 + # print(imgs.shape) + batch_data = np.ascontiguousarray(imgs) + data_shape = batch_data.shape + + cur_bsz_sample = batch_data.shape[0] + num_samples += cur_bsz_sample + + # Set input + input_idx = engine.get_binding_index(input_name) + context.set_binding_shape(input_idx, Dims(data_shape)) + inputs, outputs, allocations = setup_io_bindings(engine, context) + + cuda.memcpy_htod(inputs[0]["allocation"], batch_data) + # Prepare the output data + output = np.zeros(outputs[0]["shape"], outputs[0]["dtype"]) + + + start_time = time.time() + context.execute_v2(allocations) + end_time = time.time() + forward_time += end_time - start_time + + cuda.memcpy_dtoh(output, outputs[0]["allocation"]) + + if not args.perf_only: + if pad_batch: + output = output[:origin_size] + + outputs = torch.from_numpy(output) + outputs = non_max_suppression(outputs, self.conf_thres, self.iou_thres, multi_label=True) + pred_results.extend(self.convert_to_coco_format(outputs, imgs, paths, shapes, self.ids)) + if args.perf_only: + fps = num_samples / forward_time + return fps + else: + return dataloader, pred_results + + def eval_ixrt_map(self, pred_results, dataloader, task): + '''Evaluate models + For task speed, this function only evaluates the speed of model and outputs inference time. + For task val, this function evaluates the speed and mAP by pycocotools, and returns + inference time and mAP value. + ''' + if not self.do_coco_metric and self.do_pr_metric: + return self.pr_metric_result + print(f'\nEvaluating mAP by pycocotools.') + if task != 'speed' and len(pred_results): + if 'anno_path' in self.data: + anno_json = self.data['anno_path'] + else: + # generated coco format labels in dataset initialization + task = 'val' if task == 'train' else task + dataset_root = os.path.dirname(os.path.dirname(self.data[task])) + base_name = os.path.basename(self.data[task]) + anno_json = os.path.join(dataset_root, 'annotations', f'instances_{base_name}.json') + pred_json = os.path.join(self.save_dir, "predictions.json") + print(f'Saving {pred_json}...') + with open(pred_json, 'w') as f: + json.dump(pred_results, f) + + anno = COCO(anno_json) + pred = anno.loadRes(pred_json) + cocoEval = COCOeval(anno, pred, 'bbox') + if self.is_coco: + imgIds = [int(os.path.basename(x).split(".")[0]) + for x in dataloader.dataset.img_paths] + cocoEval.params.imgIds = imgIds + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() + + return cocoEval.stats + else: + print("pred_results is none") + return None + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model_engine", + type=str, + default="", + help="model engine path", + ) + + parser.add_argument("--bsz", type=int, default=32, help="test batch size") + parser.add_argument( + "--imgsz", + "--img", + "--img-size", + type=int, + default=608, + help="inference size h,w", + ) + + parser.add_argument("--datasets", + type=str, + required=True, + help="datasets path.") + + parser.add_argument("--warm_up", type=int, default=3, help="warm_up count") + + parser.add_argument("--acc_target", + type=float, + default=None, + help="Model inference Accuracy target.") + + parser.add_argument("--fps_target", + type=float, + default=None, + help="Model inference FPS target.") + + parser.add_argument("--perf_only", + type=bool, + default=False, + help="Run performance test only") + + args = parser.parse_args() + + return args + +def main(): + args = parse_args() + + task = 'val' + + batch_size = args.bsz + data_path = os.path.join(args.datasets, "images", "val2017") + label_path = os.path.join(args.datasets, "annotations", "instances_val2017.json") + + + data = { + 'task': 'val', + 'val': data_path, + 'anno_path': label_path, + 'names': coco_classes, + 'is_coco': True, + 'nc': 80 + } + + evaluator = EvalerIXRT(data, batch_size) + + if args.perf_only: + fps = evaluator.eval_ixrt(args) + print("FPS : ", fps) + print(f"Performance Check : Test {fps} >= target {args.fps_target}") + else: + dataloader, pred_results = evaluator.eval_ixrt(args) + eval_result = evaluator.eval_ixrt_map(pred_results, dataloader, task) + map, map50 = eval_result[:2] + print("MAP@0.5 : ", map50) + print(f"Accuracy Check : Test {map50} >= target {args.acc_target}") + if map50 >= args.acc_target: + print("pass!") + exit() + else: + print("failed!") + exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/models/cv/detection/yolov6/ixrt/quant.py b/models/cv/detection/yolov6/ixrt/quant.py new file mode 100644 index 00000000..70265cbc --- /dev/null +++ b/models/cv/detection/yolov6/ixrt/quant.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import cv2 +import random +import argparse +import numpy as np +from tensorrt.deploy import static_quantize + +import torch +import torchvision.datasets +from torch.utils.data import DataLoader +from common import letterbox + + +def setseed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_name", type=str) + parser.add_argument("--model", type=str, default="yolov4_bs16_without_decoder.onnx") + parser.add_argument("--dataset_dir", type=str, default="./coco2017/val2017") + parser.add_argument("--ann_file", type=str, default="./coco2017/annotations/instances_val2017.json") + parser.add_argument("--observer", type=str, choices=["hist_percentile", "percentile", "minmax", "entropy", "ema"], default="hist_percentile") + parser.add_argument("--disable_quant_names", nargs='*', type=str) + parser.add_argument("--save_quant_model", type=str, help="save the quantization model path", default=None) + parser.add_argument("--bsz", type=int, default=16) + parser.add_argument("--step", type=int, default=32) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--imgsz", type=int, default=608) + parser.add_argument("--use_letterbox", action="store_true") + args = parser.parse_args() + return args + +args = parse_args() +setseed(args.seed) +model_name = args.model_name + + +def get_dataloader(data_dir, step=32, batch_size=16, new_shape=[608, 608], use_letterbox=False): + num = step * batch_size + val_list = [os.path.join(data_dir, x) for x in os.listdir(data_dir)] + random.shuffle(val_list) + pic_list = val_list[:num] + + calibration_dataset = [] + for file_path in pic_list: + pic_data = cv2.imread(file_path) + org_img = pic_data + assert org_img is not None, 'Image not Found ' + file_path + h0, w0 = org_img.shape[:2] + + if use_letterbox: + img, ratio, dwdh = letterbox(org_img, new_shape=(new_shape[1], new_shape[0]), auto=False, scaleup=True) + else: + img = cv2.resize(org_img, new_shape) + img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB + img = np.ascontiguousarray(img) / 255.0 # 0~1 np array + img = torch.from_numpy(img).float() + + calibration_dataset.append(img) + + calibration_dataloader = DataLoader( + calibration_dataset, + shuffle=True, + batch_size=batch_size, + drop_last=True + ) + return calibration_dataloader + +dataloader = get_dataloader( + data_dir=args.dataset_dir, + step=args.step, + batch_size=args.bsz, + new_shape=(args.imgsz, args.imgsz), + use_letterbox=args.use_letterbox +) + +dirname = os.path.dirname(args.save_quant_model) +quant_json_path = os.path.join(dirname, f"quantized_{model_name}.json") + +static_quantize(args.model, + calibration_dataloader=dataloader, + save_quant_onnx_path=args.save_quant_model, + save_quant_params_path=quant_json_path, + observer=args.observer, + data_preprocess=lambda x: x.to("cuda"), + quant_format="qdq", + disable_quant_names=args.disable_quant_names) diff --git a/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_fp16_accuracy.sh b/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_fp16_accuracy.sh new file mode 100644 index 00000000..0360c1a1 --- /dev/null +++ b/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_fp16_accuracy.sh @@ -0,0 +1,65 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +#!/bin/bash +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +PROJ_DIR=$(cd $(dirname $0);cd ../; pwd) +DATASETS_DIR="${PROJ_DIR}/data/coco" +COCO_GT=${DATASETS_DIR}/annotations/instances_val2017.json +EVAL_DIR=${DATASETS_DIR}/images/val2017 +CHECKPOINTS_DIR="${PROJ_DIR}/data" +RUN_DIR="${PROJ_DIR}" +ORIGINE_MODEL=${CHECKPOINTS_DIR} + +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo RUN_DIR : ${RUN_DIR} +echo ====================== Model Info ====================== +echo Model Name : yolov6s +echo Onnx Path : ${ORIGINE_MODEL} + +BATCH_SIZE=32 +CURRENT_MODEL=${CHECKPOINTS_DIR}/yolov6s.onnx + +# Build Engine +echo Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/yolov6s_fp16.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 ${RUN_DIR}/build_engine.py \ + --precision float16 \ + --model ${CURRENT_MODEL} \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + +# Inference +echo Inference +RUN_BATCH_SIZE=32 +python3 ${RUN_DIR}/inference.py \ + --model_engine ${ENGINE_FILE} \ + --warm_up 2 \ + --bsz ${RUN_BATCH_SIZE} \ + --imgsz 640 \ + --datasets ${DATASETS_DIR} \ + --acc_target 0.3 +exit ${EXIT_STATUS} diff --git a/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_fp16_performance.sh b/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_fp16_performance.sh new file mode 100644 index 00000000..07a103a2 --- /dev/null +++ b/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_fp16_performance.sh @@ -0,0 +1,78 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +#!/bin/bash +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +PROJ_DIR=$(cd $(dirname $0);cd ../; pwd) +DATASETS_DIR="${PROJ_DIR}/data/coco" +COCO_GT=${DATASETS_DIR}/annotations/instances_val2017.json +EVAL_DIR=${DATASETS_DIR}/images/val2017 +CHECKPOINTS_DIR="${PROJ_DIR}/data" +RUN_DIR="${PROJ_DIR}" +ORIGINE_MODEL=${CHECKPOINTS_DIR} + +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo RUN_DIR : ${RUN_DIR} +echo ====================== Model Info ====================== +echo Model Name : yolov6s +echo Onnx Path : ${ORIGINE_MODEL} + +BATCH_SIZE=32 +CURRENT_MODEL=${CHECKPOINTS_DIR}/yolov6s.onnx + +# fuse silu +# FINAL_MODEL=${CHECKPOINTS_DIR}/yolov6_bs${BATCH_SIZE}_fused.onnx +# if [ -f $FINAL_MODEL ];then +# echo " "Fuse silu Skip, $FINAL_MODEL has been existed +# else +# python3 ${RUN_DIR}/deploy.py \ +# --src ${CURRENT_MODEL} \ +# --dst ${FINAL_MODEL} +# echo " "Generate ${FINAL_MODEL} +# fi +# CURRENT_MODEL=${FINAL_MODEL} + +# Build Engine +echo Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/yolov6s_fp16.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 ${RUN_DIR}/build_engine.py \ + --precision float16 \ + --model ${CURRENT_MODEL} \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + +# Inference +echo Inference +RUN_BATCH_SIZE=32 +python3 ${RUN_DIR}/inference.py \ + --model_engine ${ENGINE_FILE} \ + --warm_up 2 \ + --bsz ${RUN_BATCH_SIZE} \ + --imgsz 640 \ + --datasets ${DATASETS_DIR} \ + --perf_only true \ + --fps_target 0.0 +exit ${EXIT_STATUS} diff --git a/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_int8_accuracy.sh b/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_int8_accuracy.sh new file mode 100644 index 00000000..3bb4c743 --- /dev/null +++ b/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_int8_accuracy.sh @@ -0,0 +1,85 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +#!/bin/bash +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +PROJ_DIR=$(cd $(dirname $0);cd ../; pwd) +DATASETS_DIR="${PROJ_DIR}/data/coco" +COCO_GT=${DATASETS_DIR}/annotations/instances_val2017.json +EVAL_DIR=${DATASETS_DIR}/images/val2017 +CHECKPOINTS_DIR="${PROJ_DIR}/data" +RUN_DIR="${PROJ_DIR}" +ORIGINE_MODEL=${CHECKPOINTS_DIR} + +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo RUN_DIR : ${RUN_DIR} +echo ====================== Model Info ====================== +echo Model Name : yolov6s +echo Onnx Path : ${ORIGINE_MODEL} + +BATCH_SIZE=32 +CURRENT_MODEL=${CHECKPOINTS_DIR}/yolov6s.onnx + +# quant +FINAL_MODEL=${CHECKPOINTS_DIR}/quantized_yolov6s_bs${BATCH_SIZE}.onnx +if [ -f $FINAL_MODEL ];then + echo " "Change Batchsize Skip, $FINAL_MODEL has been existed +else + python3 ${RUN_DIR}/quant.py \ + --model_name "YOLOV6s" \ + --model ${CURRENT_MODEL} \ + --bsz ${BATCH_SIZE} \ + --dataset_dir ${EVAL_DIR} \ + --ann_file ${COCO_GT} \ + --observer "hist_percentile" \ + --save_quant_model ${FINAL_MODEL} \ + --imgsz 640 \ + --disable_quant_names '/detect/Split' '/detect/Div' '/detect/Sub' '/detect/Add' '/detect/Add_1' '/detect/Sub_1' '/detect/Div' '/detect/Concat_6' '/detect/Mul' '/detect/Concat_7' \ + --use_letterbox + echo " "Generate ${FINAL_MODEL} +fi +CURRENT_MODEL=${FINAL_MODEL} + +# Build Engine +echo Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/yolov6s_int8.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 ${RUN_DIR}/build_engine.py \ + --precision int8 \ + --model ${CURRENT_MODEL} \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + +# Inference +echo Inference +RUN_BATCH_SIZE=32 +python3 ${RUN_DIR}/inference.py \ + --model_engine ${ENGINE_FILE} \ + --warm_up 2 \ + --bsz ${RUN_BATCH_SIZE} \ + --imgsz 640 \ + --datasets ${DATASETS_DIR} \ + --acc_target 0.3 +exit ${EXIT_STATUS} diff --git a/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_int8_performance.sh b/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_int8_performance.sh new file mode 100644 index 00000000..53ca3397 --- /dev/null +++ b/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_int8_performance.sh @@ -0,0 +1,86 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +#!/bin/bash +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +PROJ_DIR=$(cd $(dirname $0);cd ../; pwd) +DATASETS_DIR="${PROJ_DIR}/data/coco" +COCO_GT=${DATASETS_DIR}/annotations/instances_val2017.json +EVAL_DIR=${DATASETS_DIR}/images/val2017 +CHECKPOINTS_DIR="${PROJ_DIR}/data" +RUN_DIR="${PROJ_DIR}" +ORIGINE_MODEL=${CHECKPOINTS_DIR} + +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo RUN_DIR : ${RUN_DIR} +echo ====================== Model Info ====================== +echo Model Name : yolov6s +echo Onnx Path : ${ORIGINE_MODEL} + +BATCH_SIZE=32 +CURRENT_MODEL=${CHECKPOINTS_DIR}/yolov6s.onnx + +# quant +FINAL_MODEL=${CHECKPOINTS_DIR}/quantized_yolov6s_bs${BATCH_SIZE}.onnx +if [ -f $FINAL_MODEL ];then + echo " "Change Batchsize Skip, $FINAL_MODEL has been existed +else + python3 ${RUN_DIR}/quant.py \ + --model_name "YOLOV6s" \ + --model ${CURRENT_MODEL} \ + --bsz ${BATCH_SIZE} \ + --dataset_dir ${EVAL_DIR} \ + --ann_file ${COCO_GT} \ + --observer "hist_percentile" \ + --save_quant_model ${FINAL_MODEL} \ + --imgsz 640 \ + --disable_quant_names '/detect/Split' '/detect/Div' '/detect/Sub' '/detect/Add' '/detect/Add_1' '/detect/Sub_1' '/detect/Div' '/detect/Concat_6' '/detect/Mul' '/detect/Concat_7' \ + --use_letterbox + echo " "Generate ${FINAL_MODEL} +fi +CURRENT_MODEL=${FINAL_MODEL} + +# Build Engine +echo Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/yolov6s_int8.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 ${RUN_DIR}/build_engine.py \ + --precision int8 \ + --model ${CURRENT_MODEL} \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + +# Inference +echo Inference +RUN_BATCH_SIZE=32 +python3 ${RUN_DIR}/inference.py \ + --model_engine ${ENGINE_FILE} \ + --warm_up 2 \ + --bsz ${RUN_BATCH_SIZE} \ + --imgsz 640 \ + --datasets ${DATASETS_DIR} \ + --perf_only true \ + --fps_target 0.0 +exit ${EXIT_STATUS} \ No newline at end of file diff --git a/models/cv/detection/yolov8/ixrt/README.md b/models/cv/detection/yolov8/ixrt/README.md new file mode 100644 index 00000000..c5560b14 --- /dev/null +++ b/models/cv/detection/yolov8/ixrt/README.md @@ -0,0 +1,59 @@ +# YOLOv8 + +## Description + +Yolov8 combines speed and accuracy in real-time object detection tasks. With a focus on simplicity and efficiency, this model employs a single neural network to make predictions, enabling fast and accurate identification of objects in images or video streams. + +## Setup + +### Install + +```bash +# Install libGL +## CentOS +yum install -y mesa-libGL +## Ubuntu +apt install -y libgl1-mesa-dev + +pip3 install tqdm +pip3 install onnx +pip3 install pycocotools +pip3 install ultralytics +``` + +### Download + +Pretrained model: + +Dataset: to download the validation dataset. + +### Model Conversion + +```bash +python3 export.py --weight yolov8n.pt --batch 32 +onnxsim yolov8n.onnx ./data/yolov8n.onnx +``` + +## Inference + +```bash +export DATASETS_DIR=/Path/to/coco/ +``` + +### FP16 + +```bash +# Accuracy +bash scripts/infer_yolov8n_fp16_accuracy.sh +# Performance +bash scripts/infer_yolov8n_fp16_performance.sh +``` + +### INT8 + +```bash +# Accuracy +bash scripts/infer_yolov8n_int8_accuracy.sh +# Performance +bash scripts/infer_yolov8n_int8_performance.sh +``` diff --git a/models/cv/detection/yolov8/ixrt/build_engine.py b/models/cv/detection/yolov8/ixrt/build_engine.py new file mode 100644 index 00000000..f5e1719a --- /dev/null +++ b/models/cv/detection/yolov8/ixrt/build_engine.py @@ -0,0 +1,94 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import cv2 +import argparse +import numpy as np + +import torch +import tensorrt +from tensorrt import Dims + + +def build_engine_trtapi_staticshape(config): + IXRT_LOGGER = tensorrt.Logger(tensorrt.Logger.WARNING) + builder = tensorrt.Builder(IXRT_LOGGER) + EXPLICIT_BATCH = 1 << (int)(tensorrt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(EXPLICIT_BATCH) + build_config = builder.create_builder_config() + parser = tensorrt.OnnxParser(network, IXRT_LOGGER) + parser.parse_from_file(config.model) + + precision = tensorrt.BuilderFlag.INT8 if config.precision == "int8" else tensorrt.BuilderFlag.FP16 + # print("precision : ", precision) + build_config.set_flag(precision) + + plan = builder.build_serialized_network(network, build_config) + engine_file_path = config.engine + with open(engine_file_path, "wb") as f: + f.write(plan) + print("Build static shape engine done!") + + +def build_engine_trtapi_dynamicshape(config): + IXRT_LOGGER = tensorrt.Logger(tensorrt.Logger.WARNING) + builder = tensorrt.Builder(IXRT_LOGGER) + EXPLICIT_BATCH = 1 << (int)(tensorrt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(EXPLICIT_BATCH) + build_config = builder.create_builder_config() + + profile = builder.create_optimization_profile() + profile.set_shape("input", + Dims([1, 3, 608, 608]), + Dims([32, 3, 608, 608]), + Dims([64, 3, 608, 608]), + ) + build_config.add_optimization_profile(profile) + + parser = tensorrt.OnnxParser(network, IXRT_LOGGER) + parser.parse_from_file(config.model) + precision = tensorrt.BuilderFlag.INT8 if config.precision == "int8" else tensorrt.BuilderFlag.FP16 + # print("precision : ", precision) + build_config.set_flag(precision) + + # set dynamic + num_inputs = network.num_inputs + for i in range(num_inputs): + input_tensor = network.get_input(i) + input_tensor.shape = Dims([-1, 3, 608, 608]) + + plan = builder.build_serialized_network(network, build_config) + engine_file_path = config.engine + with open(engine_file_path, "wb") as f: + f.write(plan) + print("Build dynamic shape engine done!") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str) + parser.add_argument("--precision", type=str, choices=["float16", "int8", "float32"], default="int8", + help="The precision of datatype") + # engine args + parser.add_argument("--engine", type=str, default=None) + + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + build_engine_trtapi_staticshape(args) + # build_engine_trtapi_dynamicshape(args) diff --git a/models/cv/detection/yolov8/ixrt/common.py b/models/cv/detection/yolov8/ixrt/common.py new file mode 100644 index 00000000..dc3c2766 --- /dev/null +++ b/models/cv/detection/yolov8/ixrt/common.py @@ -0,0 +1,335 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import cv2 +import glob +import time +import numpy as np +from tqdm import tqdm + +import tensorrt +import pycuda.driver as cuda + + +def load_class_names(namesfile): + class_names = [] + with open(namesfile, 'r') as fp: + lines = fp.readlines() + for line in lines: + line = line.rstrip() + class_names.append(line) + return class_names + +# input : [bsz, box_num, 5(cx, cy, w, h, conf) + class_num(prob[0], prob[1], ...)] +# output : [bsz, box_num, 6(left_top_x, left_top_y, right_bottom_x, right_bottom_y, class_id, max_prob*conf)] +def box_class85to6(input): + center_x_y = input[:, :2] + side = input[:, 2:4] + conf = input[:, 4:5] + class_id = np.argmax(input[:, 5:], axis = -1) + class_id = class_id.astype(np.float32).reshape(-1, 1) + 1 + max_prob = np.max(input[:, 5:], axis = -1).reshape(-1, 1) + x1_y1 = center_x_y - 0.5 * side + x2_y2 = center_x_y + 0.5 * side + nms_input = np.concatenate([x1_y1, x2_y2, class_id, max_prob*conf], axis = -1) + return nms_input + +def save2json(batch_img_id, pred_boxes, json_result, class_trans): + for i, boxes in enumerate(pred_boxes): + if boxes is not None: + image_id = int(batch_img_id[i]) + # have no target + if image_id == -1: + continue + + for x1, y1, x2, y2, _, p, c in boxes: + x1, y1, x2, y2, p = float(x1), float(y1), float(x2), float(y2), float(p) + c = int(c) + x = x1 + y = y1 + w = x2 - x1 + h = y2 - y1 + + json_result.append( + { + "image_id": image_id, + "category_id": class_trans[c - 1], + "bbox": [x, y, w, h], + "score": p, + } + ) + +################## About TensorRT ################# +def create_engine_context(engine_path, logger): + with open(engine_path, "rb") as f: + runtime = tensorrt.Runtime(logger) + assert runtime + engine = runtime.deserialize_cuda_engine(f.read()) + assert engine + context = engine.create_execution_context() + assert context + + return engine, context + +def setup_io_bindings(engine, context): + # Setup I/O bindings + inputs = [] + outputs = [] + allocations = [] + + for i in range(engine.num_bindings): + is_input = False + if engine.binding_is_input(i): + is_input = True + name = engine.get_binding_name(i) + dtype = engine.get_binding_dtype(i) + shape = context.get_binding_shape(i) + if is_input: + batch_size = shape[0] + size = np.dtype(tensorrt.nptype(dtype)).itemsize + for s in shape: + size *= s + allocation = cuda.mem_alloc(size) + binding = { + "index": i, + "name": name, + "dtype": np.dtype(tensorrt.nptype(dtype)), + "shape": list(shape), + "allocation": allocation, + } + # print(f"binding {i}, name : {name} dtype : {np.dtype(tensorrt.nptype(dtype))} shape : {list(shape)}") + allocations.append(allocation) + if engine.binding_is_input(i): + inputs.append(binding) + else: + outputs.append(binding) + return inputs, outputs, allocations +########################################################## + + +################## About Loading Dataset ################# +def load_images(images_path): + """ + If image path is given, return it directly + For txt file, read it and return each line as image path + In other case, it's a folder, return a list with names of each + jpg, jpeg and png file + """ + input_path_extension = images_path.split('.')[-1] + if input_path_extension in ['jpg', 'jpeg', 'png']: + return [images_path] + elif input_path_extension == "txt": + with open(images_path, "r") as f: + return f.read().splitlines() + else: + return glob.glob( + os.path.join(images_path, "*.jpg")) + \ + glob.glob(os.path.join(images_path, "*.png")) + \ + glob.glob(os.path.join(images_path, "*.jpeg")) + +def prepare_batch(images_path, bs=16, input_size=(608, 608)): + + width, height = input_size + + batch_names = [] + batch_images = [] + batch_shapes = [] + + temp_names = [] + temp_images = [] + temp_shapes = [] + + for i, image_path in tqdm(enumerate(images_path), desc="Loading coco data"): + name = os.path.basename(image_path) + image = cv2.imread(image_path) + h, w, _ = image.shape + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image_resized = cv2.resize(image_rgb, (width, height), + interpolation=cv2.INTER_LINEAR) + custom_image = image_resized.transpose(2, 0, 1).astype(np.float32) / 255. + custom_image = np.expand_dims(custom_image, axis=0) + + if i != 0 and i % bs == 0: + batch_names.append(temp_names) + batch_images.append(np.concatenate(temp_images, axis=0)) + batch_shapes.append(temp_shapes) + + temp_names = [name] + temp_images = [custom_image] + temp_shapes = [(h, w)] + else: + temp_names.append(name) + temp_images.append(custom_image) + temp_shapes.append((h, w)) + + return batch_names, batch_images, batch_shapes +########################################################## + + +################## About Operating box ################# +def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32): + # Resize and pad image while meeting stride-multiple constraints + shape = im.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: # only scale down, do not scale up (for better val mAP) + r = min(r, 1.0) + + # Compute padding + ratio = r, r # width, height ratios + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + if auto: # minimum rectangle + dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding + elif scaleFill: # stretch + dw, dh = 0.0, 0.0 + new_unpad = (new_shape[1], new_shape[0]) + ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios + + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border + return im, ratio, (dw, dh) + +def scale_boxes(net_shape, boxes, ori_shape, use_letterbox=False): + # Rescale boxes (xyxy) from net_shape to ori_shape + + if use_letterbox: + + gain = min( + net_shape[0] / ori_shape[0], net_shape[1] / ori_shape[1] + ) # gain = new / old + pad = (net_shape[1] - ori_shape[1] * gain) / 2, ( + net_shape[0] - ori_shape[0] * gain + ) / 2.0 + + boxes[:, [0, 2]] -= pad[0] # x padding + boxes[:, [1, 3]] -= pad[1] # y padding + boxes[:, :4] /= gain + else: + x_scale, y_scale = net_shape[1] / ori_shape[1], net_shape[0] / ori_shape[0] + + boxes[:, 0] /= x_scale + boxes[:, 1] /= y_scale + boxes[:, 2] /= x_scale + boxes[:, 3] /= y_scale + + clip_boxes(boxes, ori_shape) + return boxes + +def clip_boxes(boxes, shape): + + boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2 + boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2 +########################################################## + + +################## About pre and post processing ######### +def pre_processing(src_img, imgsz=608): + resized = cv2.resize(src_img, (imgsz, imgsz), interpolation=cv2.INTER_LINEAR) + in_img = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB) + in_img = np.transpose(in_img, (2, 0, 1)).astype(np.float32) + in_img = np.expand_dims(in_img, axis=0) + in_img /= 255.0 + return in_img + +def nms_cpu(boxes, confs, nms_thresh=0.5, min_mode=False): + # print(boxes.shape) + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1) * (y2 - y1) + order = confs.argsort()[::-1] + + keep = [] + while order.size > 0: + idx_self = order[0] + idx_other = order[1:] + + keep.append(idx_self) + + xx1 = np.maximum(x1[idx_self], x1[idx_other]) + yy1 = np.maximum(y1[idx_self], y1[idx_other]) + xx2 = np.minimum(x2[idx_self], x2[idx_other]) + yy2 = np.minimum(y2[idx_self], y2[idx_other]) + + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + inter = w * h + + if min_mode: + over = inter / np.minimum(areas[order[0]], areas[order[1:]]) + else: + over = inter / (areas[order[0]] + areas[order[1:]] - inter) + + inds = np.where(over <= nms_thresh)[0] + order = order[inds + 1] + + return np.array(keep) + + +def post_processing(img, conf_thresh, nms_thresh, output, num_classes=80): + + # [batch, num, 1, 4] + box_array = output[:, :, :4] + # [batch, num, 2] + class_confs = output[:, :, 4:] + + max_conf = class_confs[:, :, 1] + max_id = class_confs[:, :, 0] + + bboxes_batch = [] + for i in range(box_array.shape[0]): + + argwhere = max_conf[i] > conf_thresh + l_box_array = box_array[i, argwhere, :] + l_max_conf = max_conf[i, argwhere] + l_max_id = max_id[i, argwhere] + + bboxes = [] + # nms for each class + for j in range(num_classes): + + cls_argwhere = l_max_id == j + ll_box_array = l_box_array[cls_argwhere, :] + ll_max_conf = l_max_conf[cls_argwhere] + ll_max_id = l_max_id[cls_argwhere] + + keep = nms_cpu(ll_box_array, ll_max_conf, nms_thresh) + + if (keep.size > 0): + ll_box_array = ll_box_array[keep, :] + ll_max_conf = ll_max_conf[keep] + ll_max_id = ll_max_id[keep] + + for k in range(ll_box_array.shape[0]): + bboxes.append([ll_box_array[k, 0], ll_box_array[k, 1], ll_box_array[k, 2], + ll_box_array[k, 3], ll_max_conf[k], ll_max_conf[k], ll_max_id[k]]) + + bboxes_batch.append(bboxes) + + return bboxes_batch +########################################################## + diff --git a/models/cv/detection/yolov8/ixrt/export.py b/models/cv/detection/yolov8/ixrt/export.py new file mode 100644 index 00000000..383b327e --- /dev/null +++ b/models/cv/detection/yolov8/ixrt/export.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import argparse +from ultralytics import YOLO + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--weight", + type=str, + required=True, + help="pytorch model weight.") + + parser.add_argument("--batch", + type=int, + required=True, + help="batchsize of the model.") + args = parser.parse_args() + + return args + +def main(): + args = parse_args() + + model = YOLO(args.weight).cpu() + + model.export(format='onnx', batch=args.batch, imgsz=(640, 640), opset=11) + +if __name__ == "__main__": + main() diff --git a/models/cv/detection/yolov8/ixrt/inference.py b/models/cv/detection/yolov8/ixrt/inference.py new file mode 100644 index 00000000..d83b0136 --- /dev/null +++ b/models/cv/detection/yolov8/ixrt/inference.py @@ -0,0 +1,237 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import json +import argparse +import time +import tensorrt +from tensorrt import Dims +import pycuda.autoinit +import pycuda.driver as cuda +import torch +import numpy as np +from tqdm import tqdm + +from common import create_engine_context, setup_io_bindings + +from pathlib import Path + +from ultralytics.cfg import get_cfg +from ultralytics.data import converter +from ultralytics.utils import DEFAULT_CFG +from ultralytics.data.utils import check_det_dataset +from ultralytics.utils.metrics import ConfusionMatrix +from ultralytics.models.yolo.detect import DetectionValidator + +coco_classes = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', + 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', + 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', + 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', + 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', + 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', + 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', + 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'} + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--model_engine", + type=str, + required=True, + help="ixrt engine path.") + + parser.add_argument("--bsz", + type=int, + required=True, + help="inference batch size.") + + parser.add_argument( + "--imgsz", + "--img", + "--img-size", + type=int, + default=640, + help="inference size h,w", + ) + + parser.add_argument("--datasets", + type=str, + required=True, + help="datasets path.") + + parser.add_argument("--warm_up", + type=int, + default=3, + help="number of warmup before test.") + + parser.add_argument("--num_workers", + type=int, + default=16, + help="number of workers used in pytorch dataloader.") + + parser.add_argument("--acc_target", + type=float, + default=0.0, + help="Model inference Accuracy target.") + + parser.add_argument("--fps_target", + type=float, + default=0.0, + help="Model inference FPS target.") + + parser.add_argument("--conf", + type=float, + default=0.001, + help="confidence threshold.") + + parser.add_argument("--iou", + type=float, + default=0.65, + help="iou threshold.") + + parser.add_argument("--perf_only", + type=bool, + default=False, + help="Run performance test only") + + args = parser.parse_args() + + return args + +class IxRT_Validator(DetectionValidator): + def __call__(self, config, data): + self.data = data + self.stride = 32 + self.dataloader = self.get_dataloader(self.data.get(self.args.split), self.args.batch) + self.init_metrics() + + total_num = 0 + + input_name = "input" + host_mem = tensorrt.IHostMemory + logger = tensorrt.Logger(tensorrt.Logger.ERROR) + engine, context = create_engine_context(config.model_engine, logger) + input_idx = engine.get_binding_index(input_name) + context.set_binding_shape(input_idx, Dims((config.bsz,3,config.imgsz,config.imgsz))) + inputs, outputs, allocations = setup_io_bindings(engine, context) + + if config.warm_up > 0: + print("\nWarm Start.") + for i in range(config.warm_up): + context.execute_v2(allocations) + print("Warm Done.") + + forward_time = 0.0 + num_samples = 0 + + for batch in tqdm(self.dataloader): + batch = self.preprocess(batch) + + imgs = batch['img'] + pad_batch = len(imgs) != self.args.batch + if pad_batch: + origin_size = len(imgs) + imgs = np.resize(imgs, (self.args.batch, *imgs.shape[1:])) + + batch_data = np.ascontiguousarray(imgs) + data_shape = batch_data.shape + + cur_bsz_sample = batch_data.shape[0] + num_samples += cur_bsz_sample + + # Set input + input_idx = engine.get_binding_index(input_name) + context.set_binding_shape(input_idx, Dims(data_shape)) + inputs, outputs, allocations = setup_io_bindings(engine, context) + + cuda.memcpy_htod(inputs[0]["allocation"], batch_data) + # Prepare the output data + output = np.zeros(outputs[0]["shape"], outputs[0]["dtype"]) + + + start_time = time.time() + context.execute_v2(allocations) + end_time = time.time() + forward_time += end_time - start_time + + cuda.memcpy_dtoh(output, outputs[0]["allocation"]) + if pad_batch: + output = output[:origin_size] + + outputs = torch.from_numpy(output) + + preds = self.postprocess([outputs]) + + self.update_metrics(preds, batch) + + if config.perf_only: + fps = num_samples / forward_time + return fps + else: + stats = self.get_stats() + + if self.args.save_json and self.jdict: + with open(str(self.save_dir / 'predictions.json'), 'w') as f: + print(f'Saving {f.name} ...') + json.dump(self.jdict, f) # flatten and save + + stats = self.eval_json(stats) + + return stats + + def init_metrics(self): + """Initialize evaluation metrics for YOLO.""" + val = self.data.get(self.args.split, '') # validation path + self.is_coco = isinstance(val, str) and 'coco' in val and val.endswith(f'{os.sep}val2017.txt') # is COCO + self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000)) + self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO + self.names = self.data['names'] + self.nc = len(self.names) + self.metrics.names = self.names + self.confusion_matrix = ConfusionMatrix(nc=80) + self.seen = 0 + self.jdict = [] + self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[]) + +def main(): + config = parse_args() + + batch_size = config.bsz + + overrides = {'mode': 'val'} + cfg_args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) + + cfg_args.batch = batch_size + cfg_args.save_json = True + + data = { + 'path': Path(config.datasets), + 'val': os.path.join(config.datasets, 'val2017.txt'), + 'names': coco_classes + } + + validator = IxRT_Validator(args=cfg_args, save_dir=Path('.')) + + if config.perf_only: + fps = validator(config, data) + print("FPS : ", fps) + print(f"Performance Check : Test {fps} >= target {config.fps_target}") + else: + stats = validator(config, data) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/models/cv/detection/yolov8/ixrt/quant.py b/models/cv/detection/yolov8/ixrt/quant.py new file mode 100644 index 00000000..70265cbc --- /dev/null +++ b/models/cv/detection/yolov8/ixrt/quant.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import cv2 +import random +import argparse +import numpy as np +from tensorrt.deploy import static_quantize + +import torch +import torchvision.datasets +from torch.utils.data import DataLoader +from common import letterbox + + +def setseed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_name", type=str) + parser.add_argument("--model", type=str, default="yolov4_bs16_without_decoder.onnx") + parser.add_argument("--dataset_dir", type=str, default="./coco2017/val2017") + parser.add_argument("--ann_file", type=str, default="./coco2017/annotations/instances_val2017.json") + parser.add_argument("--observer", type=str, choices=["hist_percentile", "percentile", "minmax", "entropy", "ema"], default="hist_percentile") + parser.add_argument("--disable_quant_names", nargs='*', type=str) + parser.add_argument("--save_quant_model", type=str, help="save the quantization model path", default=None) + parser.add_argument("--bsz", type=int, default=16) + parser.add_argument("--step", type=int, default=32) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--imgsz", type=int, default=608) + parser.add_argument("--use_letterbox", action="store_true") + args = parser.parse_args() + return args + +args = parse_args() +setseed(args.seed) +model_name = args.model_name + + +def get_dataloader(data_dir, step=32, batch_size=16, new_shape=[608, 608], use_letterbox=False): + num = step * batch_size + val_list = [os.path.join(data_dir, x) for x in os.listdir(data_dir)] + random.shuffle(val_list) + pic_list = val_list[:num] + + calibration_dataset = [] + for file_path in pic_list: + pic_data = cv2.imread(file_path) + org_img = pic_data + assert org_img is not None, 'Image not Found ' + file_path + h0, w0 = org_img.shape[:2] + + if use_letterbox: + img, ratio, dwdh = letterbox(org_img, new_shape=(new_shape[1], new_shape[0]), auto=False, scaleup=True) + else: + img = cv2.resize(org_img, new_shape) + img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB + img = np.ascontiguousarray(img) / 255.0 # 0~1 np array + img = torch.from_numpy(img).float() + + calibration_dataset.append(img) + + calibration_dataloader = DataLoader( + calibration_dataset, + shuffle=True, + batch_size=batch_size, + drop_last=True + ) + return calibration_dataloader + +dataloader = get_dataloader( + data_dir=args.dataset_dir, + step=args.step, + batch_size=args.bsz, + new_shape=(args.imgsz, args.imgsz), + use_letterbox=args.use_letterbox +) + +dirname = os.path.dirname(args.save_quant_model) +quant_json_path = os.path.join(dirname, f"quantized_{model_name}.json") + +static_quantize(args.model, + calibration_dataloader=dataloader, + save_quant_onnx_path=args.save_quant_model, + save_quant_params_path=quant_json_path, + observer=args.observer, + data_preprocess=lambda x: x.to("cuda"), + quant_format="qdq", + disable_quant_names=args.disable_quant_names) diff --git a/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_fp16_accuracy.sh b/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_fp16_accuracy.sh new file mode 100644 index 00000000..8868533d --- /dev/null +++ b/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_fp16_accuracy.sh @@ -0,0 +1,65 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +#!/bin/bash +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +PROJ_DIR=$(cd $(dirname $0);cd ../; pwd) +DATASETS_DIR="${PROJ_DIR}/data/coco" +COCO_GT=${DATASETS_DIR}/annotations/instances_val2017.json +EVAL_DIR=${DATASETS_DIR}/images/val2017 +CHECKPOINTS_DIR="${PROJ_DIR}/data" +RUN_DIR="${PROJ_DIR}" +ORIGINE_MODEL=${CHECKPOINTS_DIR} + +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo RUN_DIR : ${RUN_DIR} +echo ====================== Model Info ====================== +echo Model Name : yolov8n +echo Onnx Path : ${ORIGINE_MODEL} + +BATCH_SIZE=32 +CURRENT_MODEL=${CHECKPOINTS_DIR}/yolov8n.onnx + +# Build Engine +echo Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/yolov8n_fp16.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 ${RUN_DIR}/build_engine.py \ + --precision float16 \ + --model ${CURRENT_MODEL} \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + +# Inference +echo Inference +RUN_BATCH_SIZE=32 +python3 ${RUN_DIR}/inference.py \ + --model_engine ${ENGINE_FILE} \ + --warm_up 2 \ + --bsz ${RUN_BATCH_SIZE} \ + --imgsz 640 \ + --datasets ${DATASETS_DIR} \ + --acc_target 0.3 +exit ${EXIT_STATUS} diff --git a/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_fp16_performance.sh b/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_fp16_performance.sh new file mode 100644 index 00000000..b9a28a3a --- /dev/null +++ b/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_fp16_performance.sh @@ -0,0 +1,66 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +#!/bin/bash +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +PROJ_DIR=$(cd $(dirname $0);cd ../; pwd) +DATASETS_DIR="${PROJ_DIR}/data/coco" +COCO_GT=${DATASETS_DIR}/annotations/instances_val2017.json +EVAL_DIR=${DATASETS_DIR}/images/val2017 +CHECKPOINTS_DIR="${PROJ_DIR}/data" +RUN_DIR="${PROJ_DIR}" +ORIGINE_MODEL=${CHECKPOINTS_DIR} + +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo RUN_DIR : ${RUN_DIR} +echo ====================== Model Info ====================== +echo Model Name : yolov8n +echo Onnx Path : ${ORIGINE_MODEL} + +BATCH_SIZE=32 +CURRENT_MODEL=${CHECKPOINTS_DIR}/yolov8n.onnx + +# Build Engine +echo Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/yolov8n_fp16.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 ${RUN_DIR}/build_engine.py \ + --precision float16 \ + --model ${CURRENT_MODEL} \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + +# Inference +echo Inference +RUN_BATCH_SIZE=32 +python3 ${RUN_DIR}/inference.py \ + --model_engine ${ENGINE_FILE} \ + --warm_up 2 \ + --bsz ${RUN_BATCH_SIZE} \ + --imgsz 640 \ + --datasets ${DATASETS_DIR} \ + --perf_only true \ + --fps_target 0.0 +exit ${EXIT_STATUS} diff --git a/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_int8_accuracy.sh b/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_int8_accuracy.sh new file mode 100644 index 00000000..f3259c58 --- /dev/null +++ b/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_int8_accuracy.sh @@ -0,0 +1,85 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +#!/bin/bash +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +PROJ_DIR=$(cd $(dirname $0);cd ../; pwd) +DATASETS_DIR="${PROJ_DIR}/data/coco" +COCO_GT=${DATASETS_DIR}/annotations/instances_val2017.json +EVAL_DIR=${DATASETS_DIR}/images/val2017 +CHECKPOINTS_DIR="${PROJ_DIR}/data" +RUN_DIR="${PROJ_DIR}" +ORIGINE_MODEL=${CHECKPOINTS_DIR} +DISABLE_NAMES=('/model.22/Concat' '/model.22/Concat_1' '/model.22/Concat_2' '/model.22/Reshape' '/model.22/Reshape_1' '/model.22/Reshape_2' '/model.22/Concat_3' '/model.22/Split' '/model.22/dfl/Reshape' '/model.22/dfl/Transpose' '/model.22/dfl/Softmax' '/model.22/dfl/Transpose_1' '/model.22/dfl/conv/Conv' '/model.22/dfl/Reshape_1' '/model.22/Slice' '/model.22/Slice_1' '/model.22/Sub' '/model.22/Add_1' '/model.22/Add_2' '/model.22/Div_1' '/model.22/Sub_1' '/model.22/Concat_4' '/model.22/Mul_2' '/model.22/Sigmoid' '/model.22/Concat_5') + +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo RUN_DIR : ${RUN_DIR} +echo ====================== Model Info ====================== +echo Model Name : yolov8n +echo Onnx Path : ${ORIGINE_MODEL} + +BATCH_SIZE=32 +CURRENT_MODEL=${CHECKPOINTS_DIR}/yolov8n.onnx + +# quant +FINAL_MODEL=${CHECKPOINTS_DIR}/quantized_yolov8n_bs${BATCH_SIZE}.onnx +if [ -f $FINAL_MODEL ];then + echo " "Quantize Skip, $FINAL_MODEL has been existed +else + python3 ${RUN_DIR}/quant.py \ + --model_name "YOLOV8N" \ + --model ${CURRENT_MODEL} \ + --bsz ${BATCH_SIZE} \ + --dataset_dir ${EVAL_DIR} \ + --ann_file ${COCO_GT} \ + --observer "hist_percentile" \ + --save_quant_model ${FINAL_MODEL} \ + --disable_quant_names '/model.22/Concat' '/model.22/Concat_1' '/model.22/Concat_2' '/model.22/Reshape' '/model.22/Reshape_1' '/model.22/Reshape_2' '/model.22/Concat_3' '/model.22/Split' '/model.22/dfl/Reshape' '/model.22/dfl/Transpose' '/model.22/dfl/Softmax' '/model.22/dfl/Transpose_1' '/model.22/dfl/conv/Conv' '/model.22/dfl/Reshape_1' '/model.22/Slice' '/model.22/Slice_1' '/model.22/Sub' '/model.22/Add_1' '/model.22/Add_2' '/model.22/Div_1' '/model.22/Sub_1' '/model.22/Concat_4' '/model.22/Mul_2' '/model.22/Sigmoid' '/model.22/Concat_5' \ + --imgsz 640 + echo " "Generate ${FINAL_MODEL} +fi +CURRENT_MODEL=${FINAL_MODEL} + +# Build Engine +echo Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/yolov8n_int8.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 ${RUN_DIR}/build_engine.py \ + --precision int8 \ + --model ${CURRENT_MODEL} \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + +# Inference +echo Inference +RUN_BATCH_SIZE=32 +python3 ${RUN_DIR}/inference.py \ + --model_engine ${ENGINE_FILE} \ + --warm_up 2 \ + --bsz ${RUN_BATCH_SIZE} \ + --imgsz 640 \ + --datasets ${DATASETS_DIR} \ + --acc_target 0.3 +exit ${EXIT_STATUS} diff --git a/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_int8_performance.sh b/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_int8_performance.sh new file mode 100644 index 00000000..735035d8 --- /dev/null +++ b/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_int8_performance.sh @@ -0,0 +1,85 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +#!/bin/bash +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +PROJ_DIR=$(cd $(dirname $0);cd ../; pwd) +DATASETS_DIR="${PROJ_DIR}/data/coco" +COCO_GT=${DATASETS_DIR}/annotations/instances_val2017.json +EVAL_DIR=${DATASETS_DIR}/images/val2017 +CHECKPOINTS_DIR="${PROJ_DIR}/data" +RUN_DIR="${PROJ_DIR}" +ORIGINE_MODEL=${CHECKPOINTS_DIR} + +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo RUN_DIR : ${RUN_DIR} +echo ====================== Model Info ====================== +echo Model Name : yolov8n +echo Onnx Path : ${ORIGINE_MODEL} + +BATCH_SIZE=32 +CURRENT_MODEL=${CHECKPOINTS_DIR}/yolov8n.onnx + +# quant +FINAL_MODEL=${CHECKPOINTS_DIR}/quantized_yolov8n_bs${BATCH_SIZE}.onnx +if [ -f $FINAL_MODEL ];then + echo " "Quantize Skip, $FINAL_MODEL has been existed +else + python3 ${RUN_DIR}/quant.py \ + --model_name "YOLOV8N" \ + --model ${CURRENT_MODEL} \ + --bsz ${BATCH_SIZE} \ + --dataset_dir ${EVAL_DIR} \ + --ann_file ${COCO_GT} \ + --observer "hist_percentile" \ + --save_quant_model ${FINAL_MODEL} \ + --disable_quant_names '/model.22/Concat' '/model.22/Concat_1' '/model.22/Concat_2' '/model.22/Reshape' '/model.22/Reshape_1' '/model.22/Reshape_2' '/model.22/Concat_3' '/model.22/Split' '/model.22/dfl/Reshape' '/model.22/dfl/Transpose' '/model.22/dfl/Softmax' '/model.22/dfl/Transpose_1' '/model.22/dfl/conv/Conv' '/model.22/dfl/Reshape_1' '/model.22/Slice' '/model.22/Slice_1' '/model.22/Sub' '/model.22/Add_1' '/model.22/Add_2' '/model.22/Div_1' '/model.22/Sub_1' '/model.22/Concat_4' '/model.22/Mul_2' '/model.22/Sigmoid' '/model.22/Concat_5' \ + --imgsz 640 + echo " "Generate ${FINAL_MODEL} +fi +CURRENT_MODEL=${FINAL_MODEL} + +# Build Engine +echo Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/yolov8n_int8.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 ${RUN_DIR}/build_engine.py \ + --precision int8 \ + --model ${CURRENT_MODEL} \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + +# Inference +echo Inference +RUN_BATCH_SIZE=32 +python3 ${RUN_DIR}/inference.py \ + --model_engine ${ENGINE_FILE} \ + --warm_up 2 \ + --bsz ${RUN_BATCH_SIZE} \ + --imgsz 640 \ + --datasets ${DATASETS_DIR} \ + --perf_only true \ + --fps_target 0.0 +exit ${EXIT_STATUS} \ No newline at end of file -- Gitee From efd7bc9314ce6c034935df3b27503c095d45bfe6 Mon Sep 17 00:00:00 2001 From: may Date: Mon, 22 Jul 2024 15:31:58 +0800 Subject: [PATCH 2/7] Add transformer_asr --- .../transformer_asr/ixrt/README.md | 56 ++ .../transformer_asr/ixrt/aishell_prepare.py | 141 ++++ .../transformer_asr/ixrt/beam_search.py | 381 +++++++++++ .../transformer_asr/ixrt/build.sh | 23 + .../transformer_asr/ixrt/builder.py | 466 ++++++++++++++ .../transformer_asr/ixrt/convert.py | 95 +++ .../transformer_asr/ixrt/ctc.py | 394 ++++++++++++ .../ixrt/faster_cat/__init__.py | 13 + .../transformer_asr/ixrt/faster_cat/build.sh | 22 + .../transformer_asr/ixrt/faster_cat/kernel.cu | 79 +++ .../transformer_asr/ixrt/faster_cat/setup.py | 48 ++ .../transformer_asr/ixrt/faster_cat/test.cpp | 21 + .../transformer_asr/ixrt/faster_cat/test.py | 37 ++ .../ixrt/faster_layer_norm/__init__.py | 16 + .../ixrt/faster_layer_norm/build.sh | 22 + .../ixrt/faster_layer_norm/kernel.cu | 168 +++++ .../ixrt/faster_layer_norm/setup.py | 48 ++ .../ixrt/faster_layer_norm/test.cpp | 22 + .../faster_layer_norm/transformer_helper.cuh | 295 +++++++++ .../ixrt/faster_logsumexp/__init__.py | 38 ++ .../ixrt/faster_logsumexp/build.sh | 22 + .../ixrt/faster_logsumexp/kernel.cu | 155 +++++ .../ixrt/faster_logsumexp/setup.py | 48 ++ .../ixrt/faster_logsumexp/test.cpp | 27 + .../ixrt/faster_logsumexp/test.py | 50 ++ .../ixrt/faster_stack/__init__.py | 33 + .../ixrt/faster_stack/build.sh | 22 + .../ixrt/faster_stack/kernel.cu | 146 +++++ .../ixrt/faster_stack/setup.py | 48 ++ .../ixrt/faster_stack/test.cpp | 29 + .../transformer_asr/ixrt/faster_stack/test.py | 74 +++ .../ixrt/hparams/train_ASR_transformer.yaml | 253 ++++++++ .../transformer_asr/ixrt/inference.py | 606 ++++++++++++++++++ .../transformer_asr/ixrt/load_ixrt_plugin.py | 26 + 34 files changed, 3924 insertions(+) create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/README.md create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/aishell_prepare.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/beam_search.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/build.sh create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/builder.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/convert.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/ctc.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/__init__.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/build.sh create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/kernel.cu create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/setup.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.cpp create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/__init__.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/build.sh create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/kernel.cu create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/setup.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/test.cpp create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/transformer_helper.cuh create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/__init__.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/build.sh create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/kernel.cu create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/setup.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.cpp create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/__init__.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/build.sh create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/kernel.cu create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/setup.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.cpp create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/hparams/train_ASR_transformer.yaml create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/inference.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/load_ixrt_plugin.py diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/README.md b/models/speech/speech_recognition/transformer_asr/ixrt/README.md new file mode 100644 index 00000000..7560b5eb --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/README.md @@ -0,0 +1,56 @@ +# Asr transformer fp16 inference (BeamSearch) + +## Description + +Beam search allows us to exert control over the output of text generation. This is useful because we sometimes know exactly what we want inside the output. For example, in a Neural Machine Translation task, we might know which words must be included in the final translation with a dictionary lookup. + + +## Setup + +### Install + +``` +pip3 install speechbrain==0.5.13 +``` + +* ixrt 4.0.1_MR release + +### Download + +Pretrained model: + +Dataset: to download the Aishell dataset. + +``` +# Make sure the checkpoint path is results/transformer/8886/save +mkdir -p results/transformer/8886/save +# Make sure the dataset path is results/transformer/8886/save +mkdir -p /home/data/speechbrain +``` + +## Inference + +### Build faster kernels + +```bash +bash build.sh +``` + +### Build engine + +max_batch_size and max_seq_len depend on the situation. + +``` +python3 builder.py \ +--ckpt_path results/transformer/8886/save \ +--head_num 4 \ +--max_batch_size 64 \ +--max_seq_len 1024 \ +--engine_path transformer.engine +``` + +### Run engine + +``` +python3 inference.py hparams/train_ASR_transformer.yaml --data_folder=/home/data/speechbrain/aishell --engine_path transformer.engine +``` \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/aishell_prepare.py b/models/speech/speech_recognition/transformer_asr/ixrt/aishell_prepare.py new file mode 100644 index 00000000..ba319394 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/aishell_prepare.py @@ -0,0 +1,141 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import shutil +import logging +from speechbrain.dataio.dataio import read_audio +from speechbrain.utils.data_utils import download_file +import glob +import csv +import argparse + +logger = logging.getLogger(__name__) + + +def prepare_aishell(data_folder, save_folder, skip_prep=False): + """ + This function prepares the AISHELL-1 dataset. + If the folder does not exist, the zip file will be extracted. If the zip file does not exist, it will be downloaded. + + data_folder : path to AISHELL-1 dataset. + save_folder: path where to store the manifest csv files. + skip_prep: If True, skip data preparation. + + """ + if skip_prep: + return + + # If the data folders do not exist, we need to extract the data + if not os.path.isdir(os.path.join(data_folder, "data_aishell/wav")): + # # Check for zip file and download if it doesn't exist + # zip_location = os.path.join(data_folder, "data_aishell.tgz") + # if not os.path.exists(zip_location): + # url = "https://www.openslr.org/resources/33/data_aishell.tgz" + # download_file(url, zip_location, unpack=True) + # logger.info("Extracting data_aishell.tgz...") + # shutil.unpack_archive(zip_location, data_folder) + + wav_dir = os.path.join(data_folder, "data_aishell/wav") + tgz_list = glob.glob(wav_dir + "/*.tar.gz") + for tgz in tgz_list: + shutil.unpack_archive(tgz, wav_dir) + os.remove(tgz) + + # Create filename-to-transcript dictionary + filename2transcript = {} + with open( + os.path.join( + data_folder, "data_aishell/transcript/aishell_transcript_v0.8.txt" + ), + "r", + ) as f: + lines = f.readlines() + for line in lines: + key = line.split()[0] + value = " ".join(line.split()[1:]) + filename2transcript[key] = value + + splits = [ + # "train", + "dev", + "test", + ] + ID_start = 0 # needed to have a unique ID for each audio + for split in splits: + new_filename = os.path.join(save_folder, split) + ".csv" + if os.path.exists(new_filename): + continue + logger.info("Preparing %s..." % new_filename) + + csv_output = [["ID", "duration", "wav", "transcript"]] + entry = [] + + all_wavs = glob.glob( + os.path.join(data_folder, "data_aishell/wav") + "/" + split + "/*/*.wav" + ) + for i in range(len(all_wavs)): + filename = all_wavs[i].split("/")[-1].split(".wav")[0] + if filename not in filename2transcript: + continue + signal = read_audio(all_wavs[i]) + duration = signal.shape[0] / 16000 + transcript_ = filename2transcript[filename] + csv_line = [ + ID_start + i, + str(duration), + all_wavs[i], + transcript_, + ] + entry.append(csv_line) + + csv_output = csv_output + entry + + with open(new_filename, mode="w") as csv_f: + csv_writer = csv.writer( + csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL + ) + for line in csv_output: + csv_writer.writerow(line) + + msg = "\t%s successfully created!" % (new_filename) + logger.info(msg) + + ID_start += len(all_wavs) + + +def parse_config(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_folder", + type=str, + default="/home/data/speechbrain/aishell", + help="data folder", + ) + parser.add_argument( + "--save_folder", + type=str, + default="/home/data/speechbrain/aishell/csv_data", + help="csv save folder", + ) + + config = parser.parse_args() + print("Config:", config) + return config + + +if __name__ == "__main__": + + config = parse_config() + prepare_aishell(config.data_folder, config.save_folder, skip_prep=False) diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/beam_search.py b/models/speech/speech_recognition/transformer_asr/ixrt/beam_search.py new file mode 100644 index 00000000..61e5c794 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/beam_search.py @@ -0,0 +1,381 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import torch +from ctc import CTCPrefixScorer +import time + +def forward(self, enc_states, wav_len): # noqa: C901 + """Applies beamsearch and returns the predicted tokens.""" + enc_lens = torch.round(enc_states.shape[1] * wav_len).int() + device = enc_states.device + batch_size = enc_states.shape[0] + + memory = self.reset_mem(batch_size * self.beam_size, device=device) + + if self.lm_weight > 0: + lm_memory = self.reset_lm_mem(batch_size * self.beam_size, device) + + if self.ctc_weight > 0: + # (batch_size * beam_size, L, vocab_size) + ctc_outputs = self.ctc_forward_step(enc_states) + ctc_scorer = CTCPrefixScorer( + ctc_outputs, + enc_lens, + batch_size, + self.beam_size, + self.blank_index, + self.eos_index, + self.ctc_window_size, + ) + ctc_memory = None + + # Inflate the enc_states and enc_len by beam_size times + enc_states = inflate_tensor(enc_states, times=self.beam_size, dim=0) + enc_lens = inflate_tensor(enc_lens, times=self.beam_size, dim=0) + + # Using bos as the first input + inp_tokens = ( + torch.zeros(batch_size * self.beam_size, device=device) + .fill_(self.bos_index) + .long() + ) + + # The first index of each sentence. + self.beam_offset = ( + torch.arange(batch_size, device=device) * self.beam_size + ) + + # initialize sequence scores variables. + sequence_scores = torch.empty( + batch_size * self.beam_size, device=device + ) + sequence_scores.fill_(float("-inf")) + + # keep only the first to make sure no redundancy. + sequence_scores.index_fill_(0, self.beam_offset, 0.0) + + # keep the hypothesis that reaches eos and their corresponding score and log_probs. + hyps_and_scores = [[] for _ in range(batch_size)] + + # keep the sequences that still not reaches eos. + alived_seq = torch.empty( + batch_size * self.beam_size, 0, device=device + ).long() + + # Keep the log-probabilities of alived sequences. + alived_log_probs = torch.empty( + batch_size * self.beam_size, 0, device=device + ) + + min_decode_steps = int(enc_states.shape[1] * self.min_decode_ratio) + max_decode_steps = int(enc_states.shape[1] * self.max_decode_ratio) + + # Initialize the previous attention peak to zero + # This variable will be used when using_max_attn_shift=True + prev_attn_peak = torch.zeros(batch_size * self.beam_size, device=device) + + for t in range(max_decode_steps): + # terminate condition + if self._check_full_beams(hyps_and_scores, self.beam_size): + break + + log_probs, memory, attn = self.forward_step( + inp_tokens, memory, enc_states, enc_lens + ) + log_probs = self.att_weight * log_probs + + # Keep the original value + log_probs_clone = log_probs.clone().reshape(batch_size, -1) + vocab_size = log_probs.shape[-1] + + if self.using_max_attn_shift: + # Block the candidates that exceed the max shift + cond, attn_peak = self._check_attn_shift(attn, prev_attn_peak) + log_probs = mask_by_condition( + log_probs, cond, fill_value=self.minus_inf + ) + prev_attn_peak = attn_peak + + # Set eos to minus_inf when less than minimum steps. + if t < min_decode_steps: + log_probs[:, self.eos_index] = self.minus_inf + + # Set the eos prob to minus_inf when it doesn't exceed threshold. + if self.using_eos_threshold: + cond = self._check_eos_threshold(log_probs) + log_probs[:, self.eos_index] = mask_by_condition( + log_probs[:, self.eos_index], + cond, + fill_value=self.minus_inf, + ) + + # adding LM scores to log_prob if lm_weight > 0 + if self.lm_weight > 0: + lm_log_probs, lm_memory = self.lm_forward_step( + inp_tokens, lm_memory + ) + log_probs = log_probs + self.lm_weight * lm_log_probs + + # adding CTC scores to log_prob if ctc_weight > 0 + if self.ctc_weight > 0: + g = alived_seq + # block blank token + log_probs[:, self.blank_index] = self.minus_inf + if self.ctc_weight != 1.0 and self.ctc_score_mode == "partial": + # pruning vocab for ctc_scorer + _, ctc_candidates = log_probs.topk( + self.beam_size * 2, dim=-1 + ) + else: + ctc_candidates = None + + ctc_log_probs, ctc_memory = ctc_scorer.forward_step( + g, ctc_memory, ctc_candidates, attn + ) + log_probs = log_probs + self.ctc_weight * ctc_log_probs + + scores = sequence_scores.unsqueeze(1).expand(-1, vocab_size) + scores = scores + log_probs + + # length normalization + if self.length_normalization: + scores = scores / (t + 1) + + # keep topk beams + scores, candidates = scores.view(batch_size, -1).topk( + self.beam_size, dim=-1 + ) + + # The input for the next step, also the output of current step. + inp_tokens = (candidates % vocab_size).view( + batch_size * self.beam_size + ) + + scores = scores.view(batch_size * self.beam_size) + sequence_scores = scores + + # recover the length normalization + if self.length_normalization: + sequence_scores = sequence_scores * (t + 1) + + # The index of which beam the current top-K output came from in (t-1) timesteps. + predecessors = ( + torch.div(candidates, vocab_size, rounding_mode="floor") + + self.beam_offset.unsqueeze(1).expand_as(candidates) + ).view(batch_size * self.beam_size) + + # Permute the memory to synchoronize with the output. + memory = self.permute_mem(memory, index=predecessors) + if self.lm_weight > 0: + lm_memory = self.permute_lm_mem(lm_memory, index=predecessors) + + if self.ctc_weight > 0: + ctc_memory = ctc_scorer.permute_mem(ctc_memory, candidates) + + # If using_max_attn_shift, then the previous attn peak has to be permuted too. + if self.using_max_attn_shift: + prev_attn_peak = torch.index_select( + prev_attn_peak, dim=0, index=predecessors + ) + + # Add coverage penalty + if self.coverage_penalty > 0: + cur_attn = torch.index_select(attn, dim=0, index=predecessors) + + # coverage: cumulative attention probability vector + if t == 0: + # Init coverage + self.coverage = cur_attn + + # the attn of transformer is [batch_size*beam_size, current_step, source_len] + if len(cur_attn.size()) > 2: + self.converage = torch.sum(cur_attn, dim=1) + else: + # Update coverage + self.coverage = torch.index_select( + self.coverage, dim=0, index=predecessors + ) + self.coverage = self.coverage + cur_attn + + # Compute coverage penalty and add it to scores + penalty = torch.max( + self.coverage, self.coverage.clone().fill_(0.5) + ).sum(-1) + penalty = penalty - self.coverage.size(-1) * 0.5 + penalty = penalty.view(batch_size * self.beam_size) + penalty = ( + penalty / (t + 1) if self.length_normalization else penalty + ) + scores = scores - penalty * self.coverage_penalty + + # Update alived_seq + alived_seq = torch.cat( + [ + torch.index_select(alived_seq, dim=0, index=predecessors), + inp_tokens.unsqueeze(1), + ], + dim=-1, + ) + + # Takes the log-probabilities + beam_log_probs = log_probs_clone[ + torch.arange(batch_size).unsqueeze(1), candidates + ].reshape(batch_size * self.beam_size) + alived_log_probs = torch.cat( + [ + torch.index_select( + alived_log_probs, dim=0, index=predecessors + ), + beam_log_probs.unsqueeze(1), + ], + dim=-1, + ) + + is_eos = self._update_hyp_and_scores( + inp_tokens, + alived_seq, + alived_log_probs, + hyps_and_scores, + scores, + timesteps=t, + ) + + # Block the paths that have reached eos. + sequence_scores.masked_fill_(is_eos, float("-inf")) + + if not self._check_full_beams(hyps_and_scores, self.beam_size): + # Using all eos to fill-up the hyps. + eos = ( + torch.zeros(batch_size * self.beam_size, device=device) + .fill_(self.eos_index) + .long() + ) + _ = self._update_hyp_and_scores( + eos, + alived_seq, + alived_log_probs, + hyps_and_scores, + scores, + timesteps=max_decode_steps, + ) + + ( + topk_hyps, + topk_scores, + topk_lengths, + log_probs, + ) = self._get_top_score_prediction(hyps_and_scores, topk=self.topk,) + # pick the best hyp + predictions = topk_hyps[:, 0, :] + predictions = batch_filter_seq2seq_output( + predictions, eos_id=self.eos_index + ) + + if self.return_log_probs: + return predictions, topk_scores, log_probs + else: + return predictions, topk_scores + + +def inflate_tensor(tensor, times, dim): + """This function inflates the tensor for times along dim. + + Arguments + --------- + tensor : torch.Tensor + The tensor to be inflated. + times : int + The tensor will inflate for this number of times. + dim : int + The dim to be inflated. + + Returns + ------- + torch.Tensor + The inflated tensor. + + Example + ------- + >>> tensor = torch.Tensor([[1,2,3], [4,5,6]]) + >>> new_tensor = inflate_tensor(tensor, 2, dim=0) + >>> new_tensor + tensor([[1., 2., 3.], + [1., 2., 3.], + [4., 5., 6.], + [4., 5., 6.]]) + """ + return torch.repeat_interleave(tensor, times, dim=dim) + +def batch_filter_seq2seq_output(prediction, eos_id=-1): + """Calling batch_size times of filter_seq2seq_output. + + Arguments + --------- + prediction : list of torch.Tensor + A list containing the output ints predicted by the seq2seq system. + eos_id : int, string + The id of the eos. + + Returns + ------ + list + The output predicted by seq2seq model. + + Example + ------- + >>> predictions = [torch.IntTensor([1,2,3,4]), torch.IntTensor([2,3,4,5,6])] + >>> predictions = batch_filter_seq2seq_output(predictions, eos_id=4) + >>> predictions + [[1, 2, 3], [2, 3]] + """ + outputs = [] + for p in prediction: + res = filter_seq2seq_output(p.tolist(), eos_id=eos_id) + outputs.append(res) + return outputs + +def filter_seq2seq_output(string_pred, eos_id=-1): + """Filter the output until the first eos occurs (exclusive). + + Arguments + --------- + string_pred : list + A list containing the output strings/ints predicted by the seq2seq system. + eos_id : int, string + The id of the eos. + + Returns + ------ + list + The output predicted by seq2seq model. + + Example + ------- + >>> string_pred = ['a','b','c','d','eos','e'] + >>> string_out = filter_seq2seq_output(string_pred, eos_id='eos') + >>> string_out + ['a', 'b', 'c', 'd'] + """ + if isinstance(string_pred, list): + try: + eos_index = next( + i for i, v in enumerate(string_pred) if v == eos_id + ) + except StopIteration: + eos_index = len(string_pred) + string_out = string_pred[:eos_index] + else: + raise ValueError("The input must be a list.") + return string_out \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/build.sh b/models/speech/speech_recognition/transformer_asr/ixrt/build.sh new file mode 100644 index 00000000..a8991234 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/build.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +for i in fast* +do + cd $i + bash build.sh + cd .. +done diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/builder.py b/models/speech/speech_recognition/transformer_asr/ixrt/builder.py new file mode 100644 index 00000000..5c19a9f4 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/builder.py @@ -0,0 +1,466 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import argparse +import torch +from tensorrt.deploy.api import GraphTransform, create_source, create_target +from tensorrt.deploy.ir.data_type import DataType +from tensorrt.deploy.ir.variable import Variable, VariableOptions +from tensorrt.deploy.ir.graph import Graph +from collections import OrderedDict +import math +import re +import glob +import os +from onnx import numpy_helper +import subprocess + + +def parse_args(): + parser = argparse.ArgumentParser( + description="build ixrt engine", usage="" + ) + parser.add_argument( + "--ckpt_path", + type=str, + required=True, + help="", + ) + parser.add_argument( + "--head_num", + type=int, + required=True, + help="", + ) + parser.add_argument( + "--max_batch_size", + type=int, + required=True, + help="", + ) + parser.add_argument( + "--max_seq_len", + type=int, + required=True, + help="", + ) + parser.add_argument( + "--onnx_path", + type=str, + default=".tmp.onnx", + help="", + ) + parser.add_argument( + "--engine_path", + type=str, + required=True, + help="", + ) + args = parser.parse_args() + return args + + +def add_make_mask_op(graph, state_dict, args): + attributes = {} + + t = graph + inputs = [ + graph.make_variable('length_radio', dtype=DataType.FLOAT16), + graph.make_variable('input', dtype=DataType.FLOAT16), + ] + + outputs = [t.make_variable("attention_mask", dtype=DataType.INT32)] + + t.make_operator( + "MakeMaskByRadio_IxRT", inputs=inputs, outputs=outputs, **attributes + ) + + +def add_custom_linear_op(graph, state_dict, args): + linear_keys = [ + "1.custom_src_module.layers.0.w.weight", + "1.custom_src_module.layers.0.w.bias" + ] + W = numpy_helper.from_array(state_dict[linear_keys[0]].cpu().numpy(), name="W") + B = numpy_helper.from_array(state_dict[linear_keys[1]].cpu().numpy(), name="B") + attributes = { + "out_dims": state_dict["1.custom_src_module.layers.0.w.weight"].size(0), + "type_id": 1, + "W": W, + "B": B, + } + assert state_dict['1.custom_src_module.layers.0.w.weight'].size( + 0) == state_dict["1.custom_src_module.layers.0.w.bias"].size(0) + + t = graph + inputs = [ + graph.get_variable('input'), + ] + + outputs = [t.make_variable("custom_src_output")] + t.make_operator( + "CustomFCPluginDynamic_IxRT", inputs=inputs, outputs=outputs, **attributes + ) + + +# def add_custom_linear_op(graph, state_dict, args): +# linear_keys = [ +# "1.custom_src_module.layers.0.w.weight", +# "1.custom_src_module.layers.0.w.bias" +# ] +# attributes = { +# "linear_dim": state_dict["1.custom_src_module.layers.0.w.weight"].size(0), +# "hidden_size": state_dict["1.custom_src_module.layers.0.w.weight"].size(1), +# "has_bias": 1, +# "act_type": "none", +# } +# assert state_dict['1.custom_src_module.layers.0.w.weight'].size( +# 0) == state_dict["1.custom_src_module.layers.0.w.bias"].size(0) +# +# t = graph +# inputs = [ +# graph.get_variable('input'), +# ] +# +# outputs = [t.make_variable("custom_src_output",dtype=DataType.FLOAT16)] +# for key in linear_keys: +# inputs.append(t.make_variable(name=key, value=state_dict[key].half())) +# t.make_operator( +# "LinearFP16", inputs=inputs, outputs=outputs, **attributes +# ) + + +def add_pos_encode_op(graph, state_dict, args): + attributes = {} + t = graph + inputs = [ + graph.get_variable('custom_src_output'), + ] + outputs = [t.make_variable("hidden_state", dtype=DataType.FLOAT16)] + t.make_operator( + "PosEncodeSinCos_IxRT", inputs=inputs, outputs=outputs, **attributes + ) + + +def add_transformer_op(graph, state_dict, args): + enc_tensor_layer_fp16_keys = OrderedDict([ + ["1.encoder.layers.{}.norm1.norm.weight", [args.hidden_size]], + ["1.encoder.layers.{}.norm1.norm.bias", [args.hidden_size]], + ["1.encoder.layers.{}.self_att.att.in_proj_weight", + [args.hidden_size * 3, args.hidden_size]], + ["1.encoder.layers.{}.self_att.att.in_proj_bias", [args.hidden_size * 3]], + ["1.encoder.layers.{}.self_att.att.out_proj.weight", + [args.hidden_size, args.hidden_size]], + ["1.encoder.layers.{}.self_att.att.out_proj.bias", [args.hidden_size]], + ["1.encoder.layers.{}.pos_ffn.ffn.0.weight", + [args.inner_size, args.hidden_size]], + ["1.encoder.layers.{}.pos_ffn.ffn.0.bias", [args.inner_size]], + ["1.encoder.layers.{}.pos_ffn.ffn.3.weight", + [args.hidden_size, args.inner_size]], + ["1.encoder.layers.{}.pos_ffn.ffn.3.bias", [args.hidden_size]], + ["1.encoder.layers.{}.norm2.norm.weight", [args.hidden_size]], + ["1.encoder.layers.{}.norm2.norm.bias", [args.hidden_size]], + ]) + attributes_legcy = { + "hidden_size": args.hidden_size, + "num_layers": args.num_layers, + "head_num": args.head_num, + "head_dim": args.head_dim, + "inner_size": args.inner_size, + "act_type": "gelu", + "normalize_before": 1, + "is_fmha": 1, + "atten_scaler": 1 / math.sqrt(args.head_dim) + } + + + attributes = { + "hidden_size": int(args.hidden_size), + "num_layers": int(args.num_layers), + "head_num": int(args.head_num), + "head_dim": int(args.head_dim), + "inner_size": int(args.inner_size), + "act_type": 12, #gelu + "normalize_before": 1, + "is_fmha": 1, + "atten_scaler": 1.0 / math.sqrt(args.head_dim), + "max_seq_len": int(args.max_seq_len), + "max_batch_size": int(args.max_batch_size), + + } + + t = graph + inputs = [ + graph.get_variable('hidden_state'), + graph.get_variable('attention_mask'), + ] + outputs = [t.make_variable("encoder_out", dtype=DataType.FLOAT16)] + for layer_id in range(args.num_layers): + for key, shape in enc_tensor_layer_fp16_keys.items(): + # we need cat qkv gemm's weight and bias + new_key = key.format(layer_id) + w = state_dict[new_key] + if list(w.shape) != shape: + print("weights shape error!") + print("key: ", key) + print("need shape: ", shape) + print("weight shape: ", w.shape) + exit(1) + inputs.append(t.make_variable(name=new_key, value=w.half())) + t.make_operator( + "TransformerEncoderFp16_IxRT", inputs=inputs, outputs=outputs, **attributes + ) + + +def add_layer_norm_op(graph, state_dict, args): + enc_ln_tensor_fp16_keys = OrderedDict([ + ["1.encoder.norm.norm.weight", [args.hidden_size]], + ["1.encoder.norm.norm.bias", [args.hidden_size]], + ]) + attributes = { + "epsilon": 1e-5, + "axis": -1, + "stash_type": 1 + } + t = graph + inputs = [ + graph.get_variable('encoder_out'), + ] + outputs = [t.make_variable("encoder_ln_out")] + for key, shape in enc_ln_tensor_fp16_keys.items(): + new_key = key + w = state_dict[new_key] + if list(w.shape) != shape: + print("weights shape error!") + print("key: ", key) + print("need shape: ", shape) + print("weight shape: ", w.shape) + exit(1) + inputs.append(t.make_variable(name=new_key, value=w.half())) + t.make_operator( + "LayerNormalization", inputs=inputs, outputs=outputs, **attributes + ) + + +# def add_layer_norm_op(graph, state_dict, args): +# enc_ln_tensor_fp16_keys = OrderedDict([ +# ["1.encoder.norm.norm.weight", [args.hidden_size]], +# ["1.encoder.norm.norm.bias", [args.hidden_size]], +# ]) +# attributes = { +# "hidden_size": args.hidden_size, +# } +# t = graph +# inputs = [ +# graph.get_variable('encoder_out'), +# ] +# outputs = [t.make_variable("encoder_ln_out",dtype=DataType.FLOAT16)] +# for key, shape in enc_ln_tensor_fp16_keys.items(): +# new_key = key +# w = state_dict[new_key] +# if list(w.shape) != shape: +# print("weights shape error!") +# print("key: ", key) +# print("need shape: ", shape) +# print("weight shape: ", w.shape) +# exit(1) +# inputs.append(t.make_variable(name=new_key, value=w.half())) +# t.make_operator( +# "LayerNormFp16", inputs=inputs, outputs=outputs, **attributes +# ) + +def add_linear_op(graph, state_dict, args): + linear_keys = [ + "3.w.weight", + "3.w.bias" + ] + W = numpy_helper.from_array(state_dict[linear_keys[0]].cpu().numpy(), name="W") + B = numpy_helper.from_array(state_dict[linear_keys[1]].cpu().numpy(), name="B") + attributes = { + "out_dims": state_dict["3.w.weight"].size(0), + "type_id": 1, + "W": W, + "B": B, + } + assert state_dict['3.w.weight'].size(0) == state_dict["3.w.bias"].size(0) + + t = graph + inputs = [ + graph.get_variable('encoder_ln_out'), + ] + + outputs = [t.make_variable("lin_output")] + t.make_operator( + "CustomFCPluginDynamic_IxRT", inputs=inputs, outputs=outputs, **attributes + ) + + +# +# def add_linear_op(graph, state_dict, args): +# lin_keys = [ +# "3.w.weight", +# "3.w.bias" +# ] +# attributes = { +# "linear_dim": state_dict["3.w.weight"].size(0), +# "hidden_size": state_dict["3.w.weight"].size(1), +# "has_bias": 1, +# "act_type": "none", +# } +# assert state_dict['3.w.weight'].size(0) == state_dict["3.w.bias"].size(0) +# +# t = graph +# inputs = [ +# graph.get_variable('encoder_ln_out'), +# ] +# +# outputs = [t.make_variable("lin_output",dtype=DataType.FLOAT16)] +# for key in lin_keys: +# inputs.append(t.make_variable(name=key, value=state_dict[key].half())) +# t.make_operator( +# "LinearFP16", inputs=inputs, outputs=outputs, **attributes +# ) + + +def add_log_softmax_op(graph, state_dict, args): + attributes = { + "axis": "-1", + } + + t = graph + inputs = [ + graph.get_variable('lin_output'), + ] + + outputs = [t.make_variable("log_softmax_output", dtype=DataType.FLOAT16)] + + t.make_operator( + "LogSoftmax", inputs=inputs, outputs=outputs, **attributes + ) + + +def add_search_node(graph, state_dict, args): + attributes = { + "vocab_size": args.vocab_size, + "eos_id": args.vocab_size, + "pad_id": -10000, + "beam_size": 1, + "attr1": 1.0, + "min_decode_ratio": 0.0, + "max_decode_ratio": 1.0, + "ctc_weight": 0.40, + "using_eos_threshold": 0, + "length_normalization": 1, + } + t = graph + inputs = [ + graph.get_variable('lin_output'), + ] + + outputs = [t.make_variable("output_tokens", dtype=DataType.INT32)] + list_value_half = [] + list_key_half = [] + for key in state_dict.keys(): + if "decoder" in key or "custom_tgt_module" in key or "2.w.weight" in key or "2.w.bias" in key: + list_key_half.append(key) + list_value_half.append(state_dict[key].half()) + for i, item in enumerate(list_key_half): + inputs.append(t.make_variable(name=list_key_half[i], value=list_value_half[i])) + t.make_operator( + "Search", inputs=inputs, outputs=outputs, **attributes + ) + + +def get_num_layers(state_dict): + num_layers = -1 + for key in state_dict: + layer_id = re.search( + "1.encoder.layers.([0-9]+).pos_ffn.ffn.0.bias", key) + if layer_id: + layer_id = layer_id.group(1) + num_layers = max(num_layers, int(layer_id) + 1) + assert num_layers > 0 + return num_layers + + +def build_engine(onnx_file, engine_file, max_batch_size,max_seq_len): + cmd = f"ixrtexec --onnx {onnx_file} --min_shape input:1x32x5120,length_radio:1 --opt_shape input:8x64x5120,length_radio:8 --max_shape input:{max_batch_size}x{max_seq_len}x5120,length_radio:64 --plugins ixrt_plugin --save_engine {engine_file}" + subprocess.run(cmd.split(), check=True) + + +def main(args): + graph = Graph() + transform = GraphTransform(graph) + ckpt_path = glob.glob(os.path.join(args.ckpt_path, "*/model.ckpt"))[0] + print("load ckpt from: ", ckpt_path) + state_dict = torch.load(ckpt_path) + + # print([i for i in state_dict ]) + # print(state_dict['3.w.bias']) + args.hidden_size = state_dict['1.encoder.layers.0.norm1.norm.weight'].size( + 0) + args.head_dim = args.hidden_size / args.head_num + args.inner_size = state_dict['1.encoder.layers.0.pos_ffn.ffn.0.bias'].size( + 0) + args.vocab_size = state_dict['3.w.weight'].size(0) + + args.num_layers = get_num_layers(state_dict) + + args.src_len = state_dict["1.custom_src_module.layers.0.w.weight"].size(1) + + # args.num_layers = 1 + add_make_mask_op(transform, state_dict, args) + add_custom_linear_op(transform, state_dict, args) + add_pos_encode_op(transform, state_dict, args) + add_transformer_op(transform, state_dict, args) + add_layer_norm_op(transform, state_dict, args) + # add_linear_op(transform, state_dict, args) + # add_log_softmax_op(transform, state_dict, args) + # add_search_node(transform, state_dict, args) + + # IO attributes + length_radio = graph.get_variable('length_radio') + length_radio.set_shape(["batch_size"]) + length_radio.dtype = "float16" + graph.add_input(length_radio) + + input = graph.get_variable('input') + input.set_shape(["batch_size", "seq_len", "src_len"]) + input.dtype = "float16" + graph.add_input(input) + + output = graph.get_variable('encoder_ln_out') + output.dtype = "float16" + graph.add_output(output) + + create_target(saved_path=args.onnx_path).export(graph) + + build_engine(args.onnx_path, args.engine_path, args.max_batch_size, args.max_seq_len) + print("save engine: ", args.engine_path) + + +if __name__ == "__main__": + args = parse_args() + ckpt_path = args.ckpt_path + + main(args) + +""" +python3 builder.py \ +--ckpt_path results/transformer/8886/save \ +--head_num 4 \ +--max_batch_size 64 \ +--max_seq_len 1024 \ +--engine_path transformer.engine +""" diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/convert.py b/models/speech/speech_recognition/transformer_asr/ixrt/convert.py new file mode 100644 index 00000000..11d71a56 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/convert.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import torch +from faster_layer_norm import FasterLayerNorm + +def replace_layer_norm(model): + module_output = model + + if isinstance(model, torch.nn.modules.normalization.LayerNorm): + return FasterLayerNorm(model.weight, model.bias) + + for name, child in model.named_children(): + module_output.add_module( + name, replace_layer_norm(child) + ) + return module_output + + +def convert_decoder_model(model): + model = replace_layer_norm(model) + # for layer in model.layers: + # norm = layer.norm1.norm + # print(type(norm)) + # exit() + # new_norm = FasterLayerNorm(norm.weight, norm.bias) + # layer.norm1.norm = new_norm + + # norm = layer.norm2.norm + # new_norm = FasterLayerNorm(norm.weight, norm.bias) + # layer.norm2.norm = new_norm + + # norm = layer.norm3.norm + # new_norm = FasterLayerNorm(norm.weight, norm.bias) + # layer.norm3.norm = new_norm + return model + +# def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): +# if type(module) in layers: +# return {name: module} +# res = {} +# for name1, child in module.named_children(): +# res.update(find_layers( +# child, layers=layers, name=name + '.' + name1 if name != '' else name1 +# )) +# return res + +def find_node(module): + if type(module) in [torch.nn.LayerNorm]: + print(module) + return + res = {} + for name1, child in module.named_children(): + find_node(child) + return + + +def patch_get_lookahead_mask(padded_input): + """Creates a binary mask for each sequence which maskes future frames. + + Arguments + --------- + padded_input: torch.Tensor + Padded input tensor. + + Example + ------- + >>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]]) + >>> get_lookahead_mask(a) + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0.]]) + """ + seq_len = padded_input.shape[1] + mask = ( + torch.triu(torch.ones((seq_len, seq_len), device=padded_input.device)) + == 1 + ).transpose(0, 1) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask.detach().to(padded_input.device).to(torch.float16) \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/ctc.py b/models/speech/speech_recognition/transformer_asr/ixrt/ctc.py new file mode 100644 index 00000000..9db6ab7e --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/ctc.py @@ -0,0 +1,394 @@ +"""Decoders and output normalization for CTC. + +Authors + * Mirco Ravanelli 2020 + * Aku Rouhe 2020 + * Sung-Lin Yeh 2020 +""" +import torch +from itertools import groupby +from speechbrain.dataio.dataio import length_to_mask +from faster_logsumexp import FasterLogSumExp +from faster_stack import FasterStack +from faster_cat import FastCat + + +class CTCPrefixScorer: + """This class implements the CTC prefix scorer of Algorithm 2 in + reference: https://www.merl.com/publications/docs/TR2017-190.pdf. + Official implementation: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py + + Arguments + --------- + x : torch.Tensor + The encoder states. + enc_lens : torch.Tensor + The actual length of each enc_states sequence. + batch_size : int + The size of the batch. + beam_size : int + The width of beam. + blank_index : int + The index of the blank token. + eos_index : int + The index of the end-of-sequence (eos) token. + ctc_window_size: int + Compute the ctc scores over the time frames using windowing based on attention peaks. + If 0, no windowing applied. + """ + + def __init__( + self, + x, + enc_lens, + batch_size, + beam_size, + blank_index, + eos_index, + ctc_window_size=0, + ): + self.blank_index = blank_index + self.eos_index = eos_index + self.max_enc_len = x.size(1) + self.batch_size = batch_size + self.beam_size = beam_size + self.vocab_size = x.size(-1) + self.device = x.device + self.minus_inf = -1e4 + self.last_frame_index = enc_lens - 1 + self.ctc_window_size = ctc_window_size + + # mask frames > enc_lens + mask = 1 - length_to_mask(enc_lens) + mask = mask.unsqueeze(-1).expand(-1, -1, x.size(-1)).eq(1) + x.masked_fill_(mask, self.minus_inf) + x[:, :, 0] = x[:, :, 0].masked_fill_(mask[:, :, 0], 0) + + # dim=0: xnb, nonblank posteriors, dim=1: xb, blank posteriors + xnb = x.transpose(0, 1) + xb = ( + xnb[:, :, self.blank_index] + .unsqueeze(2) + .expand(-1, -1, self.vocab_size) + ) + + # (2, L, batch_size * beam_size, vocab_size) + # self.x = torch.stack([xnb, xb]) + self.x = FasterStack([xnb.contiguous(), xb.contiguous()]) + + # The first index of each sentence. + self.beam_offset = ( + torch.arange(batch_size, device=self.device) * self.beam_size + ) + # The first index of each candidates. + self.cand_offset = ( + torch.arange(batch_size, device=self.device) * self.vocab_size + ) + + def forward_step(self, g, state, candidates=None, attn=None): + """This method if one step of forwarding operation + for the prefix ctc scorer. + + Arguments + --------- + g : torch.Tensor + The tensor of prefix label sequences, h = g + c. + state : tuple + Previous ctc states. + candidates : torch.Tensor + (batch_size * beam_size, ctc_beam_size), The topk candidates for rescoring. + The ctc_beam_size is set as 2 * beam_size. If given, performing partial ctc scoring. + """ + + prefix_length = g.size(1) + last_char = [gi[-1] for gi in g] if prefix_length > 0 else [0] * len(g) + self.num_candidates = ( + self.vocab_size if candidates is None else candidates.size(-1) + ) + if state is None: + # r_prev: (L, 2, batch_size * beam_size) + r_prev = torch.full( + (self.max_enc_len, 2, self.batch_size, self.beam_size), + self.minus_inf, + device=self.device, + dtype=torch.float16 + ) + + # Accumulate blank posteriors at each step + r_prev[:, 1] = torch.cumsum( + self.x[0, :, :, self.blank_index], 0 + ).unsqueeze(2) + r_prev = r_prev.view(-1, 2, self.batch_size * self.beam_size) + psi_prev = 0.0 + else: + r_prev, psi_prev = state + r_prev = r_prev.half() + + # for partial search + if candidates is not None: + scoring_table = torch.full( + (self.batch_size * self.beam_size, self.vocab_size), + -1, + dtype=torch.long, + device=self.device, + ) + # Assign indices of candidates to their positions in the table + col_index = torch.arange( + self.batch_size * self.beam_size, device=self.device + ).unsqueeze(1) + scoring_table[col_index, candidates] = torch.arange( + self.num_candidates, device=self.device + ) + # Select candidates indices for scoring + scoring_index = ( + candidates + + self.cand_offset.unsqueeze(1) + .repeat(1, self.beam_size) + .view(-1, 1) + ).view(-1) + x_inflate = torch.index_select( + self.x.view(2, -1, self.batch_size * self.vocab_size), + 2, + scoring_index, + ).view(2, -1, self.batch_size * self.beam_size, self.num_candidates) + # for full search + else: + scoring_table = None + x_inflate = ( + self.x.unsqueeze(3) + .repeat(1, 1, 1, self.beam_size, 1) + .view( + 2, -1, self.batch_size * self.beam_size, self.num_candidates + ) + ) + + # Prepare forward probs + r = torch.full( + ( + self.max_enc_len, + 2, + self.batch_size * self.beam_size, + self.num_candidates, + ), + self.minus_inf, + device=self.device, + dtype=torch.float16 + ) + r.fill_(self.minus_inf) + + # (Alg.2-6) + if prefix_length == 0: + r[0, 0] = x_inflate[0, 0] + # (Alg.2-10): phi = prev_nonblank + prev_blank = r_t-1^nb(g) + r_t-1^b(g) + r_sum = FasterLogSumExp(r_prev, 1) + phi = r_sum.unsqueeze(2).repeat(1, 1, self.num_candidates) + + # (Alg.2-10): if last token of prefix g in candidates, phi = prev_b + 0 + if candidates is not None: + for i in range(self.batch_size * self.beam_size): + pos = scoring_table[i, last_char[i]] + if pos != -1: + phi[:, i, pos] = r_prev[:, 1, i] + else: + for i in range(self.batch_size * self.beam_size): + phi[:, i, last_char[i]] = r_prev[:, 1, i] + + # Start, end frames for scoring (|g| < |h|). + # Scoring based on attn peak if ctc_window_size > 0 + if self.ctc_window_size == 0 or attn is None: + start = max(1, prefix_length) + end = self.max_enc_len + else: + _, attn_peak = torch.max(attn, dim=1) + max_frame = torch.max(attn_peak).item() + self.ctc_window_size + min_frame = torch.min(attn_peak).item() - self.ctc_window_size + start = max(max(1, prefix_length), int(min_frame)) + end = min(self.max_enc_len, int(max_frame)) + + # Compute forward prob log(r_t^nb(h)) and log(r_t^b(h)): + for t in range(start, end): + # (Alg.2-11): dim=0, p(h|cur step is nonblank) = [p(prev step=y) + phi] * p(c) + rnb_prev = r[t - 1, 0] + # (Alg.2-12): dim=1, p(h|cur step is blank) = [p(prev step is blank) + p(prev step is nonblank)] * p(blank) + rb_prev = r[t - 1, 1] + # r_ = torch.stack([rnb_prev, phi[t - 1], rnb_prev, rb_prev]).view( + # 2, 2, self.batch_size * self.beam_size, self.num_candidates + # ) + r_ = FasterStack([rnb_prev, phi[t - 1], rnb_prev, rb_prev]).view( + 2, 2, self.batch_size * self.beam_size, self.num_candidates + ) + r[t] = FasterLogSumExp(r_, 1) + x_inflate[:, t] + + # Compute the predix prob, psi + psi_init = r[start - 1, 0].unsqueeze(0) + # phi is prob at t-1 step, shift one frame and add it to the current prob p(c) + phix = FastCat((phi[0].unsqueeze(0), phi[:-1]), dim=0) + x_inflate[0] + + # (Alg.2-13): psi = psi + phi * p(c) + if candidates is not None: + psi = torch.full( + (self.batch_size * self.beam_size, self.vocab_size), + self.minus_inf, + device=self.device, + dtype=torch.float16 + ) + psi_ = FasterLogSumExp( + FastCat((phix[start:end], psi_init), dim=0), dim=0 + ) + # only assign prob to candidates + for i in range(self.batch_size * self.beam_size): + psi[i, candidates[i]] = psi_[i] + else: + psi = FastCat((phix[start:end], psi_init), dim=0) + psi = FasterLogSumExp(psi, dim=0) + + # (Alg.2-3): if c = , psi = log(r_T^n(g) + r_T^b(g)), where T is the length of max frames + for i in range(self.batch_size * self.beam_size): + psi[i, self.eos_index] = r_sum[ + self.last_frame_index[i // self.beam_size], i + ] + + # Exclude blank probs for joint scoring + psi[:, self.blank_index] = self.minus_inf + + return psi - psi_prev, (r, psi, scoring_table) + + def permute_mem(self, memory, index): + """This method permutes the CTC model memory + to synchronize the memory index with the current output. + + Arguments + --------- + memory : No limit + The memory variable to be permuted. + index : torch.Tensor + The index of the previous path. + + Return + ------ + The variable of the memory being permuted. + + """ + r, psi, scoring_table = memory + # The index of top-K vocab came from in (t-1) timesteps. + best_index = ( + index + + (self.beam_offset.unsqueeze(1).expand_as(index) * self.vocab_size) + ).view(-1) + # synchronize forward prob + psi = torch.index_select(psi.view(-1), dim=0, index=best_index) + psi = ( + psi.view(-1, 1) + .repeat(1, self.vocab_size) + .view(self.batch_size * self.beam_size, self.vocab_size) + ) + + # synchronize ctc states + if scoring_table is not None: + effective_index = ( + index // self.vocab_size + self.beam_offset.view(-1, 1) + ).view(-1) + selected_vocab = (index % self.vocab_size).view(-1) + score_index = scoring_table[effective_index, selected_vocab] + score_index[score_index == -1] = 0 + best_index = score_index + effective_index * self.num_candidates + + r = torch.index_select( + r.view( + -1, 2, self.batch_size * self.beam_size * self.num_candidates + ), + dim=-1, + index=best_index, + ) + r = r.view(-1, 2, self.batch_size * self.beam_size) + + return r, psi + + +def filter_ctc_output(string_pred, blank_id=-1): + """Apply CTC output merge and filter rules. + + Removes the blank symbol and output repetitions. + + Arguments + --------- + string_pred : list + A list containing the output strings/ints predicted by the CTC system. + blank_id : int, string + The id of the blank. + + Returns + ------- + list + The output predicted by CTC without the blank symbol and + the repetitions. + + Example + ------- + >>> string_pred = ['a','a','blank','b','b','blank','c'] + >>> string_out = filter_ctc_output(string_pred, blank_id='blank') + >>> print(string_out) + ['a', 'b', 'c'] + """ + + if isinstance(string_pred, list): + # Filter the repetitions + string_out = [ + v + for i, v in enumerate(string_pred) + if i == 0 or v != string_pred[i - 1] + ] + + # Remove duplicates + string_out = [i[0] for i in groupby(string_out)] + + # Filter the blank symbol + string_out = list(filter(lambda elem: elem != blank_id, string_out)) + else: + raise ValueError("filter_ctc_out can only filter python lists") + return string_out + + +def ctc_greedy_decode(probabilities, seq_lens, blank_id=-1): + """Greedy decode a batch of probabilities and apply CTC rules. + + Arguments + --------- + probabilities : torch.tensor + Output probabilities (or log-probabilities) from the network with shape + [batch, probabilities, time] + seq_lens : torch.tensor + Relative true sequence lengths (to deal with padded inputs), + the longest sequence has length 1.0, others a value between zero and one + shape [batch, lengths]. + blank_id : int, string + The blank symbol/index. Default: -1. If a negative number is given, + it is assumed to mean counting down from the maximum possible index, + so that -1 refers to the maximum possible index. + + Returns + ------- + list + Outputs as Python list of lists, with "ragged" dimensions; padding + has been removed. + + Example + ------- + >>> import torch + >>> probs = torch.tensor([[[0.3, 0.7], [0.0, 0.0]], + ... [[0.2, 0.8], [0.9, 0.1]]]) + >>> lens = torch.tensor([0.51, 1.0]) + >>> blank_id = 0 + >>> ctc_greedy_decode(probs, lens, blank_id) + [[1], [1]] + """ + if isinstance(blank_id, int) and blank_id < 0: + blank_id = probabilities.shape[-1] + blank_id + batch_max_len = probabilities.shape[1] + batch_outputs = [] + for seq, seq_len in zip(probabilities, seq_lens): + actual_size = int(torch.round(seq_len * batch_max_len)) + scores, predictions = torch.max(seq.narrow(0, 0, actual_size), dim=1) + out = filter_ctc_output(predictions.tolist(), blank_id=blank_id) + batch_outputs.append(out) + return batch_outputs diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/__init__.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/__init__.py new file mode 100644 index 00000000..537d35c5 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/__init__.py @@ -0,0 +1,13 @@ +import torch +from faster_cat import sp_opt + +def FastCat(inputs,dim=0): + if len(inputs) == 2 and dim==0: + a,b = inputs + in_shape = a.shape + if len(in_shape)>1: + res, = sp_opt.test_opt_2(a.view(a.shape[0],-1),b.view(b.shape[0],-1)) + new_shape = (a.shape[0]+b.shape[0],) + in_shape[1:] + res = res.view(*new_shape) + return res + return torch.cat(inputs,dim=dim) \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/build.sh b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/build.sh new file mode 100644 index 00000000..f679258d --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/build.sh @@ -0,0 +1,22 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +set -euox pipefail + +rm -rf build +rm -rf *.so + +python3 setup.py build + +cp build/lib*/*.so . \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/kernel.cu b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/kernel.cu new file mode 100644 index 00000000..022fac39 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/kernel.cu @@ -0,0 +1,79 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace iluvatar::inferrt::transformer { + +__global__ void Cat(half* a, half* b, half* output, int m1, int m2, int k) { + int i = blockIdx.y * blockDim.x + threadIdx.x; + // a + if (blockIdx.x < m1) { + half2* h2_a = reinterpret_cast(a + blockIdx.x * k); + half2* h2_out_a = reinterpret_cast(output + blockIdx.x * k); + if (i < k / 2) { + h2_out_a[i] = h2_a[i]; + } + } + // b + if (blockIdx.x < m2) { + half2* h2_b = reinterpret_cast(b + blockIdx.x * k); + half2* h2_out_b = + reinterpret_cast(output + blockIdx.x * k + m1 * k); + if (i < k / 2) { + h2_out_b[i] = h2_b[i]; + } + } +} + +void IxinferCatLauncher(half* a, half* b, half* output, int m1, int m2, int k, + cudaStream_t stream) { + if (k % 2 != 0) { + throw std::runtime_error("IxinferStackLauncher: size error!"); + } + int m = std::max(m1, m2); + int num_threads = 1024; + int half_k = k / 2; + int num_roll = (half_k - 1 + num_threads) / num_threads; + dim3 grid(m, num_roll); + dim3 block(num_threads); + Cat<<>>(a, b, output, m1, m2, k); +} + +} // namespace iluvatar::inferrt::transformer + +std::vector one_test_opt_2(at::Tensor a, at::Tensor b) { + TORCH_CHECK(a.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(a.is_cuda()); + TORCH_CHECK(a.is_contiguous()); + + TORCH_CHECK(b.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(b.is_cuda()); + TORCH_CHECK(b.is_contiguous()); + + TORCH_CHECK(a.dim() == 2); + TORCH_CHECK(b.dim() == 2); + + int m1 = a.size(0); + int m2 = b.size(0); + + int k = a.size(1); + + TORCH_CHECK(b.size(1) == k); + + at::Tensor output = a.new_empty({(m1 + m2), k}); + + half* p_a = (half*)a.data_ptr(); + half* p_b = (half*)b.data_ptr(); + half* p_out = (half*)output.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + iluvatar::inferrt::transformer::IxinferCatLauncher(p_a, p_b, p_out, m1, m2, k, + stream); + return {output}; +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/setup.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/setup.py new file mode 100644 index 00000000..a031577c --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/setup.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import glob +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +CUR_DIR = os.path.dirname(os.path.abspath(__file__)) + +# cpp_files = glob.glob(os.path.join(CUR_DIR,"*.cpp")) +# cu_files = glob.glob(os.path.join(CUR_DIR,'*.cu')) +# source_files = cpp_files + cu_files +# print("source files:") +# for i in source_files: +# print(i) +source_files = [ + os.path.join(CUR_DIR,'test.cpp'), + os.path.join(CUR_DIR,'kernel.cu'), +] + +for i in source_files: + assert os.path.isfile(i) + print(i) + +setup( + name="test", + ext_modules=[ + CUDAExtension( + name="sp_opt", + libraries=["cuinfer"], + sources=source_files) + ], + cmdclass={ + "build_ext": BuildExtension + } +) diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.cpp b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.cpp new file mode 100644 index 00000000..11720811 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.cpp @@ -0,0 +1,21 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include + + +std::vector one_test_opt_2(at::Tensor a, at::Tensor b); + +std::vector test_opt_2(at::Tensor a, at::Tensor b) { + return one_test_opt_2(a, b); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("test_opt_2", &test_opt_2, ""); +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.py new file mode 100644 index 00000000..2713dae2 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import torch +import sp_opt + +if __name__ == "__main__": + m1 = 320 + m2 = 321 + hidden_size = 5000 + + a = torch.randn([m1,hidden_size]).cuda().half() + b = torch.randn([m2,hidden_size]).cuda().half() + + + res_pt = torch.cat([a,b],dim=0) + + res_cu, = sp_opt.test_opt_2(a,b) + + + diff = torch.abs(res_pt-res_cu) + print(diff) + print(diff.max()) + + for i in range(20): + res_cu, = sp_opt.test_opt_2(a,b) \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/__init__.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/__init__.py new file mode 100644 index 00000000..20603650 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/__init__.py @@ -0,0 +1,16 @@ +import torch +from faster_layer_norm import sp_opt + +class FasterLayerNorm(torch.nn.Module): + def __init__(self, weight, bias): + super(FasterLayerNorm, self).__init__() + self.weight = weight + self.bias = bias + + def forward(self, inputs, *args, **kwargs): + hidden_size = self.weight.size(0) + in_shape = inputs.shape + inputs = inputs.view(-1,hidden_size) + output, = sp_opt.test_opt(inputs,self.weight,self.bias) + output = output.view(*in_shape) + return output diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/build.sh b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/build.sh new file mode 100644 index 00000000..f679258d --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/build.sh @@ -0,0 +1,22 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +set -euox pipefail + +rm -rf build +rm -rf *.so + +python3 setup.py build + +cp build/lib*/*.so . \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/kernel.cu b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/kernel.cu new file mode 100644 index 00000000..852db917 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/kernel.cu @@ -0,0 +1,168 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "transformer_helper.cuh" + +namespace iluvatar::inferrt::transformer { + +template +__global__ void LnOpt2Kernel(half* input, half* ln_weight, half* ln_bias, + half* output, int hidden_size, + float layernorm_eps) { + input += blockIdx.x * hidden_size; + output += blockIdx.x * hidden_size; + + half2* p_in = reinterpret_cast(input); + half2* p_out = reinterpret_cast(output); + half2* p_wei = reinterpret_cast(ln_weight); + half2* p_bias = reinterpret_cast(ln_bias); + int half_hidden_size = hidden_size / 2; + + extern __shared__ half2 shmem[]; + + float s_mean; + float s_variance; + float x_sum = 0.0f; + float x2_sum = 0.0f; +#pragma unroll UNROLL_FACTOR + for (int i = 0; i < UNROLL_FACTOR; ++i) { + int index = i * blockDim.x + threadIdx.x; + if (index < half_hidden_size) { + half2 value = p_in[index]; + shmem[index] = value; + float val_1 = __half2float(value.x); + float val_2 = __half2float(value.y); + x_sum += val_1 + val_2; + x2_sum += val_1 * val_1 + val_2 * val_2; + } + } + float sums[2]; // 和,平方和 + sums[0] = x_sum; + sums[1] = x2_sum; + blockReduceSumV2(sums); + + s_mean = sums[0] / hidden_size; + s_variance = rsqrtf(sums[1] / hidden_size - s_mean * s_mean + layernorm_eps); + +#pragma unroll UNROLL_FACTOR + for (int i = 0; i < UNROLL_FACTOR; ++i) { + int index = i * blockDim.x + threadIdx.x; + if (index < half_hidden_size) { + half2 wei_value = p_wei[index]; + half2 bias_value = p_bias[index]; + half2 vals_value = shmem[index]; + + float2 norm_value; + norm_value.x = (__half2float(vals_value.x) - s_mean) * s_variance * + __half2float(wei_value.x) + + __half2float(bias_value.x); + norm_value.y = (__half2float(vals_value.y) - s_mean) * s_variance * + __half2float(wei_value.y) + + __half2float(bias_value.y); + + __half2 res; + res.x = __float2half(norm_value.x); + res.y = __float2half(norm_value.y); + + p_out[index] = res; + } + } +} + +// FasterTransformer/src/fastertransformer/kernels/layernorm_kernels.cu +void IxinferLnLauncherOpt2(__half* input, __half* ln_weight, __half* ln_bias, + __half* output, int batch_tokens, int hidden_size, + cudaStream_t stream) { + const float layernorm_eps = 1e-5; + if (hidden_size % 2 != 0) { + throw std::runtime_error("layer norm error: hidden_size % 2 != 0"); + } + dim3 grid(batch_tokens); + int half_n = hidden_size / 2; + int half_n_warp = (half_n + warpSize - 1) / warpSize * warpSize; + dim3 block(std::min(half_n_warp, 1024)); + int rolls_per_thread = (half_n + block.x - 1) / block.x; + switch (rolls_per_thread) { + case 1: + LnOpt2Kernel<1><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 2: + LnOpt2Kernel<2><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 3: + LnOpt2Kernel<3><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 4: + LnOpt2Kernel<4><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 5: + LnOpt2Kernel<5><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 6: + LnOpt2Kernel<6><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 7: + LnOpt2Kernel<7><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 8: + LnOpt2Kernel<8><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + default: + std::cout << "hidden_size: " << hidden_size << std::endl; + throw std::runtime_error("layer norm error, unsupport hidden size! "); + break; + } +} +} // namespace iluvatar::inferrt::transformer + +std::vector one_test_opt(at::Tensor input, at::Tensor ln_weight, + at::Tensor ln_bias) { + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(input.is_cuda()); + TORCH_CHECK(input.is_contiguous()); + + TORCH_CHECK(ln_weight.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(ln_weight.is_cuda()); + TORCH_CHECK(ln_weight.is_contiguous()); + + TORCH_CHECK(ln_bias.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(ln_bias.is_cuda()); + TORCH_CHECK(ln_bias.is_contiguous()); + + TORCH_CHECK(input.dim() == 2); + TORCH_CHECK(ln_weight.dim() == 1); + TORCH_CHECK(ln_bias.dim() == 1); + + int batch_tokens = input.size(0); + int hidden_size = input.size(1); + + TORCH_CHECK(ln_weight.size(0) == hidden_size); + TORCH_CHECK(ln_bias.size(0) == hidden_size); + + at::Tensor output = at::empty_like(input); + + half* p_in = (half*)input.data_ptr(); + half* p_wei = (half*)ln_weight.data_ptr(); + half* p_bias = (half*)ln_bias.data_ptr(); + half* p_out = (half*)output.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + iluvatar::inferrt::transformer::IxinferLnLauncherOpt2( + p_in, p_wei, p_bias, p_out, batch_tokens, hidden_size, stream); + return {output}; +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/setup.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/setup.py new file mode 100644 index 00000000..a031577c --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/setup.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import glob +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +CUR_DIR = os.path.dirname(os.path.abspath(__file__)) + +# cpp_files = glob.glob(os.path.join(CUR_DIR,"*.cpp")) +# cu_files = glob.glob(os.path.join(CUR_DIR,'*.cu')) +# source_files = cpp_files + cu_files +# print("source files:") +# for i in source_files: +# print(i) +source_files = [ + os.path.join(CUR_DIR,'test.cpp'), + os.path.join(CUR_DIR,'kernel.cu'), +] + +for i in source_files: + assert os.path.isfile(i) + print(i) + +setup( + name="test", + ext_modules=[ + CUDAExtension( + name="sp_opt", + libraries=["cuinfer"], + sources=source_files) + ], + cmdclass={ + "build_ext": BuildExtension + } +) diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/test.cpp b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/test.cpp new file mode 100644 index 00000000..f925c1b4 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/test.cpp @@ -0,0 +1,22 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +std::vector one_test_opt(at::Tensor input, at::Tensor ln_weight, + at::Tensor ln_bias); + +std::vector test_opt(at::Tensor input, at::Tensor ln_weight, + at::Tensor ln_bias) { + return one_test_opt(input, ln_weight, ln_bias); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("test_opt", &test_opt, "fast depthwise conv1d forward"); +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/transformer_helper.cuh b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/transformer_helper.cuh new file mode 100644 index 00000000..f8a57622 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/transformer_helper.cuh @@ -0,0 +1,295 @@ +#pragma once +#include +#include + +namespace iluvatar { +namespace inferrt { +namespace transformer { + +__forceinline__ int nearest_4(int x) { + if (x % 4 == 0) { + return x; + } else { + int padding = 4 - x % 4; + return x + padding; + } +} + +__forceinline__ int nearest_2(int x) { + if (x % 2 == 0) { + return x; + } else { + int padding = 2 - x % 2; + return x + padding; + } +} + +__forceinline__ int nearest_num(int x, int value) { + if (x % value == 0) { + return x; + } else { + int padding = value - x % value; + return x + padding; + } +} + +__device__ int8_t float2int8(float x, float quant_scale) { + float i8_f = x * quant_scale; + int32_t i8 = floorf(i8_f + 0.5); + i8 = i8 < -127 ? -127 : (i8 > 127 ? 127 : i8); + return int8_t(i8); +} + +__device__ void WelfordCombine(float val, float *mean, float *m2, + float *count) { + // Use Welford Online algorithem to compute mean and variance + // For more details you can refer to: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + *count += 1; + float delta1 = val - *mean; + *mean += delta1 / *count; + float delta2 = val - *mean; + *m2 += delta1 * delta2; +} + +__device__ void WelfordCombine(float b_mean, float b_m2, float b_count, + float *mean, float *m2, float *count) { + if (b_count == 0) { + return; + } + float new_count = *count + b_count; + float nb_over_n = b_count / new_count; + float delta = b_mean - *mean; + *mean += delta * nb_over_n; + *m2 += b_m2 + delta * delta * (*count) * nb_over_n; + *count = new_count; +} + +__device__ void WelfordWarpReduce(float thread_mean, float thread_m2, + float thread_count, float *mean, float *m2, + float *count) { + *mean = thread_mean; + *m2 = thread_m2; + *count = thread_count; + for (int mask = warpSize / 2; mask > 0; mask /= 2) { + float b_mean = __shfl_down_sync(0xffffffff, *mean, mask); + float b_m2 = __shfl_down_sync(0xffffffff, *m2, mask); + float b_count = __shfl_down_sync(0xffffffff, *count, mask); + WelfordCombine(b_mean, b_m2, b_count, mean, m2, count); + } +} + +// load 两个 half2, 保存到 float4 +__device__ void load_float4_from_half(float4 &vals, __half2 *input, int index) { + __half2 i1 = input[index * 2]; + __half2 i2 = input[index * 2 + 1]; + + vals.x = __half2float(i1.x); + vals.y = __half2float(i1.y); + vals.z = __half2float(i2.x); + vals.w = __half2float(i2.y); +} + +__device__ char4 float42char4(float4 vals, float quant_scale) { + char4 res; + res.x = float2int8(vals.x, quant_scale); + res.y = float2int8(vals.y, quant_scale); + res.z = float2int8(vals.z, quant_scale); + res.w = float2int8(vals.w, quant_scale); + return res; +} + +__device__ float4 char4addhalf2_dequant(char4 input_4, half2 residual_1, + half2 residual_2, float dequant_scale) { + float4 res; + res.x = + __int2float_rn(input_4.x) * dequant_scale + __half2float(residual_1.x); + res.y = + __int2float_rn(input_4.y) * dequant_scale + __half2float(residual_1.y); + res.z = + __int2float_rn(input_4.z) * dequant_scale + __half2float(residual_2.x); + res.w = + __int2float_rn(input_4.w) * dequant_scale + __half2float(residual_2.y); + return res; +} + +__device__ float4 compute_float4_norm_value(float4 vals, float mean, float m2, + int hidden_size, float epsilon, + half2 scale_1, half2 scale_2, + half2 bias_1, half2 bias_2) { + float4 norm_value; + norm_value.x = (vals.x - mean) * rsqrtf(m2 / hidden_size + epsilon) * + __half2float(scale_1.x) + + __half2float(bias_1.x); + norm_value.y = (vals.y - mean) * rsqrtf(m2 / hidden_size + epsilon) * + __half2float(scale_1.y) + + __half2float(bias_1.y); + norm_value.z = (vals.z - mean) * rsqrtf(m2 / hidden_size + epsilon) * + __half2float(scale_2.x) + + __half2float(bias_2.x); + norm_value.w = (vals.w - mean) * rsqrtf(m2 / hidden_size + epsilon) * + __half2float(scale_2.y) + + __half2float(bias_2.y); + return norm_value; +} + +// softmax +__forceinline__ __host__ __device__ int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} +template +__device__ T WARP_SHFL_XOR(T value, int laneMask, int width) { + unsigned int mask = 0xffffffff; +#if !(defined(__HIP_PLATFORM_HCC__) || defined(__ILUVATAR__)) + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template +struct Add { + __device__ T operator()(T a, T b) const { return a + b; } +}; + +template +struct Max { + __device__ T operator()(T a, T b) const { return a < b ? b : a; } +}; +template class ReduceOp> +__device__ void warp_reduce(acc_t *sum) { + ReduceOp r; +#pragma unroll + for (int offset = REDUCE_WARP_SIZE / 2; offset > 0; offset /= 2) { + acc_t b = WARP_SHFL_XOR(*sum, offset, REDUCE_WARP_SIZE); + *sum = r(*sum, b); + } +} + +__device__ void warp_argmax(float &value, int32_t &idx) { + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + float next_value = WARP_SHFL_XOR(value, offset, warpSize); + float next_idx = WARP_SHFL_XOR(idx, offset, warpSize); + if (next_value > value) { + value = next_value; + idx = next_idx; + } + } +} + +// gelu +// IxinferBiasGeluI8II8OKernel +template +__device__ T tanhf_exp(T x) { + // float e1 = __expf(x); + // float e2 = 1.0f / e1; + // return (e1 - e2) / (e1 + e2); + + return (2.f / (1.f + __expf(-2.f * x)) - 1.f); +} + +template +__device__ T gelu(T x) { + float cdf = + 0.5f * + (1.0f + tanhf_exp((0.7978845608028654f * (x + 0.044715f * x * x * x)))); + return x * cdf; +} + +/* fp16 gelu */ +template <> +__forceinline__ __device__ __half2 gelu<__half2>(__half2 val) { + __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); + float2 tmp_pow = __half22float2(val_pow3); + float2 tmp = __half22float2(val); + + tmp.x = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); + tmp.y = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); + return __hmul2(val, __float22half2_rn(tmp)); +} + +/* Convert vector index to 3-dim tensor index */ +__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1, + int dim2, int *id0, + int *id1, int *id2) { + *id2 = src % dim2; + src /= dim2; + + *id1 = src % dim1; + *id0 = src / dim1; +} + +template +__inline__ __device__ T warpReduceSumV2(T *val) { +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = warpSize / 2; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(0xffffffff, val[i], mask, warpSize); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T blockReduceSumV2(T *val) { + static __shared__ T shared[NUM][warpSize + 1]; + int lane = threadIdx.x % warpSize; + int wid = threadIdx.x / warpSize; + + warpReduceSumV2(val); + + if (lane == 0) { +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = lane < (blockDim.x / warpSize); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[i][lane] : (T)(0.0f); + } + warpReduceSumV2(val); + return (T)0.0f; +} + +__inline__ __device__ void warpReduceSum2Number(float *x, float *y) { +#pragma unroll + for (int mask = warpSize / 2; mask > 0; mask >>= 1) { + *x += __shfl_xor_sync(0xffffffff, *x, mask, warpSize); + *y += __shfl_xor_sync(0xffffffff, *y, mask, warpSize); + } +} + +__inline__ __device__ void blockReduceSum2Number(float *x, float *y) { + static __shared__ float shared[2][warpSize + 1]; + int lane = threadIdx.x % warpSize; + int wid = threadIdx.x / warpSize; + + warpReduceSum2Number(x, y); + if (lane == 0) { + shared[0][wid] = *x; + shared[1][wid] = *y; + } + __syncthreads(); + bool is_mask = lane < (blockDim.x / warpSize); + *x = is_mask ? shared[0][lane] : 0.0f; + *y = is_mask ? shared[0][lane] : 0.0f; + + warpReduceSum2Number(x, y); +} + +} // namespace transformer + +} // namespace inferrt +} // namespace iluvatar diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/__init__.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/__init__.py new file mode 100644 index 00000000..d50b3758 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/__init__.py @@ -0,0 +1,38 @@ +import torch +from faster_logsumexp import sp_opt + +# class FasterLogSumExp(torch.nn.Module): +# def __init__(self, weight, bias): +# super(FasterLogSumExp, self).__init__() +# self.weight = weight +# self.bias = bias + +# def forward(self, inputs, *args, **kwargs): +# hidden_size = self.weight.size(0) +# in_shape = inputs.shape +# inputs = inputs.view(-1,hidden_size) +# output, = sp_opt.test_opt(inputs,self.weight,self.bias) +# output = output.view(*in_shape) +# return output + +def FasterLogSumExp(inputs,dim): + # print(inputs.shape, dim) + if dim == 1 and len(inputs.shape)>2 and inputs.size(1)==2: + in_shape = inputs.shape + inputs = inputs.view(in_shape[0],in_shape[1],-1) + res, = sp_opt.test_opt(inputs) + new_shape = (in_shape[0],) + in_shape[2:] + res = res.view(*new_shape) + return res + # dim==0 现在的实现会有bug? + # if dim == 0 and len(inputs.shape)>=2: + # in_shape = inputs.shape + # inputs = inputs.view(in_shape[0],-1) + # res, = sp_opt.test_opt_dim0(inputs) + # new_shape = in_shape[1:] + # res = res.view(*new_shape) + # return res + # print(f"not support shape: {inputs.shape} dim: {dim}") + res = torch.logsumexp(inputs, dim=dim) + return res + diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/build.sh b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/build.sh new file mode 100644 index 00000000..f679258d --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/build.sh @@ -0,0 +1,22 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +set -euox pipefail + +rm -rf build +rm -rf *.so + +python3 setup.py build + +cp build/lib*/*.so . \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/kernel.cu b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/kernel.cu new file mode 100644 index 00000000..56eb0810 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/kernel.cu @@ -0,0 +1,155 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace iluvatar::inferrt::transformer { + +__global__ void LogSumExpWith2(half* input, half* output, int H) { + half2* h2_in1 = reinterpret_cast(input + blockIdx.x * 2 * H); + half2* h2_in2 = reinterpret_cast(input + blockIdx.x * 2 * H + H); + half2* h2_out = reinterpret_cast(output + blockIdx.x * H); + + int i = blockIdx.y * blockDim.x + threadIdx.x; + if (i < H / 2) { + float2 res; + half2 value1 = h2_in1[i]; + half2 value2 = h2_in2[i]; + + res.x = std::log(__expf(__half2float(value1.x)) + + __expf(__half2float(value2.x))); + res.y = std::log(__expf(__half2float(value1.y)) + + __expf(__half2float(value2.y))); + + half2 res_h2; + res_h2.x = __float2half(res.x); + res_h2.y = __float2half(res.y); + h2_out[i] = res_h2; + } +} + +void IxinferLogSumExpLauncher(half* input, half* output, int N, int C, int H, + cudaStream_t stream) { + const float layernorm_eps = 1e-5; + if (H % 2 != 0) { + throw std::runtime_error("IxinferLogSumExpLauncher: size error!"); + } + int num_threads = 1024; + int half_h = H / 2; + int num_roll = (half_h - 1 + num_threads) / num_threads; + dim3 grid(N, num_roll); + dim3 block(num_threads); + switch (C) { + case 2: + LogSumExpWith2<<>>(input, output, H); + break; + default: + throw std::runtime_error( + "IxinferLogSumExpLauncher error, unsupport size! "); + break; + } +} + +// https://zhuanlan.zhihu.com/p/153535799 +__global__ void LogSumExpDim0(half* input, half* output, int N, int H) { + half2* h2_out = reinterpret_cast(output); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + float2 res; + res.x = 0.f; + res.y = 0.f; + + float2 max_values; + max_values.x = -1000.f; + max_values.y = -1000.f; + + for (int batch_idx = 0; batch_idx < N; batch_idx++) { + half2* h2_in = reinterpret_cast(input + batch_idx * H); + half2 value = h2_in[i]; + + if (max_values.x < __half2float(value.x)) { + max_values.x = __half2float(value.x); + } + if (max_values.y < __half2float(value.y)) { + max_values.y = __half2float(value.y); + } + } + + for (int batch_idx = 0; batch_idx < N; batch_idx++) { + half2* h2_in = reinterpret_cast(input + batch_idx * H); + half2 value = h2_in[i]; + + res.x += __expf(__half2float(value.x) - max_values.x); + res.y += __expf(__half2float(value.y) - max_values.y); + } + + half2 res_h2; + res_h2.x = __float2half(std::log(res.x) + max_values.x); + res_h2.y = __float2half(std::log(res.y) + max_values.y); + + h2_out[i] = res_h2; +} + +void IxinferLogSumExpLauncher(half* input, half* output, int N, int H, + cudaStream_t stream) { + if (H % 2 != 0) { + throw std::runtime_error("IxinferLogSumExpLauncher: size error!"); + } + int num_threads = 1024; + int half_h = H / 2; + int num_roll = (half_h - 1 + num_threads) / num_threads; + dim3 grid(num_roll); + dim3 block(num_threads); + LogSumExpDim0<<>>(input, output, N, H); +} + +} // namespace iluvatar::inferrt::transformer + +std::vector one_test_opt(at::Tensor input) { + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(input.is_cuda()); + TORCH_CHECK(input.is_contiguous()); + + TORCH_CHECK(input.dim() == 3); + + int N = input.size(0); + int C = input.size(1); + int H = input.size(2); + + at::Tensor output = input.new_empty({N, H}); + + half* p_in = (half*)input.data_ptr(); + half* p_out = (half*)output.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + iluvatar::inferrt::transformer::IxinferLogSumExpLauncher(p_in, p_out, N, C, H, + stream); + return {output}; +} + +std::vector one_test_dim0(at::Tensor input) { + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(input.is_cuda()); + TORCH_CHECK(input.is_contiguous()); + + TORCH_CHECK(input.dim() == 2); + + int N = input.size(0); + int H = input.size(1); + + at::Tensor output = input.new_empty({H}); + + half* p_in = (half*)input.data_ptr(); + half* p_out = (half*)output.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + iluvatar::inferrt::transformer::IxinferLogSumExpLauncher(p_in, p_out, N, H, + stream); + return {output}; +} \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/setup.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/setup.py new file mode 100644 index 00000000..a031577c --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/setup.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import glob +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +CUR_DIR = os.path.dirname(os.path.abspath(__file__)) + +# cpp_files = glob.glob(os.path.join(CUR_DIR,"*.cpp")) +# cu_files = glob.glob(os.path.join(CUR_DIR,'*.cu')) +# source_files = cpp_files + cu_files +# print("source files:") +# for i in source_files: +# print(i) +source_files = [ + os.path.join(CUR_DIR,'test.cpp'), + os.path.join(CUR_DIR,'kernel.cu'), +] + +for i in source_files: + assert os.path.isfile(i) + print(i) + +setup( + name="test", + ext_modules=[ + CUDAExtension( + name="sp_opt", + libraries=["cuinfer"], + sources=source_files) + ], + cmdclass={ + "build_ext": BuildExtension + } +) diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.cpp b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.cpp new file mode 100644 index 00000000..5eaf6fe1 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.cpp @@ -0,0 +1,27 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +std::vector one_test_opt(at::Tensor input); + +std::vector test_opt(at::Tensor input) { + return one_test_opt(input); +} + +std::vector one_test_dim0(at::Tensor input); + +std::vector test_opt_dim0(at::Tensor input) { + return one_test_dim0(input); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("test_opt", &test_opt, ""); + m.def("test_opt_dim0", &test_opt_dim0, ""); +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.py new file mode 100644 index 00000000..7b22dbdd --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import torch +import sp_opt + +if __name__ == "__main__": + batch_tokens = 2 + c = 2 + hidden_size = 320*5000 + + inputs = torch.randn([batch_tokens,c, hidden_size]).cuda().half() + + # res1 = torch.log(torch.sum(torch.exp(inputs),dim=-1)) + # res2 = torch.logsumexp(inputs,dim=-1) + # diff = torch.abs(res1-res2) + # print(diff.max()) + + res_pt = torch.logsumexp(inputs,dim=1) + + res_cu, = sp_opt.test_opt(inputs) + + diff = torch.abs(res_pt - res_cu) + print(diff.max()) + + for i in range(20): + res_cu, = sp_opt.test_opt(inputs) + + batch_tokens = 55 + hidden_size = 320*5000 + inputs = torch.randn([batch_tokens,hidden_size]).cuda().half() + res_pt = torch.logsumexp(inputs,dim=0) + res_cu, = sp_opt.test_opt_dim0(inputs) + + diff = torch.abs(res_pt - res_cu) + print(diff.max()) + for i in range(20): + res_cu, = sp_opt.test_opt_dim0(inputs) + diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/__init__.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/__init__.py new file mode 100644 index 00000000..48d0cf5b --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/__init__.py @@ -0,0 +1,33 @@ +import torch +from faster_stack import sp_opt + +# class FasterLogSumExp(torch.nn.Module): +# def __init__(self, weight, bias): +# super(FasterLogSumExp, self).__init__() +# self.weight = weight +# self.bias = bias + +# def forward(self, inputs, *args, **kwargs): +# hidden_size = self.weight.size(0) +# in_shape = inputs.shape +# inputs = inputs.view(-1,hidden_size) +# output, = sp_opt.test_opt(inputs,self.weight,self.bias) +# output = output.view(*in_shape) +# return output + +def FasterStack(inputs): + if len(inputs) == 4: + a,b,c,d = inputs + in_shape = a.shape + res, = sp_opt.test_opt(a.view(-1),b.view(-1),c.view(-1),d.view(-1)) + new_shape = (4,) + in_shape + res = res.view(*new_shape) + return res + if len(inputs) == 2: + a,b = inputs + in_shape = a.shape + res, = sp_opt.test_opt_2(a.view(-1),b.view(-1)) + new_shape = (2,) + in_shape + res = res.view(*new_shape) + return res + return torch.stack(inputs) \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/build.sh b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/build.sh new file mode 100644 index 00000000..f679258d --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/build.sh @@ -0,0 +1,22 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +set -euox pipefail + +rm -rf build +rm -rf *.so + +python3 setup.py build + +cp build/lib*/*.so . \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/kernel.cu b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/kernel.cu new file mode 100644 index 00000000..0fdff649 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/kernel.cu @@ -0,0 +1,146 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace iluvatar::inferrt::transformer { + +__global__ void Stack(half* a, half* b, half* c, half* d, half* output, int H) { + half2* h2_a = reinterpret_cast(a); + half2* h2_b = reinterpret_cast(b); + half2* h2_c = reinterpret_cast(c); + half2* h2_d = reinterpret_cast(d); + + half2* h2_out_a = reinterpret_cast(output); + half2* h2_out_b = reinterpret_cast(output + H); + half2* h2_out_c = reinterpret_cast(output + H * 2); + half2* h2_out_d = reinterpret_cast(output + H * 3); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i < H / 2) { + h2_out_a[i] = h2_a[i]; + h2_out_b[i] = h2_b[i]; + h2_out_c[i] = h2_c[i]; + h2_out_d[i] = h2_d[i]; + } +} + +void IxinferStackLauncher(half* a, half* b, half* c, half* d, half* output, + int H, cudaStream_t stream) { + if (H % 2 != 0) { + throw std::runtime_error("IxinferStackLauncher: size error!"); + } + int num_threads = 1024; + int half_h = H / 2; + int num_roll = (half_h - 1 + num_threads) / num_threads; + dim3 grid(num_roll); + dim3 block(num_threads); + Stack<<>>(a, b, c, d, output, H); +} + +__global__ void Stack(half* a, half* b, half* output, int H) { + half2* h2_a = reinterpret_cast(a); + half2* h2_b = reinterpret_cast(b); + + half2* h2_out_a = reinterpret_cast(output); + half2* h2_out_b = reinterpret_cast(output + H); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i < H / 2) { + h2_out_a[i] = h2_a[i]; + h2_out_b[i] = h2_b[i]; + } +} + +void IxinferStackLauncher(half* a, half* b, half* output, int H, + cudaStream_t stream) { + if (H % 2 != 0) { + throw std::runtime_error("IxinferStackLauncher: size error!"); + } + int num_threads = 1024; + int half_h = H / 2; + int num_roll = (half_h - 1 + num_threads) / num_threads; + dim3 grid(num_roll); + dim3 block(num_threads); + Stack<<>>(a, b, output, H); +} + +} // namespace iluvatar::inferrt::transformer + +std::vector one_test_opt(at::Tensor a, at::Tensor b, at::Tensor c, + at::Tensor d) { + TORCH_CHECK(a.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(a.is_cuda()); + TORCH_CHECK(a.is_contiguous()); + + TORCH_CHECK(b.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(b.is_cuda()); + TORCH_CHECK(b.is_contiguous()); + + TORCH_CHECK(c.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(c.is_cuda()); + TORCH_CHECK(c.is_contiguous()); + + TORCH_CHECK(d.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(d.is_cuda()); + TORCH_CHECK(d.is_contiguous()); + + TORCH_CHECK(a.dim() == 1); + TORCH_CHECK(b.dim() == 1); + TORCH_CHECK(c.dim() == 1); + TORCH_CHECK(d.dim() == 1); + + int N = a.size(0); + + TORCH_CHECK(b.size(0) == N); + TORCH_CHECK(c.size(0) == N); + TORCH_CHECK(d.size(0) == N); + + at::Tensor output = a.new_empty({N * 4}); + + half* p_a = (half*)a.data_ptr(); + half* p_b = (half*)b.data_ptr(); + half* p_c = (half*)c.data_ptr(); + half* p_d = (half*)d.data_ptr(); + half* p_out = (half*)output.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + iluvatar::inferrt::transformer::IxinferStackLauncher(p_a, p_b, p_c, p_d, + p_out, N, stream); + return {output}; +} + +std::vector one_test_opt_2(at::Tensor a, at::Tensor b) { + TORCH_CHECK(a.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(a.is_cuda()); + TORCH_CHECK(a.is_contiguous()); + + TORCH_CHECK(b.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(b.is_cuda()); + TORCH_CHECK(b.is_contiguous()); + + TORCH_CHECK(a.dim() == 1); + TORCH_CHECK(b.dim() == 1); + + int N = a.size(0); + + TORCH_CHECK(b.size(0) == N); + + at::Tensor output = a.new_empty({N * 2}); + + half* p_a = (half*)a.data_ptr(); + half* p_b = (half*)b.data_ptr(); + half* p_out = (half*)output.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + iluvatar::inferrt::transformer::IxinferStackLauncher(p_a, p_b, p_out, N, + stream); + return {output}; +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/setup.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/setup.py new file mode 100644 index 00000000..a031577c --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/setup.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import glob +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +CUR_DIR = os.path.dirname(os.path.abspath(__file__)) + +# cpp_files = glob.glob(os.path.join(CUR_DIR,"*.cpp")) +# cu_files = glob.glob(os.path.join(CUR_DIR,'*.cu')) +# source_files = cpp_files + cu_files +# print("source files:") +# for i in source_files: +# print(i) +source_files = [ + os.path.join(CUR_DIR,'test.cpp'), + os.path.join(CUR_DIR,'kernel.cu'), +] + +for i in source_files: + assert os.path.isfile(i) + print(i) + +setup( + name="test", + ext_modules=[ + CUDAExtension( + name="sp_opt", + libraries=["cuinfer"], + sources=source_files) + ], + cmdclass={ + "build_ext": BuildExtension + } +) diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.cpp b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.cpp new file mode 100644 index 00000000..08703064 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.cpp @@ -0,0 +1,29 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +std::vector one_test_opt(at::Tensor a, at::Tensor b, at::Tensor c, + at::Tensor d); + +std::vector test_opt(at::Tensor a, at::Tensor b, at::Tensor c, + at::Tensor d) { + return one_test_opt(a, b, c, d); +} + +std::vector one_test_opt_2(at::Tensor a, at::Tensor b); + +std::vector test_opt_2(at::Tensor a, at::Tensor b) { + return one_test_opt_2(a, b); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("test_opt", &test_opt, ""); + m.def("test_opt_2", &test_opt_2, ""); +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.py new file mode 100644 index 00000000..185b829b --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.py @@ -0,0 +1,74 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import torch +import sp_opt + +if __name__ == "__main__": + batch_tokens = 320 + hidden_size = 5000 + + a = torch.randn([batch_tokens,hidden_size]).cuda().half() + b = torch.randn([batch_tokens,hidden_size]).cuda().half() + c = torch.randn([batch_tokens,hidden_size]).cuda().half() + d = torch.randn([batch_tokens,hidden_size]).cuda().half() + + res_pt = torch.stack([a,b,c,d]) + + res_cu, = sp_opt.test_opt(a.view(-1),b.view(-1),c.view(-1),d.view(-1)) + res_cu = res_cu.view(4,batch_tokens,hidden_size) + + diff = torch.abs(res_pt-res_cu) + print(diff) + print(diff.max()) + + for i in range(20): + res_cu, = sp_opt.test_opt(a.view(-1),b.view(-1),c.view(-1),d.view(-1)) + + res_pt = torch.stack([a,b]) + + res_cu, = sp_opt.test_opt_2(a.view(-1),b.view(-1)) + res_cu = res_cu.view(2,batch_tokens,hidden_size) + + diff = torch.abs(res_pt-res_cu) + print(diff) + print(diff.max()) + for i in range(20): + res_cu, = sp_opt.test_opt_2(a.view(-1),b.view(-1)) + # # res1 = torch.log(torch.sum(torch.exp(inputs),dim=-1)) + # # res2 = torch.logsumexp(inputs,dim=-1) + # # diff = torch.abs(res1-res2) + # # print(diff.max()) + + # res_pt = torch.logsumexp(inputs,dim=1) + + # res_cu, = sp_opt.test_opt(inputs) + + # diff = torch.abs(res_pt - res_cu) + # print(diff.max()) + + # for i in range(20): + # res_cu, = sp_opt.test_opt(inputs) + + # batch_tokens = 55 + # hidden_size = 320*5000 + # inputs = torch.randn([batch_tokens,hidden_size]).cuda().half() + # res_pt = torch.logsumexp(inputs,dim=0) + # res_cu, = sp_opt.test_opt_dim0(inputs) + + # diff = torch.abs(res_pt - res_cu) + # print(diff.max()) + # for i in range(20): + # res_cu, = sp_opt.test_opt_dim0(inputs) + diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/hparams/train_ASR_transformer.yaml b/models/speech/speech_recognition/transformer_asr/ixrt/hparams/train_ASR_transformer.yaml new file mode 100644 index 00000000..859d09f3 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/hparams/train_ASR_transformer.yaml @@ -0,0 +1,253 @@ +# ############################################################################ +# Model: E2E ASR with Transformer +# Encoder: Transformer Encoder +# Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch +# Tokens: BPE with unigram +# losses: CTC + KLdiv (Label Smoothing loss) +# Training: AISHELL-1 +# Authors: Jianyuan Zhong, Titouan Parcollet +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 8886 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/transformer/ +cer_file: !ref /cer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + +# Data files +data_folder: !PLACEHOLDER # e,g./path/to/aishell +# noise/ris dataset will automatically be downloaded +data_folder_rirs: !ref # Change this is needed +skip_prep: False +ckpt_interval_minutes: 15 # save checkpoint every N min +train_data: !ref /csv_data/train.csv +valid_data: !ref /csv_data/dev.csv +test_data: !ref /csv_data/test.csv +tokenizer_file: speechbrain/asr-transformer-aishell/tokenizer.ckpt + +# Training parameters +number_of_epochs: 50 +batch_size: 64 +ctc_weight: 0.3 +gradient_accumulation: 4 +loss_reduction: 'batchmean' +sorting: ascending + +dynamic_batching: False +dynamic_batch_sampler: + feats_hop_size: 0.01 + max_batch_len: 15 # in terms of "duration" in annotations by default, second here + left_bucket_len: 200 # old implementation attributs + multiplier: 1.1 # old implementation attributs + shuffle_ex: False # if true re-creates batches at each epoch shuffling examples. + num_buckets: 10 # floor(log(max_batch_len/left_bucket_len, multiplier)) + 1 + batch_ordering: ascending + +num_workers: 6 + +# stages related parameters +stage_one_epochs: 40 +lr_adam: 1.0 +lr_sgd: 0.000025 + +# Feature parameters +sample_rate: 16000 +n_fft: 400 +n_mels: 80 + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + shuffle: True + +valid_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + +####################### Model parameters ########################### +# Transformer +d_model: 256 +nhead: 4 +num_encoder_layers: 12 +num_decoder_layers: 6 +d_ffn: 2048 +transformer_dropout: 0.1 +activation: !name:torch.nn.GELU +output_neurons: 5000 + +# Outputs +blank_index: 0 +label_smoothing: 0.1 +pad_index: 0 +bos_index: 1 +eos_index: 2 + +# Decoding parameters +min_decode_ratio: 0.0 +max_decode_ratio: 1.0 # 1.0 +valid_search_interval: 10 +valid_beam_size: 10 +test_beam_size: 1 +ctc_weight_decode: 0.40 + +############################## models ################################ + +CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd + input_shape: (8, 10, 80) + num_blocks: 2 + num_layers_per_block: 1 + out_channels: (256, 256) + kernel_sizes: (3, 3) + strides: (2, 2) + residuals: (False, False) + +Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length + input_size: 5120 + tgt_vocab: !ref + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: !ref + d_ffn: !ref + dropout: !ref + activation: !ref + normalize_before: True + +tokenizer: !new:sentencepiece.SentencePieceProcessor + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: !ref + n_neurons: !ref + +seq_lin: !new:speechbrain.nnet.linear.Linear + input_size: !ref + n_neurons: !ref + +env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt + openrir_folder: !ref + babble_prob: 0.0 + reverb_prob: 0.0 + noise_prob: 1.0 + noise_snr_low: 0 + noise_snr_high: 15 + +modules: + CNN: !ref + Transformer: !ref + seq_lin: !ref + ctc_lin: !ref + env_corrupt: !ref + +model: !new:torch.nn.ModuleList + - [!ref , !ref , !ref , !ref ] + +# define two optimizers here for two-stage training +Adam: !name:torch.optim.Adam + lr: 0 + betas: (0.9, 0.98) + eps: 0.000000001 + +SGD: !name:torch.optim.SGD + lr: !ref + momentum: 0.99 + nesterov: True + + +valid_search: !new:speechbrain.decoders.S2STransformerBeamSearch + modules: [!ref , !ref , !ref ] + bos_index: !ref + eos_index: !ref + blank_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + beam_size: !ref + ctc_weight: !ref + using_eos_threshold: False + length_normalization: True + +test_search: !new:speechbrain.decoders.S2STransformerBeamSearch + modules: [!ref , !ref , !ref ] + bos_index: !ref + eos_index: !ref + blank_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + beam_size: !ref + ctc_weight: !ref + using_eos_threshold: False + length_normalization: True + +log_softmax: !new:torch.nn.LogSoftmax + dim: -1 + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + reduction: !ref + +seq_cost: !name:speechbrain.nnet.losses.kldiv_loss + label_smoothing: !ref + reduction: !ref + +noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler + lr_initial: !ref + n_warmup_steps: 25000 + model_size: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + noam_scheduler: !ref + normalizer: !ref + counter: !ref + +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +normalize: !new:speechbrain.processing.features.InputNormalization + norm_type: global + update_until_epoch: 4 + +augmentation: !new:speechbrain.lobes.augment.SpecAugment + time_warp: True + time_warp_window: 5 + time_warp_mode: bicubic + freq_mask: True + n_freq_mask: 2 + time_mask: True + n_time_mask: 2 + replace_with_zero: False + freq_mask_width: 30 + time_mask_width: 40 + +compute_features: !new:speechbrain.lobes.features.Fbank + sample_rate: !ref + n_fft: !ref + n_mels: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +# AISHELL-1 has spaces between words in the transcripts, +# which Chinese writing normally does not do. +# If remove_spaces, spaces are removed +# from the transcript before computing CER. +# (e.g., 祝 可爱 的 你 —> 祝可爱的你) +remove_spaces: True +split_tokens: !apply:operator.not_ [!ref ] + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: !ref +acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats + +pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer + collect_in: !ref + loadables: + tokenizer: !ref + paths: + tokenizer: !ref +engine_path: transformer.engine +ckpt_path: /home/data/speechbrain/results \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/inference.py b/models/speech/speech_recognition/transformer_asr/ixrt/inference.py new file mode 100644 index 00000000..68ef0e40 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/inference.py @@ -0,0 +1,606 @@ +#!/usr/bin/env/python3 +""" + +AISHELL-1 transformer model recipe. (Adapted from the LibriSpeech recipe.) + +""" +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import sys +import time +import torch +import logging +import speechbrain as sb +from speechbrain import Stage +from speechbrain.dataio.dataloader import LoopedLoader +from speechbrain.utils.distributed import run_on_main +from hyperpyyaml import load_hyperpyyaml +from speechbrain.utils.checkpoints import Checkpointer +import numpy as np +from speechbrain.utils import data_utils +import tensorrt +from torch.utils.data import DataLoader +from tqdm import tqdm +import convert +import beam_search +from load_ixrt_plugin import load_ixrt_plugin +from tensorrt import Dims +from speechbrain.lobes.models.transformer import Transformer +Transformer.get_lookahead_mask = convert.patch_get_lookahead_mask +load_ixrt_plugin() +logger = logging.getLogger(__name__) + + +def volume(shape): + result = 1 + for i in shape: + result *= i + return result + + +class ASR(sb.core.Brain): + def __init__(self, engine_path, *args, **kwargs): + super().__init__(*args, **kwargs) + # + self.forward_time = 0 + # ixrt + self.logger = tensorrt.Logger(tensorrt.Logger.ERROR) + with open(engine_path, "rb") as f, tensorrt.Runtime(self.logger) as self.runtime: + self.engine = self.runtime.deserialize_cuda_engine(f.read()) + assert self.engine + self.context = self.engine.create_execution_context() + assert self.context + self.encoder_ln_out = torch.zeros((64,2048,256), dtype=torch.float16).cuda() + self.infer_time = 0 + self.hparams.valid_search.return_log_probs = True + self.modules.CNN = self.modules.CNN.half() + self.hparams.valid_search = self.hparams.valid_search.half() + self.hparams.valid_search.model = self.hparams.valid_search.model.half() + self.hparams.valid_search.fc = self.hparams.valid_search.fc.half() + self.hparams.valid_search.ctc_fc = self.hparams.valid_search.ctc_fc.half() + self.hparams.valid_search.minus_inf = -10000 + self.hparams.valid_search.softmax = self.hparams.valid_search.softmax.half() + self.hparams.valid_search.model.decoder = convert.convert_decoder_model(self.hparams.valid_search.model.decoder) + # Given all input/output bindings, run in a dynamic shape way + def ixrt_infer(self, engine, context, bindings): + assert engine.num_bindings == len(bindings) + io_buffers = [0] * engine.num_bindings + for name, arr in bindings.items(): + idx = engine.get_binding_index(name) + io_buffers[idx] = arr.data_ptr() + # dynamic input + if engine.binding_is_input(idx): + context.set_binding_shape(idx, Dims(arr.shape)) + + forward_start_time = time.time() + assert context.execute_v2(io_buffers) + + torch.cuda.synchronize() + self.forward_time += time.time() - forward_start_time + outputs = {} + for name, arr in bindings.items(): + idx = engine.get_binding_index(name) + if not engine.binding_is_input(idx): + # dynamic output + shape = context.get_binding_shape(idx) + outputs[name] = arr.view(-1)[:volume(shape)].view(*shape) + return outputs + + def compute_forward(self, batch, stage): + """Forward computations from the waveform batches to the output probabilities.""" + + batch = batch.to(self.device) + wavs, wav_lens = batch.sig + tokens_bos, _ = batch.tokens_bos + + # Add augmentation if specified + if stage == sb.Stage.TRAIN: + if hasattr(self.modules, "env_corrupt"): + wavs_noise = self.modules.env_corrupt(wavs, wav_lens) + wavs = torch.cat([wavs, wavs_noise], dim=0) + wav_lens = torch.cat([wav_lens, wav_lens]) + tokens_bos = torch.cat([tokens_bos, tokens_bos], dim=0) + + torch.cuda.synchronize() + start_time = time.time() + + # compute features + feats = self.hparams.compute_features(wavs) + current_epoch = self.hparams.epoch_counter.current + feats = self.hparams.normalize(feats, wav_lens, epoch=current_epoch) + + if stage == sb.Stage.TRAIN: + if hasattr(self.hparams, "augmentation"): + feats = self.hparams.augmentation(feats) + + # forward modules + src = self.modules.CNN(feats.half()) + + # Orignal PyTorch implementation, comment this to compare + # enc_out, _ = self.modules.Transformer( + # src, tokens_bos, wav_lens, pad_idx=self.hparams.pad_index + # ) + # logits = self.modules.ctc_lin(enc_out) + # p_ctc = self.hparams.log_softmax(logits) + # hyps, _ = self.hparams.test_search( + # enc_out.detach(), wav_lens + # ) + # return p_ctc, wav_lens, hyps + + # transformer + if src.ndim == 4: + bz, t, ch1, ch2 = src.shape + src = src.reshape(bz, t, ch1 * ch2) + + # ixrt inference + t1 = time.time() + bindings = {"input": src.half(), "length_radio": wav_lens.half(), + "encoder_ln_out": self.encoder_ln_out} + + infer_result = self.ixrt_infer(self.engine, self.context, bindings) + encoder_ln_out = infer_result["encoder_ln_out"] + t2 = time.time() + + hyps, _, p_ctc = beam_search.forward(self.hparams.valid_search, encoder_ln_out.half(), wav_lens.half()) + torch.cuda.synchronize() + infer_time = time.time() - start_time + + self.infer_time += infer_time + + return p_ctc, wav_lens, hyps + + def compute_objectives(self, predictions, batch, stage): + """Computes the loss (CTC+NLL) given predictions and targets.""" + + # ( + # p_ctc, + # p_seq, + # wav_lens, + # hyps, + # ) = predictions + + # 去除 seq2seq log-probabilities + ( + p_ctc, + wav_lens, + hyps, + ) = predictions + + ids = batch.id + tokens_eos, tokens_eos_lens = batch.tokens_eos + tokens, tokens_lens = batch.tokens + + if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN: + tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0) + tokens_eos_lens = torch.cat( + [tokens_eos_lens, tokens_eos_lens], dim=0) + tokens = torch.cat([tokens, tokens], dim=0) + tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0) + + # 去除 seq2seq 部分 loss + # loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) + + if stage != sb.Stage.TRAIN: + current_epoch = self.hparams.epoch_counter.current + valid_search_interval = self.hparams.valid_search_interval + + if current_epoch % valid_search_interval == 0 or (stage == sb.Stage.TEST): + # Decode token terms to words + predicted_words = [ + tokenizer.decode_ids(utt_seq).split(" ") for utt_seq in hyps + ] + target_words = [wrd.split(" ") for wrd in batch.wrd] + if self.hparams.remove_spaces: + predicted_words = ["".join(p) for p in predicted_words] + target_words = ["".join(t) for t in target_words] + self.cer_metric.append(ids, predicted_words, target_words) + + # 不计算 acc 部分 + # # compute the accuracy of the one-step-forward prediction + # self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens) + return -torch.ones([1]) + + def fit_batch(self, batch): + """Train the parameters given a single batch in input""" + # check if we need to switch optimizer + # if so change the optimizer from Adam to SGD + self.check_and_reset_optimizer() + + predictions = self.compute_forward(batch, sb.Stage.TRAIN) + loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN) + + # normalize the loss by gradient_accumulation step + (loss / self.hparams.gradient_accumulation).backward() + + if self.step % self.hparams.gradient_accumulation == 0: + # gradient clipping & early stop if loss is not fini + self.check_gradients(loss) + + self.optimizer.step() + self.optimizer.zero_grad() + + # anneal lr every update + self.hparams.noam_annealing(self.optimizer) + + return loss.detach() + + def evaluate_batch(self, batch, stage): + """Computations needed for validation/test batches""" + with torch.no_grad(): + predictions = self.compute_forward(batch, stage=stage) + loss = self.compute_objectives(predictions, batch, stage=stage) + return loss + + def on_stage_start(self, stage, epoch): + """Gets called at the beginning of each epoch""" + if stage != sb.Stage.TRAIN: + # self.acc_metric = self.hparams.acc_computer() + self.cer_metric = self.hparams.cer_computer() + + def on_stage_end(self, stage, stage_loss, epoch): + """Gets called at the end of a epoch.""" + # Compute/store important stats + stage_stats = {"forward time": self.forward_time} + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + else: + # stage_stats["ACC"] = self.acc_metric.summarize() + current_epoch = self.hparams.epoch_counter.current + valid_search_interval = self.hparams.valid_search_interval + if current_epoch % valid_search_interval == 0 or stage == sb.Stage.TEST: + stage_stats["CER"] = self.cer_metric.summarize("error_rate") + + # log stats and save checkpoint at end-of-epoch + if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process(): + + # report different epoch stages according current stage + current_epoch = self.hparams.epoch_counter.current + if current_epoch <= self.hparams.stage_one_epochs: + lr = self.hparams.noam_annealing.current_lr + steps = self.hparams.noam_annealing.n_steps + optimizer = self.optimizer.__class__.__name__ + else: + lr = self.hparams.lr_sgd + steps = -1 + optimizer = self.optimizer.__class__.__name__ + + epoch_stats = { + "epoch": epoch, + "lr": lr, + "steps": steps, + "optimizer": optimizer, + } + self.hparams.train_logger.log_stats( + stats_meta=epoch_stats, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + self.checkpointer.save_and_keep_only( + meta={"ACC": stage_stats["ACC"], "epoch": epoch}, + max_keys=["ACC"], + num_to_keep=10, + ) + + elif stage == sb.Stage.TEST: + self.hparams.train_logger.log_stats( + stats_meta={ + "Epoch loaded": self.hparams.epoch_counter.current}, + test_stats=stage_stats, + ) + with open(self.hparams.cer_file, "w") as w: + self.cer_metric.write_stats(w) + + def check_and_reset_optimizer(self): + """reset the optimizer if training enters stage 2""" + current_epoch = self.hparams.epoch_counter.current + if not hasattr(self, "switched"): + self.switched = False + if isinstance(self.optimizer, torch.optim.SGD): + self.switched = True + + if self.switched is True: + return + + if current_epoch > self.hparams.stage_one_epochs: + self.optimizer = self.hparams.SGD(self.modules.parameters()) + + if self.checkpointer is not None: + self.checkpointer.add_recoverable("optimizer", self.optimizer) + + self.switched = True + + def on_fit_start(self): + """Initialize the right optimizer on the training start""" + super().on_fit_start() + + # if the model is resumed from stage two, reinitialize the optimizer + current_epoch = self.hparams.epoch_counter.current + current_optimizer = self.optimizer + if current_epoch > self.hparams.stage_one_epochs: + del self.optimizer + self.optimizer = self.hparams.SGD(self.modules.parameters()) + + # Load latest checkpoint to resume training if interrupted + if self.checkpointer is not None: + + # do not reload the weights if training is interrupted right before stage 2 + group = current_optimizer.param_groups[0] + if "momentum" not in group: + return + + self.checkpointer.recover_if_possible( + device=torch.device(self.device)) + + def on_evaluate_start(self, max_key=None, min_key=None): + """perform checkpoint averge if needed""" + super().on_evaluate_start() + + ckpts = self.checkpointer.find_checkpoints( + max_key=max_key, min_key=min_key) + ckpt = sb.utils.checkpoints.average_checkpoints( + ckpts, recoverable_name="model", device=self.device + ) + + self.hparams.model.load_state_dict(ckpt, strict=True) + self.hparams.model.eval() + + def evaluate( + self, + test_set, + max_key=None, + min_key=None, + progressbar=None, + test_loader_kwargs={}, + ): + self.debug = False + self.debug_batches = 1 + if progressbar is None: + progressbar = not self.noprogressbar + + if not ( + isinstance(test_set, DataLoader) + or isinstance(test_set, LoopedLoader) + ): + test_loader_kwargs["ckpt_prefix"] = None + test_set = self.make_dataloader( + test_set, Stage.TEST, **test_loader_kwargs + ) + self.on_evaluate_start(max_key=max_key, min_key=min_key) + self.on_stage_start(Stage.TEST, epoch=None) + self.modules.eval() + avg_test_loss = 0.0 + self.step = 0 + with torch.no_grad(): + for batch in tqdm( + test_set, dynamic_ncols=True, disable=not progressbar + ): + self.step += 1 + loss = self.evaluate_batch(batch, stage=Stage.TEST) + avg_test_loss = self.update_average(loss, avg_test_loss) + + # Profile only if desired (steps allow the profiler to know when all is warmed up) + if self.profiler is not None: + if self.profiler.record_steps: + self.profiler.step() + + # Debug mode only runs a few batches + if self.debug and self.step == self.debug_batches: + break + + # Only run evaluation "on_stage_end" on main process + run_on_main( + self.on_stage_end, args=[Stage.TEST, avg_test_loss, None] + ) + self.step = 0 + return avg_test_loss + + +def dataio_prepare(hparams): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions.""" + data_folder = hparams["data_folder"] + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_data"], + replacements={"data_root": data_folder}, + ) + + if hparams["sorting"] == "ascending": + # we sort training data to speed up training and get better results. + train_data = train_data.filtered_sorted(sort_key="duration") + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "descending": + train_data = train_data.filtered_sorted( + sort_key="duration", reverse=True) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "random": + pass + + else: + raise NotImplementedError( + "sorting must be random, ascending or descending") + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_data"], + replacements={"data_root": data_folder}, + ) + valid_data = valid_data.filtered_sorted(sort_key="duration") + + test_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["test_data"], + replacements={"data_root": data_folder}, + ) + test_data = test_data.filtered_sorted(sort_key="duration") + + datasets = [train_data, valid_data, test_data] + + # Defining tokenizer and loading it + tokenizer = hparams["tokenizer"] + + # 2. Define audio pipeline: + @sb.utils.data_pipeline.takes("wav") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav): + sig = sb.dataio.dataio.read_audio(wav) + return sig + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + + # 3. Define text pipeline: + @sb.utils.data_pipeline.takes("transcript") + @sb.utils.data_pipeline.provides( + "wrd", "tokens_list", "tokens_bos", "tokens_eos", "tokens" + ) + def text_pipeline(wrd): + yield wrd + tokens_list = tokenizer.encode_as_ids(wrd) + yield tokens_list + tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list)) + yield tokens_bos + tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]]) + yield tokens_eos + tokens = torch.LongTensor(tokens_list) + yield tokens + + sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) + + # 4. Set output: + sb.dataio.dataset.set_output_keys( + datasets, + ["id", "sig", "wrd", "tokens_bos", "tokens_eos", "tokens"], + ) + + # 5. If Dynamic Batching is used, we instantiate the needed samplers. + train_batch_sampler = None + valid_batch_sampler = None + if hparams["dynamic_batching"]: + from speechbrain.dataio.sampler import DynamicBatchSampler # noqa + + dynamic_hparams = hparams["dynamic_batch_sampler"] + num_buckets = dynamic_hparams["num_buckets"] + + train_batch_sampler = DynamicBatchSampler( + train_data, + dynamic_hparams["max_batch_len"], + num_buckets=num_buckets, + length_func=lambda x: x["duration"], + shuffle=dynamic_hparams["shuffle_ex"], + batch_ordering=dynamic_hparams["batch_ordering"], + ) + + valid_batch_sampler = DynamicBatchSampler( + valid_data, + dynamic_hparams["max_batch_len"], + num_buckets=num_buckets, + length_func=lambda x: x["duration"], + shuffle=dynamic_hparams["shuffle_ex"], + batch_ordering=dynamic_hparams["batch_ordering"], + ) + + return ( + train_data, + valid_data, + test_data, + tokenizer, + train_batch_sampler, + valid_batch_sampler, + ) + + +if __name__ == "__main__": + # CLI: + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # If --distributed_launch then + # create ddp_group with the right communication protocol + sb.utils.distributed.ddp_init_group(run_opts) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # 1. # Dataset prep (parsing Librispeech) + from aishell_prepare import prepare_aishell # noqa + + # multi-gpu (ddp) save data preparation + run_on_main( + prepare_aishell, + kwargs={ + "data_folder": hparams["data_folder"], + "save_folder": hparams["output_folder"], + "skip_prep": hparams["skip_prep"], + }, + ) + + # here we create the datasets objects as well as tokenization and encoding + ( + train_data, + valid_data, + test_data, + tokenizer, + train_bsampler, + valid_bsampler, + ) = dataio_prepare(hparams) + + hparams["pretrainer"].collect_files(default_source=hparams['ckpt_path']) + hparams["pretrainer"].load_collected(device=run_opts["device"]) + + # Trainer initialization + asr_brain = ASR( + modules=hparams["modules"], + opt_class=hparams["Adam"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + engine_path=hparams['engine_path'] + ) + + asr_brain.tokenizer = tokenizer + + # Changing the samplers if dynamic batching is activated + train_dataloader_opts = hparams["train_dataloader_opts"] + valid_dataloader_opts = hparams["valid_dataloader_opts"] + + if train_bsampler is not None: + train_dataloader_opts = { + "batch_sampler": train_bsampler, + "num_workers": hparams["num_workers"], + } + if valid_bsampler is not None: + valid_dataloader_opts = {"batch_sampler": valid_bsampler} + + # evaluation + print("*** start evaluation ***") + start_time = time.time() + asr_brain.evaluate( + test_data, test_loader_kwargs=hparams["test_dataloader_opts"]) + eval_time = asr_brain.infer_time + + ## 统计数据总音频时长 + duration = 0.0 + for value in test_data.data.values(): + duration = duration + value['duration'] + num_samples = len(test_data) + print(f"samples: {num_samples}, QPS: {num_samples / eval_time} ") + print(f"infer time :{eval_time},RTF: {eval_time / duration} ") diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/load_ixrt_plugin.py b/models/speech/speech_recognition/transformer_asr/ixrt/load_ixrt_plugin.py new file mode 100644 index 00000000..2bb0abc2 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/load_ixrt_plugin.py @@ -0,0 +1,26 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import ctypes +import tensorrt +from os.path import join, dirname, exists +def load_ixrt_plugin(logger=tensorrt.Logger(tensorrt.Logger.INFO), namespace="", dynamic_path=""): + if not dynamic_path: + dynamic_path = join(dirname(tensorrt.__file__), "lib", "libixrt_plugin.so") + if not exists(dynamic_path): + raise FileNotFoundError( + f"The ixrt_plugin lib {dynamic_path} is not existed, please provided effective plugin path!") + ctypes.CDLL(dynamic_path) + tensorrt.init_libnvinfer_plugins(logger, namespace) + print(f"Loaded plugin from {dynamic_path}") -- Gitee From 5b076cb0ccb8a1fd16ca0a9f7beee33985b283d4 Mon Sep 17 00:00:00 2001 From: majorli Date: Wed, 7 Aug 2024 11:35:16 +0800 Subject: [PATCH 3/7] update yolov6 and yolov8 results and format Signed-off-by: majorli --- models/cv/detection/yolov6/ixrt/README.md | 27 ++++++++++++++----- .../scripts/infer_yolov6s_fp16_accuracy.sh | 3 ++- .../scripts/infer_yolov6s_fp16_performance.sh | 3 ++- .../scripts/infer_yolov6s_int8_accuracy.sh | 3 ++- .../scripts/infer_yolov6s_int8_performance.sh | 3 ++- models/cv/detection/yolov8/ixrt/README.md | 23 ++++++++++++---- .../scripts/infer_yolov8n_fp16_accuracy.sh | 3 ++- .../scripts/infer_yolov8n_fp16_performance.sh | 3 ++- .../scripts/infer_yolov8n_int8_accuracy.sh | 3 ++- .../scripts/infer_yolov8n_int8_performance.sh | 3 ++- 10 files changed, 54 insertions(+), 20 deletions(-) diff --git a/models/cv/detection/yolov6/ixrt/README.md b/models/cv/detection/yolov6/ixrt/README.md index 8d78557a..166c2fa0 100644 --- a/models/cv/detection/yolov6/ixrt/README.md +++ b/models/cv/detection/yolov6/ixrt/README.md @@ -13,12 +13,13 @@ YOLOv6 integrates cutting-edge object detection advancements from industry and a ## CentOS yum install -y mesa-libGL ## Ubuntu -apt install -y libgl1-mesa-dev +apt install -y libgl1-mesa-glx pip3 install tqdm pip3 install onnx pip3 install onnxsim pip3 install pycocotools +pip3 install pycuda ``` ### Download @@ -27,26 +28,32 @@ Pretrained model: to download the validation dataset. +```bash +# get yolov6s.pt +wget https://github.com/meituan/YOLOv6/releases/download/0.4.0/yolov6s.pt +# set coco path +mkdir -p data/ +ln -s /Path/to/coco/ data/coco +``` + ### Model Conversion ```bash # install yolov6 git clone https://github.com/meituan/YOLOv6.git -cd YOLOv6 + +pushd YOLOv6 pip3 install -r requirements.txt # export onnx model python3 deploy/ONNX/export_onnx.py --weights ../yolov6s.pt --img 640 --batch-size 32 --simplify +mv ../yolov6s.onnx ../data/ -cd .. +popd ``` ## Inference -```bash -export DATASETS_DIR=/Path/to/coco/ -``` - ### FP16 ```bash @@ -65,6 +72,12 @@ bash scripts/infer_yolov6s_int8_accuracy.sh bash scripts/infer_yolov6s_int8_performance.sh ``` +## Results + +| Model | BatchSize | Precision | FPS | MAP@0.5 | +| ------ | --------- | --------- | -------- | ------- | +| YOLOv6 | 32 | FP16 | 1107.511 | 0.355 | +| YOLOv6 | 32 | INT8 | 2080.475 | - | ## Reference diff --git a/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_fp16_accuracy.sh b/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_fp16_accuracy.sh index 0360c1a1..09cc0ac0 100644 --- a/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_fp16_accuracy.sh +++ b/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_fp16_accuracy.sh @@ -1,3 +1,4 @@ +#!/bin/bash # Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. # All Rights Reserved. # @@ -12,7 +13,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -#!/bin/bash + EXIT_STATUS=0 check_status() { diff --git a/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_fp16_performance.sh b/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_fp16_performance.sh index 07a103a2..409fd354 100644 --- a/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_fp16_performance.sh +++ b/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_fp16_performance.sh @@ -1,3 +1,4 @@ +#!/bin/bash # Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. # All Rights Reserved. # @@ -12,7 +13,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -#!/bin/bash + EXIT_STATUS=0 check_status() { diff --git a/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_int8_accuracy.sh b/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_int8_accuracy.sh index 3bb4c743..701f80f0 100644 --- a/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_int8_accuracy.sh +++ b/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_int8_accuracy.sh @@ -1,3 +1,4 @@ +#!/bin/bash # Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. # All Rights Reserved. # @@ -12,7 +13,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -#!/bin/bash + EXIT_STATUS=0 check_status() { diff --git a/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_int8_performance.sh b/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_int8_performance.sh index 53ca3397..58f77417 100644 --- a/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_int8_performance.sh +++ b/models/cv/detection/yolov6/ixrt/scripts/infer_yolov6s_int8_performance.sh @@ -1,3 +1,4 @@ +#!/bin/bash # Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. # All Rights Reserved. # @@ -12,7 +13,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -#!/bin/bash + EXIT_STATUS=0 check_status() { diff --git a/models/cv/detection/yolov8/ixrt/README.md b/models/cv/detection/yolov8/ixrt/README.md index c5560b14..07558edf 100644 --- a/models/cv/detection/yolov8/ixrt/README.md +++ b/models/cv/detection/yolov8/ixrt/README.md @@ -13,12 +13,14 @@ Yolov8 combines speed and accuracy in real-time object detection tasks. With a f ## CentOS yum install -y mesa-libGL ## Ubuntu -apt install -y libgl1-mesa-dev +apt install -y libgl1-mesa-glx pip3 install tqdm pip3 install onnx +pip3 install onnxsim pip3 install pycocotools pip3 install ultralytics +pip3 install pycuda ``` ### Download @@ -27,6 +29,14 @@ Pretrained model: to download the validation dataset. +```bash +# get yolov8n.pt +wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n.pt +# set coco path +mkdir -p data/ +ln -s /Path/to/coco/ data/coco +``` + ### Model Conversion ```bash @@ -36,10 +46,6 @@ onnxsim yolov8n.onnx ./data/yolov8n.onnx ## Inference -```bash -export DATASETS_DIR=/Path/to/coco/ -``` - ### FP16 ```bash @@ -57,3 +63,10 @@ bash scripts/infer_yolov8n_int8_accuracy.sh # Performance bash scripts/infer_yolov8n_int8_performance.sh ``` + +## Results + +| Model | BatchSize | Precision | FPS | MAP@0.5 | +| ------ | --------- | --------- | -------- | ------- | +| YOLOv8 | 32 | FP16 | 1511.366 | 0.525 | +| YOLOv8 | 32 | INT8 | 1841.017 | 0.517 | diff --git a/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_fp16_accuracy.sh b/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_fp16_accuracy.sh index 8868533d..44e75376 100644 --- a/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_fp16_accuracy.sh +++ b/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_fp16_accuracy.sh @@ -1,3 +1,4 @@ +#!/bin/bash # Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. # All Rights Reserved. # @@ -12,7 +13,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -#!/bin/bash + EXIT_STATUS=0 check_status() { diff --git a/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_fp16_performance.sh b/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_fp16_performance.sh index b9a28a3a..1ab3808f 100644 --- a/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_fp16_performance.sh +++ b/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_fp16_performance.sh @@ -1,3 +1,4 @@ +#!/bin/bash # Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. # All Rights Reserved. # @@ -12,7 +13,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -#!/bin/bash + EXIT_STATUS=0 check_status() { diff --git a/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_int8_accuracy.sh b/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_int8_accuracy.sh index f3259c58..a2257463 100644 --- a/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_int8_accuracy.sh +++ b/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_int8_accuracy.sh @@ -1,3 +1,4 @@ +#!/bin/bash # Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. # All Rights Reserved. # @@ -12,7 +13,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -#!/bin/bash + EXIT_STATUS=0 check_status() { diff --git a/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_int8_performance.sh b/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_int8_performance.sh index 735035d8..f1774d5b 100644 --- a/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_int8_performance.sh +++ b/models/cv/detection/yolov8/ixrt/scripts/infer_yolov8n_int8_performance.sh @@ -1,3 +1,4 @@ +#!/bin/bash # Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. # All Rights Reserved. # @@ -12,7 +13,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -#!/bin/bash + EXIT_STATUS=0 check_status() { -- Gitee From 2199561f26570783d1b479bf276584f4243abe67 Mon Sep 17 00:00:00 2001 From: majorli Date: Wed, 7 Aug 2024 13:36:29 +0800 Subject: [PATCH 4/7] update transformer results and format Signed-off-by: majorli --- .../transformer_asr/ixrt/README.md | 47 +++++++++++++++---- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/README.md b/models/speech/speech_recognition/transformer_asr/ixrt/README.md index 7560b5eb..0c2e1b45 100644 --- a/models/speech/speech_recognition/transformer_asr/ixrt/README.md +++ b/models/speech/speech_recognition/transformer_asr/ixrt/README.md @@ -1,31 +1,52 @@ -# Asr transformer fp16 inference (BeamSearch) +# Transformer ASR(BeamSearch) ## Description Beam search allows us to exert control over the output of text generation. This is useful because we sometimes know exactly what we want inside the output. For example, in a Neural Machine Translation task, we might know which words must be included in the final translation with a dictionary lookup. - ## Setup ### Install -``` +```bash pip3 install speechbrain==0.5.13 ``` -* ixrt 4.0.1_MR release - ### Download Pretrained model: Dataset: to download the Aishell dataset. -``` +```bash # Make sure the checkpoint path is results/transformer/8886/save mkdir -p results/transformer/8886/save +# The data path like below: +results/transformer/8886 +├── cer.txt +├── dev.csv +├── env.log +├── hyperparams.yaml +├── inference_encoder_ctc.py +├── inference.py +├── log.txt +├── save +│ ├── CKPT+2023-03-29+06-31-40+00 +│ │ ├── brain.ckpt +│ │ ├── CKPT.yaml +│ │ ├── counter.ckpt +│ │ ├── model.ckpt +│ │ ├── noam_scheduler.ckpt +│ │ └── normalizer.ckpt +│ └── tokenizer.ckpt +├── test.csv +├── train.csv +└── train_log.txt + # Make sure the dataset path is results/transformer/8886/save -mkdir -p /home/data/speechbrain +mkdir -p /home/data/speechbrain/aishell/csv_data +ln -s /PATH/to/data_aishell /home/data/speechbrain/aishell/ +cp results/transformer/8886/*.csv /home/data/speechbrain/aishell/csv_data ``` ## Inference @@ -40,7 +61,7 @@ bash build.sh max_batch_size and max_seq_len depend on the situation. -``` +```bash python3 builder.py \ --ckpt_path results/transformer/8886/save \ --head_num 4 \ @@ -51,6 +72,12 @@ python3 builder.py \ ### Run engine -``` +```bash python3 inference.py hparams/train_ASR_transformer.yaml --data_folder=/home/data/speechbrain/aishell --engine_path transformer.engine -``` \ No newline at end of file +``` + +## Results + +| Model | BatchSize | Precision | QPS | CER | +| --------------- | --------- | --------- | ----- | ---- | +| Transformer ASR | 32 | FP16 | 15.64 | 5.95 | -- Gitee From 732bfe4b79ad3a7d0bb02fc82487ebf5695490a2 Mon Sep 17 00:00:00 2001 From: "yanlong.hao" Date: Wed, 31 Jul 2024 20:12:42 +0800 Subject: [PATCH 5/7] add conformer ixrt model. --- README.md | 2 +- .../conformer/ixrt/README.md | 49 ++ .../conformer/ixrt/build_engine.py | 145 +++++ .../conformer/ixrt/common.py | 136 +++++ .../conformer/ixrt/convert2onnx.py | 529 +++++++++++++++++ .../conformer/ixrt/ixrt_inference_accuracy.py | 285 +++++++++ .../ixrt/ixrt_inference_performance.py | 273 +++++++++ .../conformer/ixrt/postprocess/__init__.py | 1 + .../conformer/ixrt/postprocess/search.py | 103 ++++ .../ixrt/scripts/aishell_data_prepare.sh | 61 ++ .../infer_conformer_fp16_accuracy_ixrt.sh | 49 ++ .../infer_conformer_fp16_performance_ixrt.sh | 59 ++ .../conformer/ixrt/tools/__init__.py | 0 .../conformer/ixrt/tools/compute_cer.py | 532 +++++++++++++++++ .../conformer/ixrt/tools/filter_scp.pl | 87 +++ .../conformer/ixrt/tools/make_raw_list.py | 59 ++ .../conformer/ixrt/tools/make_shard_list.py | 181 ++++++ .../conformer/ixrt/tools/text2token.py | 171 ++++++ .../conformer/ixrt/utils/__init__.py | 39 ++ .../conformer/ixrt/utils/embedding.py | 133 +++++ .../conformer/ixrt/wenet/__init__.py | 0 .../conformer/ixrt/wenet/dataset.py | 179 ++++++ .../conformer/ixrt/wenet/file_utils.py | 66 +++ .../conformer/ixrt/wenet/processor.py | 550 ++++++++++++++++++ 24 files changed, 3688 insertions(+), 1 deletion(-) create mode 100644 models/speech/speech_recognition/conformer/ixrt/README.md create mode 100644 models/speech/speech_recognition/conformer/ixrt/build_engine.py create mode 100644 models/speech/speech_recognition/conformer/ixrt/common.py create mode 100644 models/speech/speech_recognition/conformer/ixrt/convert2onnx.py create mode 100644 models/speech/speech_recognition/conformer/ixrt/ixrt_inference_accuracy.py create mode 100644 models/speech/speech_recognition/conformer/ixrt/ixrt_inference_performance.py create mode 100644 models/speech/speech_recognition/conformer/ixrt/postprocess/__init__.py create mode 100644 models/speech/speech_recognition/conformer/ixrt/postprocess/search.py create mode 100755 models/speech/speech_recognition/conformer/ixrt/scripts/aishell_data_prepare.sh create mode 100644 models/speech/speech_recognition/conformer/ixrt/scripts/infer_conformer_fp16_accuracy_ixrt.sh create mode 100644 models/speech/speech_recognition/conformer/ixrt/scripts/infer_conformer_fp16_performance_ixrt.sh create mode 100644 models/speech/speech_recognition/conformer/ixrt/tools/__init__.py create mode 100755 models/speech/speech_recognition/conformer/ixrt/tools/compute_cer.py create mode 100755 models/speech/speech_recognition/conformer/ixrt/tools/filter_scp.pl create mode 100755 models/speech/speech_recognition/conformer/ixrt/tools/make_raw_list.py create mode 100755 models/speech/speech_recognition/conformer/ixrt/tools/make_shard_list.py create mode 100755 models/speech/speech_recognition/conformer/ixrt/tools/text2token.py create mode 100644 models/speech/speech_recognition/conformer/ixrt/utils/__init__.py create mode 100644 models/speech/speech_recognition/conformer/ixrt/utils/embedding.py create mode 100644 models/speech/speech_recognition/conformer/ixrt/wenet/__init__.py create mode 100644 models/speech/speech_recognition/conformer/ixrt/wenet/dataset.py create mode 100644 models/speech/speech_recognition/conformer/ixrt/wenet/file_utils.py create mode 100644 models/speech/speech_recognition/conformer/ixrt/wenet/processor.py diff --git a/README.md b/README.md index 90c697a0..5ad55a16 100644 --- a/README.md +++ b/README.md @@ -746,7 +746,7 @@ DeepSparkInference将按季度进行版本更新,后续会逐步丰富模型 Conformer FP16 Supported - - + Supported INT8 diff --git a/models/speech/speech_recognition/conformer/ixrt/README.md b/models/speech/speech_recognition/conformer/ixrt/README.md new file mode 100644 index 00000000..2d4d98f3 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/README.md @@ -0,0 +1,49 @@ +# Conformer + +## Description + +Conformer is a speech recognition model proposed by Google in 2020. It combines the advantages of CNN and Transformer. CNN efficiently extracts local features, while Transformer is more effective in capturing long sequence dependencies. Conformer applies convolution to the Encoder layer of Transformer, enhancing the performance of Transformer in the ASR (Automatic Speech Recognition) domain. + +## Setup + +### Install + +```bash +pip3 install tqdm +pip3 install onnx +pip3 install typeguard==2.13.3 +pip3 install onnxsim +``` + +### Download + +Pretrained model: + +Dataset: to download the Aishell dataset. + +download and put model in conformer_checkpoints, put data in aishell_test_data. + +### Prepare Data +```bash +# Accuracy +DATA_DIR=./aishell_test_data +Tool_DIR=./tools +bash scripts/aishell_data_prepare.sh ${DATA_DIR} ${Tool_DIR} +``` + +### Model Conversion And Inference + +### FP16 + +```bash +# Accuracy +bash scripts/infer_conformer_fp16_accuracy_ixrt.sh +# Performance +bash scripts/infer_conformer_fp16_performance_ixrt.sh +``` + +## Results + +Model |BatchSize |Precision |QPS |CER | +-----------|-----------|----------|----------|----------| +Conformer | 24 | FP16 | 380.00 | 0.051 | diff --git a/models/speech/speech_recognition/conformer/ixrt/build_engine.py b/models/speech/speech_recognition/conformer/ixrt/build_engine.py new file mode 100644 index 00000000..aa20ee59 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/build_engine.py @@ -0,0 +1,145 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Build Engine From FusionPlugin Onnx. +""" + +import os +import ctypes +import json +import onnx +import logging +import argparse + +import tensorrt +import tensorrt as trt +from tensorrt import Dims + + +TRT_LOGGER = trt.Logger(trt.Logger.WARNING) +def load_ixrt_plugin(logger=trt.Logger(trt.Logger.WARNING), namespace="", dynamic_path=""): + if not dynamic_path: + dynamic_path = os.path.join(os.path.dirname(trt.__file__), "lib", "libixrt_plugin.so") + if not os.path.exists(dynamic_path): + raise FileNotFoundError( + f"The ixrt_plugin lib {dynamic_path} is not existed, please provided effective plugin path!" + ) + ctypes.CDLL(dynamic_path, mode=ctypes.RTLD_GLOBAL) + trt.init_libnvinfer_plugins(logger, namespace) + print(f"Loaded plugin from {dynamic_path}") + +load_ixrt_plugin() + + + +def parse_args(): + parser = argparse.ArgumentParser(description="build tensorrt engine of conformer.", usage="") + parser.add_argument( + "--model_name", + type=str, + required=True, + help="conformer", + ) + parser.add_argument( + "--onnx_path", + type=str, + required=True, + help="onnx_path path to save", + ) + parser.add_argument( + "--engine_path", + type=str, + required=True, + help="engine path to save", + ) + parser.add_argument( + "--max_batch_size", + type=int, + required=True, + ) + parser.add_argument( + "--max_seq_len", + type=int, + required=True, + ) + args = parser.parse_args() + return args + +args = parse_args() +MaxBSZ = args.max_batch_size +MaxSeqLen = args.max_seq_len + + +def build_engine_trtapi_dynamicshape(args): + onnx_model = args.onnx_path + assert os.path.isfile(onnx_model), f"The onnx model{onnx_model} must be existed!" + IXRT_LOGGER = tensorrt.Logger(tensorrt.Logger.WARNING) + builder = tensorrt.Builder(IXRT_LOGGER) + EXPLICIT_BATCH = 1 << (int)(tensorrt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(EXPLICIT_BATCH) + build_config = builder.create_builder_config() + + profile = builder.create_optimization_profile() + profile.set_shape("input", Dims([MaxBSZ, 100, 80]), Dims([MaxBSZ, 1000, 80]), Dims([MaxBSZ, 1500, 80])) + profile.set_shape("mask", Dims([MaxBSZ, 1, 25]), Dims([MaxBSZ, 1, 250]), Dims([MaxBSZ, 1, 374])) + profile.set_shape("pos_emb", Dims([1, 25, 256]), Dims([1, 250, 256]), Dims([1, 374, 256])) + build_config.add_optimization_profile(profile) + + parser = tensorrt.OnnxParser(network, IXRT_LOGGER) + parser.parse_from_file(onnx_model) + build_config.set_flag(tensorrt.BuilderFlag.FP16) + + # set dynamic + # input + input_tensor = network.get_input(0) + input_tensor.shape = Dims([MaxBSZ, -1, 80]) + # mask + mask_tensor = network.get_input(1) + mask_tensor.shape = Dims([MaxBSZ, 1, -1]) + # pos_emb + pos_emb_tensor = network.get_input(2) + pos_emb_tensor.shape = Dims([1, -1, 256]) + + plan = builder.build_serialized_network(network, build_config) + with open(args.engine_path, "wb") as f: + f.write(plan) + + print("Build dynamic shape engine done!") + + +def build_engine_trtapi_staticshape(args): + onnx_model = args.onnx_path + assert os.path.isfile(onnx_model), f"The onnx model{onnx_model} must be existed!" + IXRT_LOGGER = tensorrt.Logger(tensorrt.Logger.WARNING) + builder = tensorrt.Builder(IXRT_LOGGER) + EXPLICIT_BATCH = 1 << (int)(tensorrt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(EXPLICIT_BATCH) + build_config = builder.create_builder_config() + parser = tensorrt.OnnxParser(network, IXRT_LOGGER) + + parser.parse_from_file(onnx_model) + build_config.set_flag(tensorrt.BuilderFlag.FP16) + + plan = builder.build_serialized_network(network, build_config) + with open(args.engine_path, "wb") as f: + f.write(plan) + + print("Build static shape engine done!") + + +if __name__ == "__main__": + build_engine_trtapi_dynamicshape(args) + # build_engine_trtapi_staticshape(args) diff --git a/models/speech/speech_recognition/conformer/ixrt/common.py b/models/speech/speech_recognition/conformer/ixrt/common.py new file mode 100644 index 00000000..89023300 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/common.py @@ -0,0 +1,136 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import ctypes +import cv2 +import glob +import torch +import tensorrt +import tensorrt as trt +import numpy as np +import pycuda.driver as cuda + +from tensorrt.hook.utils import copy_ixrt_io_tensors_as_np + + +TRT_LOGGER = trt.Logger(trt.Logger.WARNING) +def load_ixrt_plugin(logger=trt.Logger(trt.Logger.WARNING), namespace="", dynamic_path=""): + if not dynamic_path: + dynamic_path = os.path.join(os.path.dirname(trt.__file__), "lib", "libixrt_plugin.so") + if not os.path.exists(dynamic_path): + raise FileNotFoundError( + f"The ixrt_plugin lib {dynamic_path} is not existed, please provided effective plugin path!" + ) + ctypes.CDLL(dynamic_path, mode=ctypes.RTLD_GLOBAL) + trt.init_libnvinfer_plugins(logger, namespace) + print(f"Loaded plugin from {dynamic_path}") +load_ixrt_plugin() + + +def trtapi(engine_file): + datatype = tensorrt.DataType.FLOAT + host_mem = tensorrt.IHostMemory + logger = tensorrt.Logger(tensorrt.Logger.ERROR) + with open(engine_file, "rb") as f, tensorrt.Runtime(logger) as runtime: + runtime = tensorrt.Runtime(logger) + assert runtime + engine = runtime.deserialize_cuda_engine(f.read()) + assert engine + context = engine.create_execution_context() + assert context + + return engine, context + + +def create_engine_context(engine_path, logger): + with open(engine_path, "rb") as f: + runtime = tensorrt.Runtime(logger) + assert runtime + engine = runtime.deserialize_cuda_engine(f.read()) + assert engine + context = engine.create_execution_context() + assert context + + return engine, context + + +def get_io_bindings(engine): + # Setup I/O bindings + inputs = [] + outputs = [] + allocations = [] + + for i in range(engine.num_bindings): + is_input = False + if engine.binding_is_input(i): + is_input = True + name = engine.get_binding_name(i) + dtype = engine.get_binding_dtype(i) + shape = engine.get_binding_shape(i) + if is_input: + batch_size = shape[0] + size = np.dtype(tensorrt.nptype(dtype)).itemsize + for s in shape: + size *= s + allocation = cuda.mem_alloc(size) + binding = { + "index": i, + "name": name, + "dtype": np.dtype(tensorrt.nptype(dtype)), + "shape": list(shape), + "allocation": allocation, + } + print(f"binding {i}, name : {name} dtype : {np.dtype(tensorrt.nptype(dtype))} shape : {list(shape)}") + allocations.append(allocation) + if engine.binding_is_input(i): + inputs.append(binding) + else: + outputs.append(binding) + return inputs, outputs, allocations + + +def setup_io_bindings(engine, context): + # Setup I/O bindings + inputs = [] + outputs = [] + allocations = [] + + for i in range(engine.num_bindings): + is_input = False + if engine.binding_is_input(i): + is_input = True + name = engine.get_binding_name(i) + dtype = engine.get_binding_dtype(i) + shape = context.get_binding_shape(i) + if is_input: + batch_size = shape[0] + size = np.dtype(tensorrt.nptype(dtype)).itemsize + for s in shape: + size *= s + allocation = cuda.mem_alloc(size) + binding = { + "index": i, + "name": name, + "dtype": np.dtype(tensorrt.nptype(dtype)), + "shape": list(shape), + "allocation": allocation, + } + allocations.append(allocation) + if engine.binding_is_input(i): + inputs.append(binding) + else: + outputs.append(binding) + return inputs, outputs, allocations diff --git a/models/speech/speech_recognition/conformer/ixrt/convert2onnx.py b/models/speech/speech_recognition/conformer/ixrt/convert2onnx.py new file mode 100644 index 00000000..823ae321 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/convert2onnx.py @@ -0,0 +1,529 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Build Compute Graph(Fusion Plugin Onnx) From Checkpoints. +""" + +import os +import json +import torch +import argparse +import numpy as np +from collections import OrderedDict + +from tensorrt.deploy.api import GraphTransform, create_source, create_target +from tensorrt.deploy.ir.data_type import DataType +from tensorrt.deploy.ir.variable import Variable, VariableOptions +from tensorrt.deploy.ir.graph import Graph + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Build Compute Graph From Checkpoints.", usage="" + ) + parser.add_argument( + "--model_name", + type=str, + required=True, + help="conformer", + ) + parser.add_argument( + "--model_path", + type=str, + required=True, + help="checkpont of conformer", + ) + parser.add_argument( + "--onnx_path", + type=str, + required=True, + help="raw onnx path to save", + ) + parser.add_argument( + "--batch_size", + type=int, + required=True, + help="the batch size for test.", + ) + args = parser.parse_args() + return args + + +def add_global_cmvn_op(graph, state_dict, args): + t = graph + + sub_inputs = [t.make_variable("input", dtype=DataType.FLOAT, shape=(128, 1500, 80))] + key = "encoder.global_cmvn.mean" + sub_inputs.append(t.make_variable(name=key, value=state_dict[key])) + sub_outputs = [t.make_variable("Sub_output_0", dtype=DataType.FLOAT, shape=(128, 1500, 80))] + t.make_operator( + "Sub", + inputs=sub_inputs, + outputs=sub_outputs, + ) + + mul_inputs = sub_outputs + key = "encoder.global_cmvn.istd" + mul_inputs.append(t.make_variable(name=key, value=state_dict[key])) + mul_outputs = [t.make_variable("Mul_output_0", dtype=DataType.FLOAT, shape=(128, 1500, 80))] + t.make_operator( + "Mul", + inputs=mul_inputs, + outputs=mul_outputs, + ) + + unsqueeze_inputs = mul_outputs + unsqueeze_inputs.append(t.make_variable("axes", value=np.array([1], dtype=np.int64))) + unsqueeze_outputs = [t.make_variable("Unsqueeze_output_0", dtype=DataType.FLOAT, shape=(128, 1, 1500, 80))] + t.make_operator( + "Unsqueeze", + inputs=unsqueeze_inputs, + outputs=unsqueeze_outputs, + ) + + +def add_first_submodule_op(graph, state_dict, args): + """ + The firt submodule part contains follows: + 1.Conv2d+ReLU; + 2.Conv2d+ReLU; + 3.Transpose+Reshape; + 4.MatMul+Add+Mul; + """ + + t = graph + conv2d0_weight_keys = [ + "encoder.embed.conv.0.weight", + "encoder.embed.conv.0.bias", + ] + conv2d0_attributes = { + "dilations": [1, 1], + "group": 1, + "kernel_shape": [3, 3], + "pads": [0, 0, 0, 0], + "strides": [2, 2], + } + conv2d0_inputs = [t.get_variable("Unsqueeze_output_0")] + conv2d0_outputs = [t.make_variable("Conv_output_0", dtype=DataType.FLOAT)] + + for key in conv2d0_weight_keys: + conv2d0_inputs.append(t.make_variable(name=key, value=state_dict[key])) + t.make_operator( + "Conv", + inputs=conv2d0_inputs, + outputs=conv2d0_outputs, + **conv2d0_attributes + ) + + relu0_inputs = conv2d0_outputs + relu0_outputs = [t.make_variable("Relu_output_0", dtype=DataType.FLOAT)] + t.make_operator( + "Relu", + inputs=relu0_inputs, + outputs=relu0_outputs + ) + + conv2d1_weight_keys = [ + "encoder.embed.conv.2.weight", + "encoder.embed.conv.2.bias", + ] + conv2d1_attributes = { + "dilations": [1, 1], + "group": 1, + "kernel_shape": [3, 3], + "pads": [0, 0, 0, 0], + "strides": [2, 2], + } + conv2d1_inputs = relu0_outputs + conv2d1_outputs = [t.make_variable("Conv_output_1", dtype=DataType.FLOAT)] + + for key in conv2d1_weight_keys: + conv2d1_inputs.append(t.make_variable(name=key, value=state_dict[key])) + t.make_operator( + "Conv", + inputs=conv2d1_inputs, + outputs=conv2d1_outputs, + **conv2d1_attributes + ) + + relu1_inputs = conv2d1_outputs + relu1_outputs = [t.make_variable("Relu_output_1", dtype=DataType.FLOAT)] + t.make_operator( + "Relu", + inputs=relu1_inputs, + outputs=relu1_outputs + ) + + tran_inputs = relu1_outputs + tran_outputs = [t.make_variable("Transpose_output_0", dtype=DataType.FLOAT)] + tran_attributes = {"perm": [0, 2, 1, 3]} + t.make_operator( + "Transpose", + inputs=tran_inputs, + outputs=tran_outputs, + **tran_attributes + ) + + reshape_inputs = tran_outputs + reshape_inputs.append(t.make_variable(name="constant_0", value=np.array([args.batch_size, -1, 4864]), dtype=DataType.INT64)) + reshape_outputs = [t.make_variable("Reshape_output_0", dtype=DataType.FLOAT)] + t.make_operator( + "Reshape", + inputs=reshape_inputs, + outputs=reshape_outputs, + ) + + matmul_inputs = reshape_outputs + matmul_inputs.append(t.make_variable(name="embed.out.0.weight", value=state_dict["encoder.embed.out.0.weight"].transpose(1, 0))) # (256,4864)--->(4864,256) + matmul_outputs = [t.make_variable("MatMul_output_0", dtype=DataType.FLOAT)] + t.make_operator( + "MatMul", + inputs=matmul_inputs, + outputs=matmul_outputs, + ) + + add_inputs = matmul_outputs + add_inputs.append(t.make_variable(name="embed.out.0.bias", value=state_dict["encoder.embed.out.0.bias"])) + add_outputs = [t.make_variable("Add_output_0", dtype=DataType.FLOAT)] + t.make_operator( + "Add", + inputs=add_inputs, + outputs=add_outputs, + ) + + mul_inputs = add_outputs + mul_inputs.append(t.make_variable(name="constant_1", value=np.array([16.], dtype=np.float32), dtype=DataType.FLOAT)) + mul_outputs = [t.make_variable("Mul_output_1", dtype=DataType.FLOAT)] + t.make_operator( + "Mul", + inputs=mul_inputs, + outputs=mul_outputs, + ) + + +def add_encoder_ff_macaron_op(graph, state_dict, args, index): + + t = graph + ff_macaron_keys = [ + "encoder.encoders.{}.norm_ff_macaron.weight", + "encoder.encoders.{}.norm_ff_macaron.bias", + "encoder.encoders.{}.feed_forward_macaron.w_1.weight", + "encoder.encoders.{}.feed_forward_macaron.w_1.bias", + "encoder.encoders.{}.feed_forward_macaron.w_2.weight", + "encoder.encoders.{}.feed_forward_macaron.w_2.bias", + ] + + attributes = { + "in_feature": 256, + "hidden_size": 2048, + "act_type": 12, + "ff_scale": 0.5, + } + + if index == 0: + inputs = [graph.get_variable("Mul_output_1")] + else: + inputs = [graph.get_variable("norm_final_{}_output".format(index-1))] + + outputs = [t.make_variable("ff_macaron_{}_output".format(index), dtype=DataType.FLOAT)] + + for key in ff_macaron_keys: + key = key.format(index) + inputs.append(t.make_variable(name=key, value=state_dict[key].half(), dtype=DataType.FLOAT16)) + + t.make_operator( + "PositionWiseFFNPluginDynamic_IxRT", + inputs=inputs, + outputs=outputs, + **attributes + ) + + +def add_encoder_mhsa_op(graph, state_dict, args, index): + + t = graph + mhsa_keys = [ + "encoder.encoders.{}.norm_mha.weight", + "encoder.encoders.{}.norm_mha.bias", + "encoder.encoders.{}.self_attn.linear_q.weight", + "encoder.encoders.{}.self_attn.linear_q.bias", + "encoder.encoders.{}.self_attn.linear_k.weight", + "encoder.encoders.{}.self_attn.linear_k.bias", + "encoder.encoders.{}.self_attn.linear_v.weight", + "encoder.encoders.{}.self_attn.linear_v.bias", + "encoder.encoders.{}.self_attn.linear_pos.weight", + "encoder.encoders.{}.self_attn.pos_bias_u", + "encoder.encoders.{}.self_attn.pos_bias_v", + "encoder.encoders.{}.self_attn.linear_out.weight", + "encoder.encoders.{}.self_attn.linear_out.bias", + ] + + attributes = { + "bs": 128, + "seq_len": 374, + "n_head": 4, + "n_feat": 256, + } + + if index == 0: + inputs = [ + graph.get_variable("ff_macaron_{}_output".format(index)), + t.make_variable("mask", dtype=DataType.INT32, shape=(128, 1, 374)), + t.make_variable("pos_emb", dtype=DataType.FLOAT, shape=(1, 374, 256)), + ] + else: + inputs = [ + graph.get_variable("ff_macaron_{}_output".format(index)), + graph.get_variable("mask"), + graph.get_variable("pos_emb"), + ] + + outputs = [t.make_variable("mhsa_{}_output".format(index), dtype=DataType.FLOAT)] + + for key in mhsa_keys: + key = key.format(index) + inputs.append(t.make_variable(name=key, value=state_dict[key].half(), dtype=DataType.FLOAT16)) + + t.make_operator( + "ConformerMultiHeadSelfAttentionPlugin_IxRT", + inputs=inputs, + outputs=outputs, + **attributes + ) + + +def add_encoder_conv_module_op(graph, state_dict, args, index): + + t = graph + conv_module_keys = [ + "encoder.encoders.{}.norm_conv.weight", + "encoder.encoders.{}.norm_conv.bias", + "encoder.encoders.{}.conv_module.pointwise_conv1.weight", + "encoder.encoders.{}.conv_module.pointwise_conv1.bias", + "encoder.encoders.{}.conv_module.depthwise_conv.weight", + "encoder.encoders.{}.conv_module.depthwise_conv.bias", + "encoder.encoders.{}.conv_module.norm.weight", + "encoder.encoders.{}.conv_module.norm.bias", + "encoder.encoders.{}.conv_module.pointwise_conv2.weight", + "encoder.encoders.{}.conv_module.pointwise_conv2.bias", + ] + + attributes = { + "kernel_size_1": 1, + "stride_1": 1, + "odim_1": 512, + "kernel_size_2": 8, + "stride_2": 1, + "odim_2": 256, + "kernel_size_3": 1, + "stride_3": 1, + "odim_3": 256, + } + + inputs = [ + graph.get_variable("mhsa_{}_output".format(index)), + graph.get_variable("mask"), + ] + outputs = [t.make_variable("conv_module_{}_output".format(index), dtype=DataType.FLOAT)] + + for key in conv_module_keys: + key = key.format(index) + + if "conv_module.depthwise_conv.weight" in key: + inputs.append(t.make_variable(name=key, value=state_dict[key].permute(1, 2, 0).half(), dtype=DataType.FLOAT16)) + elif "bias" in key and "norm" not in key: + inputs.append(t.make_variable(name=key, value=state_dict[key], dtype=DataType.FLOAT)) + else: + inputs.append(t.make_variable(name=key, value=state_dict[key].half(), dtype=DataType.FLOAT16)) + + t.make_operator( + "ConformerConvModulePlugin_IxRT", + inputs=inputs, + outputs=outputs, + **attributes + ) + + +def add_encoder_positionwise_ff_op(graph, state_dict, args, index): + + t = graph + positionwise_ff_keys = [ + "encoder.encoders.{}.norm_ff.weight", + "encoder.encoders.{}.norm_ff.bias", + "encoder.encoders.{}.feed_forward.w_1.weight", + "encoder.encoders.{}.feed_forward.w_1.bias", + "encoder.encoders.{}.feed_forward.w_2.weight", + "encoder.encoders.{}.feed_forward.w_2.bias", + ] + + attributes = { + "in_feature": 256, + "hidden_size": 2048, + "act_type": 12, + "ff_scale": 0.5, + } + + inputs = [graph.get_variable('conv_module_{}_output'.format(index))] + outputs = [t.make_variable("positionwise_ff_{}_output".format(index), dtype=DataType.FLOAT)] + + for key in positionwise_ff_keys: + key = key.format(index) + inputs.append(t.make_variable(name=key, value=state_dict[key].half(), dtype=DataType.FLOAT16)) + + t.make_operator( + "PositionWiseFFNPluginDynamic_IxRT", + inputs=inputs, + outputs=outputs, + **attributes + ) + + +def add_encoder_ln_op(graph, state_dict, args, index): + + t = graph + ln_keys = [ + "encoder.encoders.{}.norm_final.weight", + "encoder.encoders.{}.norm_final.bias", + ] + + attributes = { + "axis": -1, + "epsilon": 0.000009999999747378752, + "stash_type": 1, + } + + inputs = [graph.get_variable("positionwise_ff_{}_output".format(index))] + outputs = [t.make_variable("norm_final_{}_output".format(index), dtype=DataType.FLOAT)] + + for key in ln_keys: + key = key.format(index) + inputs.append(t.make_variable(name=key, value=state_dict[key].half(), dtype=DataType.FLOAT16)) + + t.make_operator( + "LayerNormalization", + inputs=inputs, + outputs=outputs, + **attributes + ) + + +def add_final_ln_op(graph, state_dict, args): + + t = graph + ln_keys = [ + "encoder.after_norm.weight", + "encoder.after_norm.bias", + ] + + attributes = { + "axis": -1, + "epsilon": 0.000009999999747378752, + "stash_type": 1, + } + + inputs = [graph.get_variable("norm_final_11_output")] + outputs = [t.make_variable("norm_final_output", dtype=DataType.FLOAT)] + + for key in ln_keys: + inputs.append(t.make_variable(name=key, value=state_dict[key].half(), dtype=DataType.FLOAT16)) + + t.make_operator( + "LayerNormalization", + inputs=inputs, + outputs=outputs, + **attributes + ) + + +def add_ctc_op(graph, state_dict, args): + t = graph + # matmul + matmul_inputs = [graph.get_variable("norm_final_output")] + matmul_inputs.append(t.make_variable(name="ctc.ctc_lo.weight", value=state_dict["ctc.ctc_lo.weight"].transpose(1, 0))) # (4233,256)--->(256,4233) + matmul_outputs = [t.make_variable("MatMul_output_1", dtype=DataType.FLOAT)] + t.make_operator( + "MatMul", + inputs=matmul_inputs, + outputs=matmul_outputs, + ) + + add_inputs = matmul_outputs + add_inputs.append(t.make_variable(name="ctc.ctc_lo.bias", value=state_dict["ctc.ctc_lo.bias"])) + add_outputs = [t.make_variable("Add_output_1", dtype=DataType.FLOAT)] + t.make_operator( + "Add", + inputs=add_inputs, + outputs=add_outputs, + ) + + logsoftmax_inputs = add_outputs + logsoftmax_outputs = [t.make_variable("output", dtype=DataType.FLOAT)] + attributes = { + "axis": 2 + } + t.make_operator( + "LogSoftmax", + inputs=logsoftmax_inputs, + outputs=logsoftmax_outputs, + **attributes + ) + + +def main(args): + graph = Graph() + transform = GraphTransform(graph) + state_dict = torch.load(args.model_path) + + # 0. Global CMVN: sub+mul+unsqueeze + add_global_cmvn_op(transform, state_dict, args) + + # 1. First Submodule: Conv2d+Relu+Transpose+MatMul + add_first_submodule_op(transform, state_dict, args) + + # 2. Second Submodule: ConformerEncoderLayer: 12 layers + for i in range(args.num_layers): + add_encoder_ff_macaron_op(transform, state_dict, args, i) + add_encoder_mhsa_op(transform, state_dict, args, i) + add_encoder_conv_module_op(transform, state_dict, args, i) + add_encoder_positionwise_ff_op(transform, state_dict, args, i) + add_encoder_ln_op(transform, state_dict, args, i) + + # 3. Third Submodule: FinalNorm + add_final_ln_op(transform, state_dict, args) + + # 4.Forth Submodule: CTC+LogSoftmax + add_ctc_op(transform, state_dict, args) + + # 5. set input and output + graph.add_input(graph.get_variable("input")) + graph.add_input(graph.get_variable("mask")) + graph.add_input(graph.get_variable("pos_emb")) + graph.add_output(graph.get_variable("output")) + # 5. export onnx file + create_target(saved_path=args.onnx_path).export(graph) + print("save onnx: ", args.onnx_path) + + +if __name__ == "__main__": + args = parse_args() + model_name = args.model_name.lower() + args.num_layers = 12 + args.hidden_size = 2048 + args.head_num = 4 + args.head_dim = 64 + args.pad_id = 0 + args.inner_size = 3072 + main(args) diff --git a/models/speech/speech_recognition/conformer/ixrt/ixrt_inference_accuracy.py b/models/speech/speech_recognition/conformer/ixrt/ixrt_inference_accuracy.py new file mode 100644 index 00000000..35aad9bb --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/ixrt_inference_accuracy.py @@ -0,0 +1,285 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(__file__))) + +import argparse +import yaml +import copy +import torch +import numpy as np + +from tqdm.contrib import tqdm +from torch.utils.data import DataLoader +from wenet.file_utils import read_symbol_table +from wenet.dataset import Dataset +from tools.compute_cer import Calculator, characterize, normalize, default_cluster +import tensorrt +from tensorrt import Dims +from common import create_engine_context, get_io_bindings,trtapi,setup_io_bindings +import pickle + +import pycuda.autoinit +import pycuda.driver as cuda + +from utils import make_pad_mask, RelPositionalEncoding +from postprocess import ctc_greedy_search + + +rel_positional_encoding = RelPositionalEncoding(256, 0.1) + + +def get_args(): + parser = argparse.ArgumentParser(description="recognize with your model") + parser.add_argument( + "--infer_type", + default="fp16", + choices=["fp16", "int8"], + help="inference type: fp16 or int8", + ) + parser.add_argument("--warm_up", type=int, default=3, help="warm_up count") + parser.add_argument("--batch_size", type=int, default=24) + parser.add_argument("--data_dir", required=True, help="test data directory") + parser.add_argument( + "--model_dir", type=str, required=True, help="model for inference" + ) + args = parser.parse_args() + return args + + +def tensorrt_infer(engine, context, all_inputs): + + input_names = ["input", "mask", "pos_emb"] + output_names = ["output"] + + for input_name, input_data in zip(input_names, all_inputs): + input_idx = engine.get_binding_index(input_name) + input_shape = input_data.shape + context.set_binding_shape(input_idx, Dims(input_shape)) + + inputs, outputs, allocations = setup_io_bindings(engine, context) + pred_output = np.zeros(outputs[0]["shape"], outputs[0]["dtype"]) + + for i, input_data in enumerate(all_inputs): + cuda.memcpy_htod(inputs[i]["allocation"], input_data) + + context.execute_v2(allocations) + cuda.memcpy_dtoh(pred_output, outputs[0]["allocation"]) + return pred_output + + +def engine_init(engine): + host_mem = tensorrt.IHostMemory + logger = tensorrt.Logger(tensorrt.Logger.ERROR) + engine, context = create_engine_context(engine, logger) + + return engine,context + + +def calculate_cer(data, reference_data): + calculator = Calculator() + tochar = True + split = None + case_sensitive = False + ignore_words = set() + rec_set = {} + for line in data: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: + continue + fid = array[0] + rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split) + + default_clusters = {} + default_words = {} + for line in reference_data: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: + continue + fid = array[0] + if fid not in rec_set: + continue + lab = normalize(array[1:], ignore_words, case_sensitive, split) + rec = rec_set[fid] + + for word in rec + lab: + if word not in default_words: + default_cluster_name = default_cluster(word) + if default_cluster_name not in default_clusters: + default_clusters[default_cluster_name] = {} + if word not in default_clusters[default_cluster_name]: + default_clusters[default_cluster_name][word] = 1 + default_words[word] = default_cluster_name + result = calculator.calculate(lab, rec) + + result = calculator.overall() + cer = float(result["ins"] + result["sub"] + result["del"]) / result["all"] + corr = result["cor"] / result["all"] + + return cer, corr + + +def main(): + args = get_args() + + # 读取配置文件 + config_fn = os.path.join(args.model_dir, "config.yaml") + with open(config_fn, "r") as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + dataset_conf = copy.deepcopy(configs["dataset_conf"]) + dataset_conf["filter_conf"]["max_length"] = 102400 + dataset_conf["filter_conf"]["min_length"] = 0 + dataset_conf["filter_conf"]["token_max_length"] = 102400 + dataset_conf["filter_conf"]["token_min_length"] = 0 + dataset_conf["filter_conf"]["max_output_input_ratio"] = 102400 + dataset_conf["filter_conf"]["min_output_input_ratio"] = 0 + dataset_conf["speed_perturb"] = False + dataset_conf["spec_aug"] = False + dataset_conf["shuffle"] = False + dataset_conf["sort"] = True + dataset_conf["fbank_conf"]["dither"] = 0.0 + dataset_conf["batch_conf"]["batch_type"] = "static" + dataset_conf["batch_conf"]["batch_size"] = args.batch_size + + # Load dict + dict_fn = os.path.join(args.model_dir, "words.txt") + char_dict = {} + with open(dict_fn, "r", encoding="utf8") as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + char_dict[int(arr[1])] = arr[0] + eos = len(char_dict) - 1 + + data_type = "raw" + test_data_fn = os.path.join(args.data_dir, "data.list") + symbol_table = read_symbol_table(dict_fn) + test_dataset = Dataset( + data_type, test_data_fn, symbol_table, dataset_conf, partition=False + ) + test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) + + data_path_pkl = os.path.join(args.data_dir, f"aishell_test_data_bs{args.batch_size}.pkl") + + print("*** 1. Prepare data ***") + if not os.path.isfile(data_path_pkl): + eval_samples = [] + max_batch_size = -1 + max_feature_length = -1 + for batch in test_data_loader: + keys, feats, target, feats_lengths, target_lengths = batch + max_feature_length = max(max_feature_length, feats.size(1)) + max_batch_size = max(max_batch_size, feats.size(0)) + eval_samples.append( + [ + keys, + feats.cpu().numpy().astype(np.float16), + feats_lengths.cpu().numpy().astype(np.int32), + ] + ) + with open(data_path_pkl, "wb") as f: + pickle.dump( + [ + eval_samples, + max_batch_size, + max_feature_length + ], + f, + ) + else: + print(f"load data from tmp: {data_path_pkl}") + with open(data_path_pkl, "rb") as f: + ( + eval_samples, + max_batch_size, + max_feature_length + ) = pickle.load(f) + print( + f"dataset max shape: batch_size: {max_batch_size}, feat_length: {max_feature_length}" + ) + + print("*** 2. Load engine ***") + engine_path = os.path.join(args.model_dir, f"conformer_encoder_fusion.engine") + engine, context = engine_init(engine_path) + + print("*** 3. Warm up ***") + if args.warm_up > 0: + for i in range(args.warm_up): + feats_tmp = np.ones((args.batch_size, 1500, 80)).astype(np.float32) + feats_lengths_tmp = np.ones((args.batch_size)).astype(np.int32) * 1500 + mask_tmp = make_pad_mask(feats_lengths_tmp, 1500) + mask_len_tmp = mask_tmp.shape[-1] + pos_emb_tmp = rel_positional_encoding(mask_len_tmp).numpy() + all_inputs = [feats_tmp, mask_tmp, pos_emb_tmp] + tensorrt_infer(engine, context, all_inputs) + + results = [] + for keys, feats, feats_lengths in tqdm(eval_samples): + b, seq_len, feat = feats.shape + + inputs = feats.astype(np.float32) + mask = make_pad_mask(feats_lengths, seq_len) + mask_len = mask.shape[-1] + pos_emb = rel_positional_encoding(mask_len).numpy() + + all_inputs = [inputs, mask, pos_emb] + hyps = tensorrt_infer( + engine, + context, + all_inputs + ) + + ctc_probs = torch.from_numpy(hyps) + ctc_lens = torch.from_numpy(feats_lengths) + hyps = ctc_greedy_search(ctc_probs, ctc_lens) + + for i, key in enumerate(keys): + line = f"{key} " + for w in hyps[i]: + w = w - 1 + if w == eos: + break + line += char_dict[w] + results.append(line) + + # 3. 计算 CER + reference_file = os.path.join(args.data_dir, "text") + reference_data = [] + for line in open(reference_file, "r", encoding="utf-8"): + reference_data.append(line) + + cer, corr = calculate_cer(results, reference_data) + target_cer = float(os.environ["Accuracy"]) + print("CER: ", cer, "target CER: ", target_cer) + if cer <= target_cer: + print("pass!") + exit() + else: + print("failed!") + exit(1) + + +if __name__ == "__main__": + main() diff --git a/models/speech/speech_recognition/conformer/ixrt/ixrt_inference_performance.py b/models/speech/speech_recognition/conformer/ixrt/ixrt_inference_performance.py new file mode 100644 index 00000000..c19233fa --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/ixrt_inference_performance.py @@ -0,0 +1,273 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import sys +import time + +sys.path.append(os.path.dirname(os.path.dirname(__file__))) + +import argparse +import yaml +import copy +import torch +import numpy as np + +from tqdm.contrib import tqdm +from torch.utils.data import DataLoader +from wenet.file_utils import read_symbol_table +from wenet.dataset import Dataset +from tools.compute_cer import Calculator, characterize, normalize, default_cluster +import tensorrt +from tensorrt import Dims +from common import create_engine_context, get_io_bindings,trtapi,setup_io_bindings +import pickle + +import pycuda.autoinit +import pycuda.driver as cuda + +from utils import make_pad_mask, RelPositionalEncoding +from postprocess import ctc_greedy_search + + +rel_positional_encoding = RelPositionalEncoding(256, 0.1) + + +def get_args(): + parser = argparse.ArgumentParser(description="recognize with your model") + parser.add_argument( + "--infer_type", + default="fp16", + choices=["fp16", "int8"], + help="inference type: fp16 or int8", + ) + parser.add_argument("--warm_up", type=int, default=3, help="warm_up count") + parser.add_argument("--batch_size", type=int, default=24) + parser.add_argument("--data_dir", required=True, help="test data directory") + parser.add_argument( + "--model_dir", type=str, required=True, help="model for inference" + ) + args = parser.parse_args() + return args + + +def tensorrt_infer(engine, context, all_inputs): + + input_names = ["input", "mask", "pos_emb"] + output_names = ["output"] + + for input_name, input_data in zip(input_names, all_inputs): + input_idx = engine.get_binding_index(input_name) + input_shape = input_data.shape + context.set_binding_shape(input_idx, Dims(input_shape)) + + inputs, outputs, allocations = setup_io_bindings(engine, context) + pred_output = np.zeros(outputs[0]["shape"], outputs[0]["dtype"]) + + for i, input_data in enumerate(all_inputs): + cuda.memcpy_htod(inputs[i]["allocation"], input_data) + + context.execute_v2(allocations) + cuda.memcpy_dtoh(pred_output, outputs[0]["allocation"]) + return pred_output + + +def engine_init(engine): + host_mem = tensorrt.IHostMemory + logger = tensorrt.Logger(tensorrt.Logger.ERROR) + engine, context = create_engine_context(engine, logger) + + return engine,context + + +def calculate_cer(data, reference_data): + calculator = Calculator() + tochar = True + split = None + case_sensitive = False + ignore_words = set() + rec_set = {} + for line in data: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: + continue + fid = array[0] + rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split) + + default_clusters = {} + default_words = {} + for line in reference_data: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: + continue + fid = array[0] + if fid not in rec_set: + continue + lab = normalize(array[1:], ignore_words, case_sensitive, split) + rec = rec_set[fid] + + for word in rec + lab: + if word not in default_words: + default_cluster_name = default_cluster(word) + if default_cluster_name not in default_clusters: + default_clusters[default_cluster_name] = {} + if word not in default_clusters[default_cluster_name]: + default_clusters[default_cluster_name][word] = 1 + default_words[word] = default_cluster_name + result = calculator.calculate(lab, rec) + + result = calculator.overall() + cer = float(result["ins"] + result["sub"] + result["del"]) / result["all"] + corr = result["cor"] / result["all"] + + return cer, corr + + +def main(): + args = get_args() + + # 读取配置文件 + config_fn = os.path.join(args.model_dir, "config.yaml") + with open(config_fn, "r") as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + dataset_conf = copy.deepcopy(configs["dataset_conf"]) + dataset_conf["filter_conf"]["max_length"] = 102400 + dataset_conf["filter_conf"]["min_length"] = 0 + dataset_conf["filter_conf"]["token_max_length"] = 102400 + dataset_conf["filter_conf"]["token_min_length"] = 0 + dataset_conf["filter_conf"]["max_output_input_ratio"] = 102400 + dataset_conf["filter_conf"]["min_output_input_ratio"] = 0 + dataset_conf["speed_perturb"] = False + dataset_conf["spec_aug"] = False + dataset_conf["shuffle"] = False + dataset_conf["sort"] = True + dataset_conf["fbank_conf"]["dither"] = 0.0 + dataset_conf["batch_conf"]["batch_type"] = "static" + dataset_conf["batch_conf"]["batch_size"] = args.batch_size + + # Load dict + dict_fn = os.path.join(args.model_dir, "words.txt") + char_dict = {} + with open(dict_fn, "r", encoding="utf8") as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + char_dict[int(arr[1])] = arr[0] + eos = len(char_dict) - 1 + + data_type = "raw" + test_data_fn = os.path.join(args.data_dir, "data.list") + symbol_table = read_symbol_table(dict_fn) + test_dataset = Dataset( + data_type, test_data_fn, symbol_table, dataset_conf, partition=False + ) + test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) + + data_path_pkl = os.path.join(args.data_dir, f"aishell_test_data_bs{args.batch_size}.pkl") + + print("*** 1. Prepare data ***") + if not os.path.isfile(data_path_pkl): + eval_samples = [] + max_batch_size = -1 + max_feature_length = -1 + for batch in test_data_loader: + keys, feats, target, feats_lengths, target_lengths = batch + max_feature_length = max(max_feature_length, feats.size(1)) + max_batch_size = max(max_batch_size, feats.size(0)) + eval_samples.append( + [ + keys, + feats.cpu().numpy().astype(np.float16), + feats_lengths.cpu().numpy().astype(np.int32), + ] + ) + with open(data_path_pkl, "wb") as f: + pickle.dump( + [ + eval_samples, + max_batch_size, + max_feature_length + ], + f, + ) + else: + print(f"load data from tmp: {data_path_pkl}") + with open(data_path_pkl, "rb") as f: + ( + eval_samples, + max_batch_size, + max_feature_length + ) = pickle.load(f) + print( + f"dataset max shape: batch_size: {max_batch_size}, feat_length: {max_feature_length}" + ) + + print("*** 2. Load engine ***") + engine_path = os.path.join(args.model_dir, f"conformer_encoder_fusion.engine") + engine, context = engine_init(engine_path) + + print("*** 3. Warm up ***") + if args.warm_up > 0: + for i in range(args.warm_up): + feats_tmp = np.ones((args.batch_size, 1500, 80)).astype(np.float32) + feats_lengths_tmp = np.ones((args.batch_size)).astype(np.int32) * 1500 + mask_tmp = make_pad_mask(feats_lengths_tmp, 1500) + mask_len_tmp = mask_tmp.shape[-1] + pos_emb_tmp = rel_positional_encoding(mask_len_tmp).numpy() + all_inputs = [feats_tmp, mask_tmp, pos_emb_tmp] + tensorrt_infer(engine, context, all_inputs) + + print("*** 4. Inference ***") + start_time = time.time() + num_samples = 0 + results = [] + for keys, feats, feats_lengths in tqdm(eval_samples): + b, seq_len, feat = feats.shape + num_samples += b + inputs = feats.astype(np.float32) + mask = make_pad_mask(feats_lengths, seq_len) + mask_len = mask.shape[-1] + pos_emb = rel_positional_encoding(mask_len).numpy() + + all_inputs = [inputs, mask, pos_emb] + hyps = tensorrt_infer( + engine, + context, + all_inputs + ) + + eval_time = time.time() - start_time + + QPS = num_samples / eval_time + print(f"Recognize {num_samples} sentences, {QPS} sentences/s") + target_qps = float(os.environ["Accuracy"]) + print("QPS: = ", QPS, "target QPS: ", target_qps) + if QPS >= target_qps: + print("pass!") + exit() + else: + print("failed!") + exit(1) + + +if __name__ == "__main__": + main() diff --git a/models/speech/speech_recognition/conformer/ixrt/postprocess/__init__.py b/models/speech/speech_recognition/conformer/ixrt/postprocess/__init__.py new file mode 100644 index 00000000..33f8b046 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/postprocess/__init__.py @@ -0,0 +1 @@ +from .search import ctc_greedy_search diff --git a/models/speech/speech_recognition/conformer/ixrt/postprocess/search.py b/models/speech/speech_recognition/conformer/ixrt/postprocess/search.py new file mode 100644 index 00000000..d2ae5565 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/postprocess/search.py @@ -0,0 +1,103 @@ +import math +from collections import defaultdict +from typing import List, Dict + +import torch +from torch.nn.utils.rnn import pad_sequence + + +def remove_duplicates_and_blank(hyp: List[int], + blank_id: int = 0) -> List[int]: + new_hyp: List[int] = [] + cur = 0 + while cur < len(hyp): + if hyp[cur] != blank_id: + new_hyp.append(hyp[cur]) + prev = cur + while cur < len(hyp) and hyp[cur] == hyp[prev]: + cur += 1 + return new_hyp + + +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. + + Args: + lengths (torch.Tensor): Batch of lengths (B,). + Returns: + torch.Tensor: Mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + batch_size = lengths.size(0) + max_len = max_len if max_len > 0 else lengths.max().item() + seq_range = torch.arange(0, + max_len, + dtype=torch.int64, + device=lengths.device) + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + mask = mask[:, 2::2][:, 2::2] + return mask + + +class DecodeResult: + + def __init__(self, + tokens: List[int], + score: float = 0.0, + confidence: float = 0.0, + tokens_confidence: List[float] = None, + times: List[int] = None, + nbest: List[List[int]] = None, + nbest_scores: List[float] = None, + nbest_times: List[List[int]] = None): + """ + Args: + tokens: decode token list + score: the total decode score of this result + confidence: the total confidence of this result, it's in 0~1 + tokens_confidence: confidence of each token + times: timestamp of each token, list of (start, end) + nbest: nbest result + nbest_scores: score of each nbest + nbest_times: + """ + self.tokens = tokens + self.score = score + self.confidence = confidence + self.tokens_confidence = tokens_confidence + self.times = times + self.nbest = nbest + self.nbest_scores = nbest_scores + self.nbest_times = nbest_times + + +def ctc_greedy_search(ctc_probs: torch.Tensor, + ctc_lens: torch.Tensor, + blank_id: int = 0) -> List[DecodeResult]: + + batch_size = ctc_probs.shape[0] + maxlen = ctc_probs.size(1) + topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) + topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) + + mask_ctc_lens = ctc_lens[0].item() + mask = make_pad_mask(ctc_lens, mask_ctc_lens) # (B, maxlen) + topk_index = topk_index.masked_fill_(mask, blank_id) # (B, maxlen) + hyps = [hyp.tolist() for hyp in topk_index] + scores = topk_prob.max(1) + results = [] + for hyp in hyps: + results.append(remove_duplicates_and_blank(hyp, blank_id)) + return results + diff --git a/models/speech/speech_recognition/conformer/ixrt/scripts/aishell_data_prepare.sh b/models/speech/speech_recognition/conformer/ixrt/scripts/aishell_data_prepare.sh new file mode 100755 index 00000000..985564c2 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/scripts/aishell_data_prepare.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Copyright 2019 Mobvoi Inc. All Rights Reserved. +# set -euox pipefail + +data_dir=$1 +tool_dir=$2 + +wav_dir=${data_dir}/wav +aishell_text=${data_dir}/transcript/aishell_transcript_v0.8.txt + +# data directory check +if [ ! -d $wav_dir ] || [ ! -f $aishell_text ]; then + echo "Error: wav directory and aishell text not found!" + exit 1; +fi + +# find test wav file +local_dir=${data_dir}/local +mkdir -p $local_dir +find $wav_dir -iname "*.wav" > $local_dir/wav.flist || exit 1; + +# Transcriptions preparation +sed -e 's/\.wav//' $local_dir/wav.flist | awk -F '/' '{print $NF}' > $local_dir/utt.list +paste -d' ' $local_dir/utt.list $local_dir/wav.flist > $local_dir/wav.scp_all +${tool_dir}/filter_scp.pl -f 1 $local_dir/utt.list $aishell_text > $local_dir/transcripts.txt +awk '{print $1}' $local_dir/transcripts.txt > $local_dir/utt.list +${tool_dir}/filter_scp.pl -f 1 $local_dir/utt.list $local_dir/wav.scp_all | sort -u > $local_dir/wav.scp +sort -u $local_dir/transcripts.txt > $local_dir/text +echo "Preparing transcriptions succeeded!" + +test_dir=${data_dir}/test +mkdir -p ${test_dir} +for f in wav.scp text; do + cp $local_dir/$f ${test_dir}/$f || exit 1; +done +rm -r ${data_dir}/local + +# data_type can be `raw` or `shard`. Typically, raw is used for small dataset, +# `shard` is used for large dataset which is over 1k hours, and `shard` is +# faster on reading data and training. +data_type=raw +num_utts_per_shard=1000 + +# remove the space between the text labels for Mandarin dataset +cp $test_dir/text $test_dir/text.org +paste -d " " <(cut -f 1 -d" " ${test_dir}/text.org) \ + <(cut -f 2- -d" " ${test_dir}/text.org | tr -d " ") \ + > ${test_dir}/text +rm ${test_dir}/text.org + +# Prepare required format +if [ $data_type == "shard" ]; then + ${tool_dir}/make_shard_list.py --num_utts_per_shard $num_utts_per_shard \ + --num_threads 16 $test_dir/wav.scp $test_dir/text \ + $(realpath $test_dir/shards) $test_dir/data.list +else + ${tool_dir}/make_raw_list.py $test_dir/wav.scp $test_dir/text \ + $test_dir/data.list +fi + +echo "AISHELL data preparation succeeded!" \ No newline at end of file diff --git a/models/speech/speech_recognition/conformer/ixrt/scripts/infer_conformer_fp16_accuracy_ixrt.sh b/models/speech/speech_recognition/conformer/ixrt/scripts/infer_conformer_fp16_accuracy_ixrt.sh new file mode 100644 index 00000000..f1af4bb4 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/scripts/infer_conformer_fp16_accuracy_ixrt.sh @@ -0,0 +1,49 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +set -euo pipefail + +current_path=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) + +PROJECT_DIR=${current_path}/.. +DATA_DIR=${current_path}/../aishell_test_data/test +MODEL_DIR=${current_path}/../conformer_checkpoints + +export Accuracy=${Accuracy:=0.052} + +cd ${PROJECT_DIR} + +echo "Step1.Export Onnx From Checkpoints!" +python3 convert2onnx.py \ + --model_name "Conformer" \ + --model_path=${MODEL_DIR}/final.pt \ + --onnx_path=${MODEL_DIR}/conformer_encoder_fusion.onnx \ + --batch_size=8 + +echo "Step2.Build Engine!" +python3 build_engine.py \ + --model_name "Conformer" \ + --onnx_path=${MODEL_DIR}/conformer_encoder_fusion.onnx \ + --engine_path=${MODEL_DIR}/conformer_encoder_fusion.engine \ + --max_batch_size=8 \ + --max_seq_len=1500 + +echo "Step3.Inference(Test ACC)!" +python3 ixrt_inference_accuracy.py \ + --infer_type fp16 \ + --warm_up 3 \ + --batch_size ${BATCH_SIZE:=8} \ + --data_dir ${DATA_DIR} \ + --model_dir ${MODEL_DIR} diff --git a/models/speech/speech_recognition/conformer/ixrt/scripts/infer_conformer_fp16_performance_ixrt.sh b/models/speech/speech_recognition/conformer/ixrt/scripts/infer_conformer_fp16_performance_ixrt.sh new file mode 100644 index 00000000..dc02673c --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/scripts/infer_conformer_fp16_performance_ixrt.sh @@ -0,0 +1,59 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +set -euo pipefail + + +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + echo "fails" + EXIT_STATUS=1 + fi +} + +current_path=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) + +PROJECT_DIR=${current_path}/.. +DATA_DIR=${current_path}/../aishell_test_data/test +MODEL_DIR=${current_path}/../conformer_checkpoints + +export Accuracy=${Accuracy:=350} + +cd ${PROJECT_DIR} + + +echo "Step1.Export Onnx From Checkpoints!" +python3 convert2onnx.py \ + --model_name "Conformer" \ + --model_path=${MODEL_DIR}/final.pt \ + --onnx_path=${MODEL_DIR}/conformer_encoder_fusion.onnx \ + --batch_size=24 + +echo "Step2.Build Engine!" +python3 build_engine.py \ + --model_name "Conformer" \ + --onnx_path=${MODEL_DIR}/conformer_encoder_fusion.onnx \ + --engine_path=${MODEL_DIR}/conformer_encoder_fusion.engine \ + --max_batch_size=24 \ + --max_seq_len=1500 + +echo "Step3.Inference(Test QPS)!" +python3 ixrt_inference_performance.py \ + --infer_type fp16 \ + --batch_size ${BATCH_SIZE:=24} \ + --data_dir ${DATA_DIR} \ + --model_dir ${MODEL_DIR} diff --git a/models/speech/speech_recognition/conformer/ixrt/tools/__init__.py b/models/speech/speech_recognition/conformer/ixrt/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/speech/speech_recognition/conformer/ixrt/tools/compute_cer.py b/models/speech/speech_recognition/conformer/ixrt/tools/compute_cer.py new file mode 100755 index 00000000..a5db0897 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/tools/compute_cer.py @@ -0,0 +1,532 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +import sys +import unicodedata +import codecs + +remove_tag = True +spacelist = [' ', '\t', '\r', '\n'] +puncts = ['!', ',', '?', + '、', '。', '!', ',', ';', '?', + ':', '「', '」', '︰', '『', '』', '《', '》'] + +def characterize(string) : + res = [] + i = 0 + while i < len(string): + char = string[i] + if char in puncts: + i += 1 + continue + cat1 = unicodedata.category(char) + # https://unicodebook.readthedocs.io/unicode.html#unicode-categories + if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned + i += 1 + continue + if cat1 == 'Lo': # letter-other + res.append(char) + i += 1 + else: + # some input looks like: , we want to separate it to two words. + sep = ' ' + if char == '<': + sep = '>' + j = i + 1 + while j < len(string): + c = string[j] + if ord(c) >= 128 or (c in spacelist) or (c == sep): + break + j += 1 + if j < len(string) and string[j] == '>': + j += 1 + res.append(string[i:j]) + i = j + return res + +def stripoff_tags(x): + if not x: + return '' + chars = [] + i = 0 + T = len(x) + while i < T: + if x[i] == '<': + while i < T and x[i] != '>': + i += 1 + i += 1 + else: + chars.append(x[i]) + i += 1 + return ''.join(chars) + + +def normalize(sentence, ignore_words, cs, split=None): + """ sentence, ignore_words are both in unicode + """ + new_sentence = [] + for token in sentence: + x = token + if not cs: + x = x.upper() + if x in ignore_words: + continue + if remove_tag: + x = stripoff_tags(x) + if not x: + continue + if split and x in split: + new_sentence += split[x] + if x.isalnum(): + for k in x: + new_sentence.append(k) + else: + new_sentence.append(x) + return new_sentence + +class Calculator : + def __init__(self) : + self.data = {} + self.space = [] + self.cost = {} + self.cost['cor'] = 0 + self.cost['sub'] = 1 + self.cost['del'] = 1 + self.cost['ins'] = 1 + + def calculate(self, lab, rec) : + # Initialization + lab.insert(0, '') + rec.insert(0, '') + while len(self.space) < len(lab) : + self.space.append([]) + for row in self.space : + for element in row : + element['dist'] = 0 + element['error'] = 'non' + while len(row) < len(rec) : + row.append({'dist' : 0, 'error' : 'non'}) + for i in range(len(lab)) : + self.space[i][0]['dist'] = i + self.space[i][0]['error'] = 'del' + for j in range(len(rec)) : + self.space[0][j]['dist'] = j + self.space[0][j]['error'] = 'ins' + self.space[0][0]['error'] = 'non' + for token in lab : + if token not in self.data and len(token) > 0 : + self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, + 'ins' : 0, 'del' : 0} + for token in rec : + if token not in self.data and len(token) > 0 : + self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, + 'ins' : 0, 'del' : 0} + # Computing edit distance + for i, lab_token in enumerate(lab) : + for j, rec_token in enumerate(rec) : + if i == 0 or j == 0 : + continue + min_dist = sys.maxsize + min_error = 'none' + dist = self.space[i - 1][j]['dist'] + self.cost['del'] + error = 'del' + if dist < min_dist : + min_dist = dist + min_error = error + dist = self.space[i][j - 1]['dist'] + self.cost['ins'] + error = 'ins' + if dist < min_dist : + min_dist = dist + min_error = error + if lab_token == rec_token : + dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor'] + error = 'cor' + else : + dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub'] + error = 'sub' + if dist < min_dist : + min_dist = dist + min_error = error + self.space[i][j]['dist'] = min_dist + self.space[i][j]['error'] = min_error + # Tracing back + result = {'lab': [], 'rec': [], 'all': 0, 'cor': 0, 'sub': 0, + 'ins': 0, 'del': 0} + i = len(lab) - 1 + j = len(rec) - 1 + while True : + if self.space[i][j]['error'] == 'cor' : # correct + if len(lab[i]) > 0 : + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 + result['all'] = result['all'] + 1 + result['cor'] = result['cor'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'sub' : # substitution + if len(lab[i]) > 0 : + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 + result['all'] = result['all'] + 1 + result['sub'] = result['sub'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'del' : # deletion + if len(lab[i]) > 0 : + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 + result['all'] = result['all'] + 1 + result['del'] = result['del'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, "") + i = i - 1 + elif self.space[i][j]['error'] == 'ins' : # insertion + if len(rec[j]) > 0 : + self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 + result['ins'] = result['ins'] + 1 + result['lab'].insert(0, "") + result['rec'].insert(0, rec[j]) + j = j - 1 + elif self.space[i][j]['error'] == 'non' : # starting point + break + else : # shouldn't reach here + print('this should not happen , i={i} , j={j} , \ + error={error}'. + format(i=i, j=j, error=self.space[i][j]['error'])) + return result + + def overall(self) : + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in self.data : + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def cluster(self, data) : + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in data : + if token in self.data : + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def keys(self) : + return list(self.data.keys()) + +def width(string): + return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) + +def default_cluster(word) : + unicode_names = [unicodedata.name(char) for char in word] + for i in reversed(range(len(unicode_names))) : + if unicode_names[i].startswith('DIGIT') : # 1 + unicode_names[i] = 'Number' # 'DIGIT' + elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or + unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')) : + # 明 / 郎 + unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' + elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or + unicode_names[i].startswith('LATIN SMALL LETTER')) : + # A / a + unicode_names[i] = 'English' # 'LATIN LETTER' + elif unicode_names[i].startswith('HIRAGANA LETTER') : # は こ め + unicode_names[i] = 'Japanese' # 'GANA LETTER' + elif (unicode_names[i].startswith('AMPERSAND') or + unicode_names[i].startswith('APOSTROPHE') or + unicode_names[i].startswith('COMMERCIAL AT') or + unicode_names[i].startswith('DEGREE CELSIUS') or + unicode_names[i].startswith('EQUALS SIGN') or + unicode_names[i].startswith('FULL STOP') or + unicode_names[i].startswith('HYPHEN-MINUS') or + unicode_names[i].startswith('LOW LINE') or + unicode_names[i].startswith('NUMBER SIGN') or + unicode_names[i].startswith('PLUS SIGN') or + unicode_names[i].startswith('SEMICOLON')) : + # & / ' / @ / ℃ / = / . / - / _ / # / + / ; + del unicode_names[i] + else : + return 'Other' + if len(unicode_names) == 0 : + return 'Other' + if len(unicode_names) == 1 : + return unicode_names[0] + for i in range(len(unicode_names) - 1) : + if unicode_names[i] != unicode_names[i + 1] : + return 'Other' + return unicode_names[0] + +def usage() : + print("compute-wer.py : compute word error rate (WER) \ + and align recognition results and references.") + print(" usage : python compute-wer.py [--cs={0,1}] \ + [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] \ + [--padding-symbol={space,underline}] test.ref test.hyp > test.wer") + +if __name__ == '__main__': + if len(sys.argv) == 1 : + usage() + sys.exit(0) + calculator = Calculator() + cluster_file = '' + ignore_words = set() + tochar = False + verbose = 1 + padding_symbol = ' ' + case_sensitive = False + max_words_per_line = sys.maxsize + split = None + while len(sys.argv) > 3: + a = '--maxw=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):] + del sys.argv[1] + max_words_per_line = int(b) + continue + a = '--rt=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + remove_tag = (b == 'true') or (b != '0') + continue + a = '--cs=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + case_sensitive = (b == 'true') or (b != '0') + continue + a = '--cluster=' + if sys.argv[1].startswith(a): + cluster_file = sys.argv[1][len(a):] + del sys.argv[1] + continue + a = '--splitfile=' + if sys.argv[1].startswith(a): + split_file = sys.argv[1][len(a):] + del sys.argv[1] + split = dict() + with codecs.open(split_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + words = line.strip().split() + if len(words) >= 2: + split[words[0]] = words[1:] + continue + a = '--ig=' + if sys.argv[1].startswith(a): + ignore_file = sys.argv[1][len(a):] + del sys.argv[1] + with codecs.open(ignore_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + line = line.strip() + if len(line) > 0: + ignore_words.add(line) + continue + a = '--char=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + tochar = (b == 'true') or (b != '0') + continue + a = '--v=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + verbose = 0 + try: + verbose = int(b) + except Exception: + if b == 'true' or b != '0': + verbose = 1 + continue + a = '--padding-symbol=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + if b == 'space': + padding_symbol = ' ' + elif b == 'underline': + padding_symbol = '_' + continue + if True or sys.argv[1].startswith('-'): + # ignore invalid switch + del sys.argv[1] + continue + + if not case_sensitive: + ig = set([w.upper() for w in ignore_words]) + ignore_words = ig + + default_clusters = {} + default_words = {} + + ref_file = sys.argv[1] + hyp_file = sys.argv[2] + rec_set = {} + if split and not case_sensitive: + newsplit = dict() + for w in split: + words = split[w] + for i in range(len(words)): + words[i] = words[i].upper() + newsplit[w.upper()] = words + split = newsplit + + with codecs.open(hyp_file, 'r', 'utf-8') as fh: + for line in fh: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: + continue + fid = array[0] + rec_set[fid] = normalize(array[1:], ignore_words, + case_sensitive, split) + + # compute error rate on the interaction of reference file and hyp file + for line in open(ref_file, 'r', encoding='utf-8') : + if tochar: + array = characterize(line) + else: + array = line.rstrip('\n').split() + if len(array) == 0: + continue + fid = array[0] + if fid not in rec_set: + continue + lab = normalize(array[1:], ignore_words, case_sensitive, split) + rec = rec_set[fid] + if verbose: + print('\nutt: %s' % fid) + + for word in rec + lab : + if word not in default_words : + default_cluster_name = default_cluster(word) + if default_cluster_name not in default_clusters : + default_clusters[default_cluster_name] = {} + if word not in default_clusters[default_cluster_name] : + default_clusters[default_cluster_name][word] = 1 + default_words[word] = default_cluster_name + + result = calculator.calculate(lab, rec) + if verbose: + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('WER: %4.2f %%' % wer, end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], + result['del'], result['ins'])) + space = {} + space['lab'] = [] + space['rec'] = [] + for idx in range(len(result['lab'])) : + len_lab = width(result['lab'][idx]) + len_rec = width(result['rec'][idx]) + length = max(len_lab, len_rec) + space['lab'].append(length - len_lab) + space['rec'].append(length - len_rec) + upper_lab = len(result['lab']) + upper_rec = len(result['rec']) + lab1, rec1 = 0, 0 + while lab1 < upper_lab or rec1 < upper_rec: + if verbose > 1: + print('lab(%s):' % fid.encode('utf-8'), end=' ') + else: + print('lab:', end=' ') + lab2 = min(upper_lab, lab1 + max_words_per_line) + for idx in range(lab1, lab2): + token = result['lab'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['lab'][idx]) : + print(padding_symbol, end='') + print(' ', end='') + print() + if verbose > 1: + print('rec(%s):' % fid.encode('utf-8'), end=' ') + else: + print('rec:', end=' ') + rec2 = min(upper_rec, rec1 + max_words_per_line) + for idx in range(rec1, rec2): + token = result['rec'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['rec'][idx]) : + print(padding_symbol, end='') + print(' ', end='') + print('\n', end='\n') + lab1 = lab2 + rec1 = rec2 + + if verbose: + print('===================================================' + '========================') + print() + + result = calculator.overall() + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('Overall -> wer %4.2f %% Corr %4.2f %%' % (wer, result['cor']*100/result['all']), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], + result['del'], result['ins'])) + if not verbose: + print() + + if verbose: + for cluster_id in default_clusters : + result = calculator.cluster(k for k in default_clusters[cluster_id]) + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], + result['del'], result['ins'])) + if len(cluster_file) > 0 : # compute separated WERs for word clusters + cluster_id = '' + cluster = [] + for line in open(cluster_file, 'r', encoding='utf-8') : + for token in line.decode('utf-8').rstrip('\n').split() : + # end of cluster reached, like + if token[0:2] == '' and \ + token.lstrip('') == cluster_id : + result = calculator.cluster(cluster) + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], + result['del'], result['ins'])) + cluster_id = '' + cluster = [] + # begin of cluster reached, like + elif (token[0] == '<' and token[len(token) - 1] == '>' and + cluster_id == ''): + cluster_id = token.lstrip('<').rstrip('>') + cluster = [] + # general terms, like WEATHER / CAR / ... + else : + cluster.append(token) + print() + print('=======================================' + '====================================') diff --git a/models/speech/speech_recognition/conformer/ixrt/tools/filter_scp.pl b/models/speech/speech_recognition/conformer/ixrt/tools/filter_scp.pl new file mode 100755 index 00000000..b76d37f4 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/tools/filter_scp.pl @@ -0,0 +1,87 @@ +#!/usr/bin/env perl +# Copyright 2010-2012 Microsoft Corporation +# Johns Hopkins University (author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script takes a list of utterance-ids or any file whose first field +# of each line is an utterance-id, and filters an scp +# file (or any file whose "n-th" field is an utterance id), printing +# out only those lines whose "n-th" field is in id_list. The index of +# the "n-th" field is 1, by default, but can be changed by using +# the -f switch + +$exclude = 0; +$field = 1; +$shifted = 0; + +do { + $shifted=0; + if ($ARGV[0] eq "--exclude") { + $exclude = 1; + shift @ARGV; + $shifted=1; + } + if ($ARGV[0] eq "-f") { + $field = $ARGV[1]; + shift @ARGV; shift @ARGV; + $shifted=1 + } +} while ($shifted); + +if(@ARGV < 1 || @ARGV > 2) { + die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . + "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . + "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . + "only the lines that were *not* in id_list.\n" . + "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . + "If your older scripts (written before Oct 2014) stopped working and you used the\n" . + "-f option, add 1 to the argument.\n" . + "See also: utils/filter_scp.pl .\n"; +} + + +$idlist = shift @ARGV; +open(F, "<$idlist") || die "Could not open id-list file $idlist"; +while() { + @A = split; + @A>=1 || die "Invalid id-list file line $_"; + $seen{$A[0]} = 1; +} + +if ($field == 1) { # Treat this as special case, since it is common. + while(<>) { + $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; + # $1 is what we filter on. + if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { + print $_; + } + } +} else { + while(<>) { + @A = split; + @A > 0 || die "Invalid scp file line $_"; + @A >= $field || die "Invalid scp file line $_"; + if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { + print $_; + } + } +} + +# tests: +# the following should print "foo 1" +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo) +# the following should print "bar 2". +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2) diff --git a/models/speech/speech_recognition/conformer/ixrt/tools/make_raw_list.py b/models/speech/speech_recognition/conformer/ixrt/tools/make_raw_list.py new file mode 100755 index 00000000..2f84f015 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/tools/make_raw_list.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='') + parser.add_argument('--segments', default=None, help='segments file') + parser.add_argument('wav_file', help='wav file') + parser.add_argument('text_file', help='text file') + parser.add_argument('output_file', help='output list file') + args = parser.parse_args() + + wav_table = {} + with open(args.wav_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + wav_table[arr[0]] = arr[1] + + if args.segments is not None: + segments_table = {} + with open(args.segments, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 4 + segments_table[arr[0]] = (arr[1], float(arr[2]), float(arr[3])) + + with open(args.text_file, 'r', encoding='utf8') as fin, \ + open(args.output_file, 'w', encoding='utf8') as fout: + for line in fin: + arr = line.strip().split(maxsplit=1) + key = arr[0] + txt = arr[1] if len(arr) > 1 else '' + if args.segments is None: + assert key in wav_table + wav = wav_table[key] + line = dict(key=key, wav=wav, txt=txt) + else: + assert key in segments_table + wav_key, start, end = segments_table[key] + wav = wav_table[wav_key] + line = dict(key=key, wav=wav, txt=txt, start=start, end=end) + json_line = json.dumps(line, ensure_ascii=False) + fout.write(json_line + '\n') diff --git a/models/speech/speech_recognition/conformer/ixrt/tools/make_shard_list.py b/models/speech/speech_recognition/conformer/ixrt/tools/make_shard_list.py new file mode 100755 index 00000000..fcd4bcd7 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/tools/make_shard_list.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import io +import logging +import os +import tarfile +import time +import multiprocessing + +import torch +import torchaudio +import torchaudio.backend.sox_io_backend as sox + +AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) + + +def write_tar_file(data_list, + no_segments, + tar_file, + resample=16000, + index=0, + total=1): + logging.info('Processing {} {}/{}'.format(tar_file, index, total)) + read_time = 0.0 + save_time = 0.0 + write_time = 0.0 + with tarfile.open(tar_file, "w") as tar: + prev_wav = None + for item in data_list: + if no_segments: + key, txt, wav = item + else: + key, txt, wav, start, end = item + + suffix = wav.split('.')[-1] + assert suffix in AUDIO_FORMAT_SETS + if no_segments: + ts = time.time() + with open(wav, 'rb') as fin: + data = fin.read() + read_time += (time.time() - ts) + else: + if wav != prev_wav: + ts = time.time() + waveforms, sample_rate = sox.load(wav, normalize=False) + read_time += (time.time() - ts) + prev_wav = wav + start = int(start * sample_rate) + end = int(end * sample_rate) + audio = waveforms[:1, start:end] + + # resample + if sample_rate != resample: + audio = torchaudio.transforms.Resample( + sample_rate, resample)(audio) + + ts = time.time() + f = io.BytesIO() + sox.save(f, audio, resample, format="wav", bits_per_sample=16) + # Save to wav for segments file + suffix = "wav" + f.seek(0) + data = f.read() + save_time += (time.time() - ts) + + assert isinstance(txt, str) + ts = time.time() + txt_file = key + '.txt' + txt = txt.encode('utf8') + txt_data = io.BytesIO(txt) + txt_info = tarfile.TarInfo(txt_file) + txt_info.size = len(txt) + tar.addfile(txt_info, txt_data) + + wav_file = key + '.' + suffix + wav_data = io.BytesIO(data) + wav_info = tarfile.TarInfo(wav_file) + wav_info.size = len(data) + tar.addfile(wav_info, wav_data) + write_time += (time.time() - ts) + logging.info('read {} save {} write {}'.format(read_time, save_time, + write_time)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='') + parser.add_argument('--num_utts_per_shard', + type=int, + default=1000, + help='num utts per shard') + parser.add_argument('--num_threads', + type=int, + default=1, + help='num threads for make shards') + parser.add_argument('--prefix', + default='shards', + help='prefix of shards tar file') + parser.add_argument('--segments', default=None, help='segments file') + parser.add_argument('--resample', + type=int, + default=16000, + help='segments file') + parser.add_argument('wav_file', help='wav file') + parser.add_argument('text_file', help='text file') + parser.add_argument('shards_dir', help='output shards dir') + parser.add_argument('shards_list', help='output shards list file') + args = parser.parse_args() + logging.basicConfig(level=logging.INFO, + format='%(asctime)s %(levelname)s %(message)s') + + torch.set_num_threads(1) + wav_table = {} + with open(args.wav_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + wav_table[arr[0]] = arr[1] + + no_segments = True + segments_table = {} + if args.segments is not None: + no_segments = False + with open(args.segments, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 4 + segments_table[arr[0]] = (arr[1], float(arr[2]), float(arr[3])) + + data = [] + with open(args.text_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split(maxsplit=1) + key = arr[0] + txt = arr[1] if len(arr) > 1 else '' + if no_segments: + assert key in wav_table + wav = wav_table[key] + data.append((key, txt, wav)) + else: + wav_key, start, end = segments_table[key] + wav = wav_table[wav_key] + data.append((key, txt, wav, start, end)) + + num = args.num_utts_per_shard + chunks = [data[i:i + num] for i in range(0, len(data), num)] + os.makedirs(args.shards_dir, exist_ok=True) + + # Using thread pool to speedup + pool = multiprocessing.Pool(processes=args.num_threads) + shards_list = [] + tasks_list = [] + num_chunks = len(chunks) + for i, chunk in enumerate(chunks): + tar_file = os.path.join(args.shards_dir, + '{}_{:09d}.tar'.format(args.prefix, i)) + shards_list.append(tar_file) + pool.apply_async( + write_tar_file, + (chunk, no_segments, tar_file, args.resample, i, num_chunks)) + + pool.close() + pool.join() + + with open(args.shards_list, 'w', encoding='utf8') as fout: + for name in shards_list: + fout.write(name + '\n') diff --git a/models/speech/speech_recognition/conformer/ixrt/tools/text2token.py b/models/speech/speech_recognition/conformer/ixrt/tools/text2token.py new file mode 100755 index 00000000..4f4dcc90 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/tools/text2token.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Copyright 2021 JD AI Lab. All Rights Reserved. (authors: Lu Fan) +# Copyright 2021 Mobvoi Inc. All Rights Reserved. (Di Wu) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +from __future__ import print_function +from __future__ import unicode_literals + +import argparse +import codecs +import re +import sys + +is_python2 = sys.version_info[0] == 2 + + +def exist_or_not(i, match_pos): + start_pos = None + end_pos = None + for pos in match_pos: + if pos[0] <= i < pos[1]: + start_pos = pos[0] + end_pos = pos[1] + break + + return start_pos, end_pos + +def seg_char(sent): + pattern = re.compile(r'([\u4e00-\u9fa5])') + chars = pattern.split(sent) + chars = [w for w in chars if len(w.strip()) > 0] + return chars + +def get_parser(): + parser = argparse.ArgumentParser( + description='convert raw text to tokenized text', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--nchar', + '-n', + default=1, + type=int, + help='number of characters to split, i.e., \ + aabb -> a a b b with -n 1 and aa bb with -n 2') + parser.add_argument('--skip-ncols', + '-s', + default=0, + type=int, + help='skip first n columns') + parser.add_argument('--space', + default='', + type=str, + help='space symbol') + parser.add_argument('--bpe-model', + '-m', + default=None, + type=str, + help='bpe model for english part') + parser.add_argument('--non-lang-syms', + '-l', + default=None, + type=str, + help='list of non-linguistic symobles,' + ' e.g., etc.') + parser.add_argument('text', + type=str, + default=False, + nargs='?', + help='input text') + parser.add_argument('--trans_type', + '-t', + type=str, + default="char", + choices=["char", "phn", "cn_char_en_bpe"], + help="""Transcript type. char/phn. e.g., for TIMIT + FADG0_SI1279 - + If trans_type is char, read from + SI1279.WRD file -> "bricks are an alternative" + Else if trans_type is phn, + read from SI1279.PHN file -> + "sil b r ih sil k s aa r er n aa l + sil t er n ih sil t ih v sil" """) + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + rs = [] + if args.non_lang_syms is not None: + with codecs.open(args.non_lang_syms, 'r', encoding="utf-8") as f: + nls = [x.rstrip() for x in f.readlines()] + rs = [re.compile(re.escape(x)) for x in nls] + + if args.bpe_model is not None: + import sentencepiece as spm + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + if args.text: + f = codecs.open(args.text, encoding="utf-8") + else: + f = codecs.getreader("utf-8")( + sys.stdin if is_python2 else sys.stdin.buffer) + + sys.stdout = codecs.getwriter("utf-8")( + sys.stdout if is_python2 else sys.stdout.buffer) + line = f.readline() + n = args.nchar + while line: + x = line.split() + print(' '.join(x[:args.skip_ncols]), end=" ") + a = ' '.join(x[args.skip_ncols:]) + + # get all matched positions + match_pos = [] + for r in rs: + i = 0 + while i >= 0: + m = r.search(a, i) + if m: + match_pos.append([m.start(), m.end()]) + i = m.end() + else: + break + + if len(match_pos) > 0: + chars = [] + i = 0 + while i < len(a): + start_pos, end_pos = exist_or_not(i, match_pos) + if start_pos is not None: + chars.append(a[start_pos:end_pos]) + i = end_pos + else: + chars.append(a[i]) + i += 1 + a = chars + + if (args.trans_type == "phn"): + a = a.split(" ") + elif args.trans_type == "cn_char_en_bpe": + b = seg_char(a) + a = [] + for j in b: + # we use "▁" to instead of blanks among english words + # warning: here is "▁", not "_" + for l in j.strip().split("▁"): + if not l.encode('UTF-8').isalpha(): + a.append(l) + else: + for k in sp.encode_as_pieces(l): + a.append(k) + else: + a = [a[j:j + n] for j in range(0, len(a), n)] + + a_flat = [] + for z in a: + a_flat.append("".join(z)) + + a_chars = [z.replace(' ', args.space) for z in a_flat] + if (args.trans_type == "phn"): + a_chars = [z.replace("sil", args.space) for z in a_chars] + print(' '.join(a_chars)) + line = f.readline() + + +if __name__ == '__main__': + main() diff --git a/models/speech/speech_recognition/conformer/ixrt/utils/__init__.py b/models/speech/speech_recognition/conformer/ixrt/utils/__init__.py new file mode 100644 index 00000000..c57435c1 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/utils/__init__.py @@ -0,0 +1,39 @@ +import os +import torch +import numpy as np + +from .embedding import RelPositionalEncoding + + +rel_positional_encoding = RelPositionalEncoding(256, 0.1) + + +def make_pad_mask(lengths: np.ndarray, max_len: int = 0) -> np.ndarray : + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. + + Args: + lengths (numpy.ndarray): Batch of lengths (B,). + Returns: + numpy.ndarray: Mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + + batch_size = lengths.shape[0] + max_len = max_len if max_len > 0 else lengths.max().item() + seq_range = np.arange(0, max_len, dtype=np.int64) + seq_range_expand = np.tile(seq_range, batch_size).reshape(batch_size, max_len) + seq_length_expand = lengths[..., None] + mask = seq_range_expand >= seq_length_expand + mask = np.expand_dims(mask, axis=1) + mask = ~mask + mask = mask[:, :, 2::2][:, :, 2::2] + mask = mask.astype(np.int32) + return mask diff --git a/models/speech/speech_recognition/conformer/ixrt/utils/embedding.py b/models/speech/speech_recognition/conformer/ixrt/utils/embedding.py new file mode 100644 index 00000000..0fd65c4c --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/utils/embedding.py @@ -0,0 +1,133 @@ +"""Positonal Encoding Module.""" + +import math +from typing import Tuple, Union + +import torch +import torch.nn.functional as F +import numpy as np + + +class PositionalEncoding(torch.nn.Module): + """Positional encoding. + + :param int d_model: embedding dim + :param float dropout_rate: dropout rate + :param int max_len: maximum input length + + PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) + PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) + """ + + def __init__(self, + d_model: int, + dropout_rate: float, + max_len: int = 5000, + reverse: bool = False): + """Construct an PositionalEncoding object.""" + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.max_len = max_len + + pe = torch.zeros(self.max_len, self.d_model) + position = torch.arange(0, self.max_len, + dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) * + -(math.log(10000.0) / self.d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, + x: torch.Tensor, + offset: Union[int, torch.Tensor] = 0) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input. Its shape is (batch, time, ...) + offset (int, torch.tensor): position offset + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + torch.Tensor: for compatibility to RelPositionalEncoding + """ + + pos_emb = self.position_encoding(offset, x.size(1), False) + x = x * self.xscale + pos_emb + return self.dropout(x), self.dropout(pos_emb) + + def position_encoding(self, + offset: Union[int, torch.Tensor], + size: int, + apply_dropout: bool = True) -> torch.Tensor: + """ For getting encoding in a streaming fashion + + Attention!!!!! + we apply dropout only once at the whole utterance level in a none + streaming way, but will call this function several times with + increasing input size in a streaming scenario, so the dropout will + be applied several times. + + Args: + offset (int or torch.tensor): start offset + size (int): required size of position encoding + + Returns: + torch.Tensor: Corresponding encoding + """ + # How to subscript a Union type: + # https://github.com/pytorch/pytorch/issues/69434 + # import ipdb;ipdb.set_trace() + if isinstance(offset, int): + assert offset + size <= self.max_len + pos_emb = self.pe[:, offset:offset + size] + elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar + assert offset + size <= self.max_len + pos_emb = self.pe[:, offset:offset + size] + else: # for batched streaming decoding on GPU + assert torch.max(offset) + size <= self.max_len + index = offset.unsqueeze(1) + \ + torch.arange(0, size).to(offset.device) # B X T + flag = index > 0 + # remove negative offset + index = index * flag + pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model + + if apply_dropout: + pos_emb = self.dropout(pos_emb) + return pos_emb + + +class RelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module. + See : Appendix B in https://arxiv.org/abs/1901.02860 + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): + """Initialize class.""" + super().__init__(d_model, dropout_rate, max_len, reverse=True) + + def forward(self, + seq_len: int, + offset: Union[int, torch.Tensor] = 0) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """Compute positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Positional embedding tensor (1, time, `*`). + """ + pos_emb = self.position_encoding(offset, seq_len, False) + # return self.dropout(pos_emb) + return pos_emb + diff --git a/models/speech/speech_recognition/conformer/ixrt/wenet/__init__.py b/models/speech/speech_recognition/conformer/ixrt/wenet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/speech/speech_recognition/conformer/ixrt/wenet/dataset.py b/models/speech/speech_recognition/conformer/ixrt/wenet/dataset.py new file mode 100644 index 00000000..88a8cd15 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/wenet/dataset.py @@ -0,0 +1,179 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import torch +import torch.distributed as dist +from torch.utils.data import IterableDataset + +import wenet.processor as processor +from wenet.file_utils import read_lists + + +class Processor(IterableDataset): + def __init__(self, source, f, *args, **kw): + assert callable(f) + self.source = source + self.f = f + self.args = args + self.kw = kw + + def set_epoch(self, epoch): + self.source.set_epoch(epoch) + + def __iter__(self): + """ Return an iterator over the source dataset processed by the + given processor. + """ + assert self.source is not None + assert callable(self.f) + return self.f(iter(self.source), *self.args, **self.kw) + + def apply(self, f): + assert callable(f) + return Processor(self, f, *self.args, **self.kw) + + +class DistributedSampler: + def __init__(self, shuffle=True, partition=True): + self.epoch = -1 + self.update() + self.shuffle = shuffle + self.partition = partition + + def update(self): + assert dist.is_available() + if dist.is_initialized(): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + else: + self.rank = 0 + self.world_size = 1 + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + self.worker_id = 0 + self.num_workers = 1 + else: + self.worker_id = worker_info.id + self.num_workers = worker_info.num_workers + return dict(rank=self.rank, + world_size=self.world_size, + worker_id=self.worker_id, + num_workers=self.num_workers) + + def set_epoch(self, epoch): + self.epoch = epoch + + def sample(self, data): + """ Sample data according to rank/world_size/num_workers + + Args: + data(List): input data list + + Returns: + List: data list after sample + """ + data = list(range(len(data))) + # TODO(Binbin Zhang): fix this + # We can not handle uneven data for CV on DDP, so we don't + # sample data by rank, that means every GPU gets the same + # and all the CV data + if self.partition: + if self.shuffle: + random.Random(self.epoch).shuffle(data) + data = data[self.rank::self.world_size] + data = data[self.worker_id::self.num_workers] + return data + + +class DataList(IterableDataset): + def __init__(self, lists, shuffle=True, partition=True): + self.lists = lists + self.sampler = DistributedSampler(shuffle, partition) + + def set_epoch(self, epoch): + self.sampler.set_epoch(epoch) + + def __iter__(self): + sampler_info = self.sampler.update() + indexes = self.sampler.sample(self.lists) + for index in indexes: + # yield dict(src=src) + data = dict(src=self.lists[index]) + data.update(sampler_info) + yield data + + +def Dataset(data_type, + data_list_file, + symbol_table, + conf, + bpe_model=None, + non_lang_syms=None, + partition=True): + """ Construct dataset from arguments + + We have two shuffle stage in the Dataset. The first is global + shuffle at shards tar/raw file level. The second is global shuffle + at training samples level. + + Args: + data_type(str): raw/shard + bpe_model(str): model for english bpe part + partition(bool): whether to do data partition in terms of rank + """ + assert data_type in ['raw', 'shard'] + lists = read_lists(data_list_file) + shuffle = conf.get('shuffle', True) + dataset = DataList(lists, shuffle=shuffle, partition=partition) + if data_type == 'shard': + dataset = Processor(dataset, processor.url_opener) + dataset = Processor(dataset, processor.tar_file_and_group) + else: + dataset = Processor(dataset, processor.parse_raw) + + dataset = Processor(dataset, processor.tokenize, symbol_table, bpe_model, + non_lang_syms, conf.get('split_with_space', False)) + filter_conf = conf.get('filter_conf', {}) + dataset = Processor(dataset, processor.filter, **filter_conf) + + resample_conf = conf.get('resample_conf', {}) + dataset = Processor(dataset, processor.resample, **resample_conf) + + speed_perturb = conf.get('speed_perturb', False) + if speed_perturb: + dataset = Processor(dataset, processor.speed_perturb) + + fbank_conf = conf.get('fbank_conf', {}) + dataset = Processor(dataset, processor.compute_fbank, **fbank_conf) + + spec_aug = conf.get('spec_aug', True) + if spec_aug: + spec_aug_conf = conf.get('spec_aug_conf', {}) + dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf) + + if shuffle: + shuffle_conf = conf.get('shuffle_conf', {}) + dataset = Processor(dataset, processor.shuffle, **shuffle_conf) + + sort = conf.get('sort', True) + if sort: + sort_conf = conf.get('sort_conf', {}) + dataset = Processor(dataset, processor.sort, **sort_conf) + + batch_conf = conf.get('batch_conf', {}) + dataset = Processor(dataset, processor.batch, **batch_conf) + dataset = Processor(dataset, processor.padding) + return dataset diff --git a/models/speech/speech_recognition/conformer/ixrt/wenet/file_utils.py b/models/speech/speech_recognition/conformer/ixrt/wenet/file_utils.py new file mode 100644 index 00000000..7b7e516c --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/wenet/file_utils.py @@ -0,0 +1,66 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + + +def read_lists(list_file): + lists = [] + with open(list_file, 'r', encoding='utf8') as fin: + for line in fin: + lists.append(line.strip()) + return lists + + +def read_non_lang_symbols(non_lang_sym_path): + """read non-linguistic symbol from file. + + The file format is like below: + + {NOISE}\n + {BRK}\n + ... + + + Args: + non_lang_sym_path: non-linguistic symbol file path, None means no any + syms. + + """ + if non_lang_sym_path is None: + return None + else: + syms = read_lists(non_lang_sym_path) + non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") + for sym in syms: + if non_lang_syms_pattern.fullmatch(sym) is None: + class BadSymbolFormat(Exception): + pass + raise BadSymbolFormat( + "Non-linguistic symbols should be " + "formatted in {xxx}//[xxx], consider" + " modify '%s' to meet the requirment. " + "More details can be found in discussions here : " + "https://github.com/wenet-e2e/wenet/pull/819" % (sym)) + return syms + + +def read_symbol_table(symbol_table_file): + symbol_table = {} + with open(symbol_table_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + symbol_table[arr[0]] = int(arr[1]) + return symbol_table diff --git a/models/speech/speech_recognition/conformer/ixrt/wenet/processor.py b/models/speech/speech_recognition/conformer/ixrt/wenet/processor.py new file mode 100644 index 00000000..9a542a3d --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/wenet/processor.py @@ -0,0 +1,550 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import json +import random +import re +import tarfile +from subprocess import PIPE, Popen +from urllib.parse import urlparse + +import torch +import torchaudio +import torchaudio.compliance.kaldi as kaldi +from torch.nn.utils.rnn import pad_sequence + +AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) + + +def url_opener(data): + """ Give url or local file, return file descriptor + Inplace operation. + + Args: + data(Iterable[str]): url or local file list + + Returns: + Iterable[{src, stream}] + """ + for sample in data: + assert 'src' in sample + # TODO(Binbin Zhang): support HTTP + url = sample['src'] + try: + pr = urlparse(url) + # local file + if pr.scheme == '' or pr.scheme == 'file': + stream = open(url, 'rb') + # network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP + else: + cmd = f'curl -s -L {url}' + process = Popen(cmd, shell=True, stdout=PIPE) + sample.update(process=process) + stream = process.stdout + sample.update(stream=stream) + yield sample + except Exception as ex: + logging.warning('Failed to open {}'.format(url)) + + +def tar_file_and_group(data): + """ Expand a stream of open tar files into a stream of tar file contents. + And groups the file with same prefix + + Args: + data: Iterable[{src, stream}] + + Returns: + Iterable[{key, wav, txt, sample_rate}] + """ + for sample in data: + assert 'stream' in sample + stream = tarfile.open(fileobj=sample['stream'], mode="r|*") + prev_prefix = None + example = {} + valid = True + for tarinfo in stream: + name = tarinfo.name + pos = name.rfind('.') + assert pos > 0 + prefix, postfix = name[:pos], name[pos + 1:] + if prev_prefix is not None and prefix != prev_prefix: + example['key'] = prev_prefix + if valid: + yield example + example = {} + valid = True + with stream.extractfile(tarinfo) as file_obj: + try: + if postfix == 'txt': + example['txt'] = file_obj.read().decode('utf8').strip() + elif postfix in AUDIO_FORMAT_SETS: + waveform, sample_rate = torchaudio.load(file_obj) + example['wav'] = waveform + example['sample_rate'] = sample_rate + else: + example[postfix] = file_obj.read() + except Exception as ex: + valid = False + logging.warning('error to parse {}'.format(name)) + prev_prefix = prefix + if prev_prefix is not None: + example['key'] = prev_prefix + yield example + stream.close() + if 'process' in sample: + sample['process'].communicate() + sample['stream'].close() + + +def parse_raw(data): + """ Parse key/wav/txt from json line + + Args: + data: Iterable[str], str is a json line has key/wav/txt + + Returns: + Iterable[{key, wav, txt, sample_rate}] + """ + for sample in data: + assert 'src' in sample + json_line = sample['src'] + obj = json.loads(json_line) + assert 'key' in obj + assert 'wav' in obj + assert 'txt' in obj + key = obj['key'] + wav_file = obj['wav'] + txt = obj['txt'] + try: + if 'start' in obj: + assert 'end' in obj + sample_rate = torchaudio.backend.sox_io_backend.info( + wav_file).sample_rate + start_frame = int(obj['start'] * sample_rate) + end_frame = int(obj['end'] * sample_rate) + waveform, _ = torchaudio.backend.sox_io_backend.load( + filepath=wav_file, + num_frames=end_frame - start_frame, + frame_offset=start_frame) + else: + waveform, sample_rate = torchaudio.load(wav_file) + example = dict(key=key, + txt=txt, + wav=waveform, + sample_rate=sample_rate) + yield example + except Exception as ex: + logging.warning('Failed to read {}'.format(wav_file)) + + +def filter(data, + max_length=10240, + min_length=10, + token_max_length=200, + token_min_length=1, + min_output_input_ratio=0.0005, + max_output_input_ratio=1): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + data: Iterable[{key, wav, label, sample_rate}] + max_length: drop utterance which is greater than max_length(10ms) + min_length: drop utterance which is less than min_length(10ms) + token_max_length: drop utterance which is greater than + token_max_length, especially when use char unit for + english modeling + token_min_length: drop utterance which is + less than token_max_length + min_output_input_ratio: minimal ration of + token_length / feats_length(10ms) + max_output_input_ratio: maximum ration of + token_length / feats_length(10ms) + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'label' in sample + # sample['wav'] is torch.Tensor, we have 100 frames every second + num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100 + if num_frames < min_length: + continue + if num_frames > max_length: + continue + if len(sample['label']) < token_min_length: + continue + if len(sample['label']) > token_max_length: + continue + if num_frames != 0: + if len(sample['label']) / num_frames < min_output_input_ratio: + continue + if len(sample['label']) / num_frames > max_output_input_ratio: + continue + yield sample + + +def resample(data, resample_rate=16000): + """ Resample data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + resample_rate: target resample rate + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + if sample_rate != resample_rate: + sample['sample_rate'] = resample_rate + sample['wav'] = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=resample_rate)(waveform) + yield sample + + +def speed_perturb(data, speeds=None): + """ Apply speed perturb to the data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + speeds(List[float]): optional speed + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + if speeds is None: + speeds = [0.9, 1.0, 1.1] + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + speed = random.choice(speeds) + if speed != 1.0: + wav, _ = torchaudio.sox_effects.apply_effects_tensor( + waveform, sample_rate, + [['speed', str(speed)], ['rate', str(sample_rate)]]) + sample['wav'] = wav + + yield sample + + +def compute_fbank(data, + num_mel_bins=23, + frame_length=25, + frame_shift=10, + dither=0.0): + """ Extract fbank + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'key' in sample + assert 'label' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + waveform = waveform * (1 << 15) + # Only keep key, feat, label + mat = kaldi.fbank(waveform, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + energy_floor=0.0, + sample_frequency=sample_rate) + yield dict(key=sample['key'], label=sample['label'], feat=mat) + + +def __tokenize_by_bpe_model(sp, txt): + tokens = [] + # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + pattern = re.compile(r'([\u4e00-\u9fff])') + # Example: + # txt = "你好 ITS'S OKAY 的" + # chars = ["你", "好", " ITS'S OKAY ", "的"] + chars = pattern.split(txt.upper()) + mix_chars = [w for w in chars if len(w.strip()) > 0] + for ch_or_w in mix_chars: + # ch_or_w is a single CJK charater(i.e., "你"), do nothing. + if pattern.fullmatch(ch_or_w) is not None: + tokens.append(ch_or_w) + # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "), + # encode ch_or_w using bpe_model. + else: + for p in sp.encode_as_pieces(ch_or_w): + tokens.append(p) + + return tokens + + +def tokenize(data, symbol_table, bpe_model=None, non_lang_syms=None, + split_with_space=False): + """ Decode text to chars or BPE + Inplace operation + + Args: + data: Iterable[{key, wav, txt, sample_rate}] + + Returns: + Iterable[{key, wav, txt, tokens, label, sample_rate}] + """ + if non_lang_syms is not None: + non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") + else: + non_lang_syms = {} + non_lang_syms_pattern = None + + if bpe_model is not None: + import sentencepiece as spm + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + else: + sp = None + + for sample in data: + assert 'txt' in sample + txt = sample['txt'].strip() + if non_lang_syms_pattern is not None: + parts = non_lang_syms_pattern.split(txt.upper()) + parts = [w for w in parts if len(w.strip()) > 0] + else: + parts = [txt] + + label = [] + tokens = [] + for part in parts: + if part in non_lang_syms: + tokens.append(part) + else: + if bpe_model is not None: + tokens.extend(__tokenize_by_bpe_model(sp, part)) + else: + if split_with_space: + part = part.split(" ") + for ch in part: + if ch == ' ': + ch = "▁" + tokens.append(ch) + + for ch in tokens: + if ch in symbol_table: + label.append(symbol_table[ch]) + elif '' in symbol_table: + label.append(symbol_table['']) + + sample['tokens'] = tokens + sample['label'] = label + yield sample + + +def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80): + """ Do spec augmentation + Inplace operation + + Args: + data: Iterable[{key, feat, label}] + num_t_mask: number of time mask to apply + num_f_mask: number of freq mask to apply + max_t: max width of time mask + max_f: max width of freq mask + max_w: max width of time warp + + Returns + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'feat' in sample + x = sample['feat'] + assert isinstance(x, torch.Tensor) + y = x.clone().detach() + max_frames = y.size(0) + max_freq = y.size(1) + # time mask + for i in range(num_t_mask): + start = random.randint(0, max_frames - 1) + length = random.randint(1, max_t) + end = min(max_frames, start + length) + y[start:end, :] = 0 + # freq mask + for i in range(num_f_mask): + start = random.randint(0, max_freq - 1) + length = random.randint(1, max_f) + end = min(max_freq, start + length) + y[:, start:end] = 0 + sample['feat'] = y + yield sample + + +def shuffle(data, shuffle_size=10000): + """ Local shuffle the data + + Args: + data: Iterable[{key, feat, label}] + shuffle_size: buffer size for shuffle + + Returns: + Iterable[{key, feat, label}] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= shuffle_size: + random.shuffle(buf) + for x in buf: + yield x + buf = [] + # The sample left over + random.shuffle(buf) + for x in buf: + yield x + + +def sort(data, sort_size=500): + """ Sort the data by feature length. + Sort is used after shuffle and before batch, so we can group + utts with similar lengths into a batch, and `sort_size` should + be less than `shuffle_size` + + Args: + data: Iterable[{key, feat, label}] + sort_size: buffer size for sort + + Returns: + Iterable[{key, feat, label}] + """ + + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= sort_size: + buf.sort(key=lambda x: x['feat'].size(0)) + for x in buf: + yield x + buf = [] + # The sample left over + buf.sort(key=lambda x: x['feat'].size(0)) + for x in buf: + yield x + + +def static_batch(data, batch_size=16): + """ Static batch the data by `batch_size` + + Args: + data: Iterable[{key, feat, label}] + batch_size: batch size + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= batch_size: + yield buf + buf = [] + if len(buf) > 0: + yield buf + + +def dynamic_batch(data, max_frames_in_batch=12000): + """ Dynamic batch the data until the total frames in batch + reach `max_frames_in_batch` + + Args: + data: Iterable[{key, feat, label}] + max_frames_in_batch: max_frames in one batch + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + longest_frames = 0 + for sample in data: + assert 'feat' in sample + assert isinstance(sample['feat'], torch.Tensor) + new_sample_frames = sample['feat'].size(0) + longest_frames = max(longest_frames, new_sample_frames) + frames_after_padding = longest_frames * (len(buf) + 1) + if frames_after_padding > max_frames_in_batch: + yield buf + buf = [sample] + longest_frames = new_sample_frames + else: + buf.append(sample) + if len(buf) > 0: + yield buf + + +def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000): + """ Wrapper for static/dynamic batch + """ + if batch_type == 'static': + return static_batch(data, batch_size) + elif batch_type == 'dynamic': + return dynamic_batch(data, max_frames_in_batch) + else: + logging.fatal('Unsupported batch type {}'.format(batch_type)) + + +def padding(data): + """ Padding the data into training data + + Args: + data: Iterable[List[{key, feat, label}]] + + Returns: + Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] + """ + for sample in data: + assert isinstance(sample, list) + feats_length = torch.tensor([x['feat'].size(0) for x in sample], + dtype=torch.int32) + order = torch.argsort(feats_length, descending=True) + feats_lengths = torch.tensor( + [sample[i]['feat'].size(0) for i in order], dtype=torch.int32) + sorted_feats = [sample[i]['feat'] for i in order] + sorted_keys = [sample[i]['key'] for i in order] + sorted_labels = [ + torch.tensor(sample[i]['label'], dtype=torch.int64) for i in order + ] + label_lengths = torch.tensor([x.size(0) for x in sorted_labels], + dtype=torch.int32) + + padded_feats = pad_sequence(sorted_feats, + batch_first=True, + padding_value=0) + padding_labels = pad_sequence(sorted_labels, + batch_first=True, + padding_value=-1) + + yield (sorted_keys, padded_feats, padding_labels, feats_lengths, + label_lengths) -- Gitee From 41d6ae60f6464fd5a7fdfd9a41544e16844abffb Mon Sep 17 00:00:00 2001 From: majorli Date: Thu, 15 Aug 2024 11:14:27 +0800 Subject: [PATCH 6/7] update conformer readme format and result Signed-off-by: majorli --- .../conformer/ixrt/README.md | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/models/speech/speech_recognition/conformer/ixrt/README.md b/models/speech/speech_recognition/conformer/ixrt/README.md index 2d4d98f3..2ad0e26a 100644 --- a/models/speech/speech_recognition/conformer/ixrt/README.md +++ b/models/speech/speech_recognition/conformer/ixrt/README.md @@ -9,29 +9,41 @@ Conformer is a speech recognition model proposed by Google in 2020. It combines ### Install ```bash +# Install libGL +## CentOS +yum install -y mesa-libGL +## Ubuntu +apt install -y libgl1-mesa-glx + pip3 install tqdm pip3 install onnx pip3 install typeguard==2.13.3 pip3 install onnxsim +pip3 install pycuda ``` ### Download -Pretrained model: +Pretrained model: Dataset: to download the Aishell dataset. -download and put model in conformer_checkpoints, put data in aishell_test_data. +Download and put model in conformer_checkpoints. + +```bash +ln -s /home/deepspark/datasets/INFER/conformer/20210601_u2++_conformer_exp_aishell ./conformer_checkpoints +``` ### Prepare Data + ```bash # Accuracy -DATA_DIR=./aishell_test_data -Tool_DIR=./tools -bash scripts/aishell_data_prepare.sh ${DATA_DIR} ${Tool_DIR} +DATA_DIR=/PATH/to/data_aishell +TOOL_DIR="$(pwd)/tools" +bash scripts/aishell_data_prepare.sh ${DATA_DIR} ${TOOL_DIR} ``` -### Model Conversion And Inference +## Model Conversion And Inference ### FP16 @@ -44,6 +56,6 @@ bash scripts/infer_conformer_fp16_performance_ixrt.sh ## Results -Model |BatchSize |Precision |QPS |CER | ------------|-----------|----------|----------|----------| -Conformer | 24 | FP16 | 380.00 | 0.051 | +| Model | BatchSize | Precision | QPS | CER | +| --------- | --------- | --------- | ------- | ------ | +| Conformer | 24 | FP16 | 387.821 | 0.0517 | -- Gitee From f7b5fb5de3e0fbaf0b949aac62a057a000137feb Mon Sep 17 00:00:00 2001 From: majorli Date: Thu, 15 Aug 2024 11:19:36 +0800 Subject: [PATCH 7/7] yolov6 acc results not correct Signed-off-by: majorli --- models/cv/detection/yolov6/ixrt/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/cv/detection/yolov6/ixrt/README.md b/models/cv/detection/yolov6/ixrt/README.md index 166c2fa0..66258563 100644 --- a/models/cv/detection/yolov6/ixrt/README.md +++ b/models/cv/detection/yolov6/ixrt/README.md @@ -76,7 +76,7 @@ bash scripts/infer_yolov6s_int8_performance.sh | Model | BatchSize | Precision | FPS | MAP@0.5 | | ------ | --------- | --------- | -------- | ------- | -| YOLOv6 | 32 | FP16 | 1107.511 | 0.355 | +| YOLOv6 | 32 | FP16 | 1107.511 | - | | YOLOv6 | 32 | INT8 | 2080.475 | - | ## Reference -- Gitee