diff --git a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README.md b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README.md index 004bb94fb9c5a52b0bc074a93d93cbc6ff9c653b..ccde19b6ddc6307d705c3e5cc2b5d43997ba8006 100644 --- a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README.md +++ b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README.md @@ -152,7 +152,7 @@ ```bash bs=1 # 量化模型 - unet_model="models/unet_quant/unet.onnx" + unet_model="models/unet_quant/unet_fuse.onnx" # 非量化模型 unet_model="models/unet/unet.onnx" @@ -186,7 +186,6 @@ FA、TOME、Gelu融合算子需通过安装与CANN版本对应的推理引擎包(MindIE)来获取,如未安装推理引擎或使用的版本不支持FA、TOME、SliceGelu算子,FA_soc和TOME_num参数请使用默认配置、不设置faster_gelu参数。 - 多batch场景限制:A2场景下暂不支持FA算子优化,FA_soc参数请设置为None。 3. 适配cache方案(可选,可提升性能但可能导致精度下降) @@ -414,7 +413,6 @@ # Clip Score 和 HPSv2 均需使用的权重 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K - cd ./CLIP-ViT-H-14-laion2B-s32B-b79K # HPSv2权重 wget https://huggingface.co/spaces/xswu/HPSv2/resolve/main/HPS_v2_compressed.pt --no-check-certificate diff --git a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md index 19c3a9e795dafc52656f268b6824ba31331e737d..150d1d5036447d14ad90b88ac18bc08befa49fc6 100644 --- a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md +++ b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md @@ -34,8 +34,8 @@ python3 quant_unet.py \ - --save_path:量化模型的储存目录,为model_dir下的子文件夹名。 - --data_free:使用虚拟数据。 -执行成功后生成`models_bs${bs}/unet_quant`文件夹,包含unet.onnx模型及权重。 - +执行成功后生成`models_bs${bs}/unet_quant`文件夹,包含unet.onnx模型, unet_fuse.onnx(matmul和dequant算子融合)模型及权重。 + ### 真实数据校准 1. 使用ATC工具将ONNX模型转OM模型。 @@ -60,7 +60,7 @@ python3 quant_unet.py \ 2. 执行ATC命令。 ```bash - # 为减少量化耗时,建议使用bs=1场景进行量化 + # 为减少量化耗时,要求使用bs=1场景进行量化 bs=1 # text_encoder cd ./models/text_encoder diff --git a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/modify_onnx.py b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/modify_onnx.py index c214e602c1cbdd538470546b7672e487bcebd546..fee6db54e667adbbaeef6908272f48e35eb74268 100644 --- a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/modify_onnx.py +++ b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/modify_onnx.py @@ -35,10 +35,10 @@ def add_flash_attention(model, fa_name, soc_type): else: flag = 'attn1' in name if flag: - matmul = model[name[:-3] + 'to_q/MatMul'] + matmul = model[name[:-3] + 'to_q/MatMul_quant'] reshape = model[name[:-3] + 'Reshape'] seqlen = 4096 - if soc_type == 2 and model[reshape.inputs[1]].value[1] != seqlen: + if soc_type == 3 and model[reshape.inputs[1]].value[1] != seqlen: continue softmax_node = model.get_next_nodes(node.outputs[0])[0] if soc_type == 1: @@ -49,11 +49,18 @@ def add_flash_attention(model, fa_name, soc_type): # add flashattention new_node = model.add_node(name[:-3] + fa_name, fa_name) + if soc_type == 3: + new_node.attrs = { + 'input_layout': 'BSH', + 'num_heads': 10, + 'scale_value': 0.125, + 'next_tokens': 65535 + } inputs = [None, None, None] # input 0: q if soc_type == 1: matmul_node = model.get_prev_node(softmax_node.inputs[0]) - if soc_type == 2: + if soc_type == 3: matmul_node = model.get_prev_node(node.inputs[0]) inputs[0] = matmul_node.inputs[0] # input 1: k @@ -82,10 +89,6 @@ def add_flash_attention(model, fa_name, soc_type): model.remove(prev_node.name) next_node = model.get_next_nodes(node.outputs[0])[0] model.remove(next_node.name) - if soc_type == 2: - name = node.name.replace(fa_name, 'Cast') - cast = model.add_node(name, 'Cast', attrs={'to': 1}) - model.insert_node(node.name, cast) def change_input(model, bs): @@ -164,6 +167,8 @@ def get_block(model): norms = [] for node in model.get_nodes('Add'): next_nodes = model.get_next_nodes(node.outputs[0]) + if next_nodes[0].op_type == 'AscendQuant': + next_nodes = model.get_next_nodes(next_nodes[0].outputs[0]) if len(next_nodes) != 3: continue op_type = set(n.op_type for n in next_nodes) @@ -475,10 +480,7 @@ def main(): if args.FA_soc == 'Duo': add_flash_attention(model, 'FlashAttentionTik', soc_type=1) elif args.FA_soc == 'A2': - if batch_size > 2: - print('A2 does not support FA in multi-batch case! The FA modification does not effect.') - else: - add_flash_attention(model, 'UnpadFlashAttentionMix', soc_type=2) + add_flash_attention(model, 'NPUPromptFlashAttention', soc_type=3) if args.TOME_num: insert_tome_block(model, args.TOME_num) replace_slice(model, args.faster_gelu) diff --git a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py index 141f73607b431d13498281648d8729550f684fde..5dea3ca7ed907fa3f53bdfccb7d3636524e0905e 100644 --- a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py +++ b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py @@ -20,6 +20,7 @@ from ais_bench.infer.interface import InferSession from diffusers import DPMSolverMultistepScheduler, EulerDiscreteScheduler, DDIMScheduler from modelslim.onnx.squant_ptq.onnx_quant_tools import OnnxCalibrator from modelslim.onnx.squant_ptq.quant_config import QuantConfig +from auto_optimizer import OnnxGraph import numpy as np import onnx import torch @@ -258,6 +259,66 @@ class StableDiffusionXLDumpPipeline(AscendStableDiffusionXLPipeline): return dump_data +def get_quant_data(node, param, graph): + input_scale = param.input_scale + weight_scale = param.weight_scale + input_offset = param.input_offset + quant_weight = param.quant_weight + node_name = '_'.join(node.inputs[1].split('_')[:-1]) + scale = input_scale[node_name] * weight_scale[node_name] + packed_weight_np_data = scale.squeeze() + float32_scale_deq = np.array(packed_weight_np_data, np.float32) + uint32_scale_deq = np.frombuffer(float32_scale_deq, np.uint32) + uint64_result = np.zeros(float32_scale_deq.shape, np.int64) + if len(uint64_result.shape) == 0: + uint64_result = np.expand_dims(uint64_result, axis=0) + uint64_result |= np.int64(uint32_scale_deq) + graph.add_initializer('_'.join([node.name, 'scale']), uint64_result) + graph.add_initializer('_'.join([node.name, 'offset']), np.array(0).astype(np.float32)) + correction = quant_weight[node_name].astype(np.float32).sum(axis=0)*input_offset[node_name].astype(np.float32) + + return scale, correction + + +def modify_quant_fuse(unet, quant, param): + quant_graph = OnnxGraph.parse(quant) + unet_graph = OnnxGraph.parse(unet) + quant_op_type = "AscendDequant" + quant_list = quant_graph.get_nodes(quant_op_type) + input_scale = param.input_scale + weight_scale = param.weight_scale + input_offset = param.input_offset + quant_weight = param.quant_weight + for node in quant_list: + pre_node = quant_graph.get_prev_node(node.inputs[0]) + if pre_node.op_type == "MatMul": + _, _ = get_quant_data(pre_node, param, quant_graph) + x = pre_node.inputs[1] + w = quant_graph[x].value + quant_graph[x].value = w.transpose(1,0) + quant_graph.add_node('_'.join([pre_node.name, 'quant']), "QuantBatchMatMul", inputs=[pre_node.inputs[0], x, '_'.join([pre_node.name, 'scale']), \ + '_'.join([pre_node.name, 'offset'])], outputs=[node.outputs[0]], attrs={"dtype":0, "transpose_x2":True}) + quant_graph.remove(pre_node.name, mapping={}) + quant_graph.remove(node.name, mapping={}) + quant_graph.update_map() + elif pre_node.op_type == "Add": + matmul_node = quant_graph.get_prev_node(pre_node.inputs[0]) + scale, correction = get_quant_data(matmul_node, param, quant_graph) + x = matmul_node.inputs[1] + w = quant_graph[x].value + quant_graph[x].value = w.transpose(1,0) + ori_bias = np.round(unet_graph[unet_graph[pre_node.name].inputs[0]].value / scale - correction).astype(np.int32) + quant_graph.add_initializer('_'.join([matmul_node.name, 'bias']), ori_bias) + quant_graph.add_node('_'.join([matmul_node.name, 'quant']), "QuantBatchMatMul", inputs=[matmul_node.inputs[0], x, '_'.join([matmul_node.name, 'scale']), \ + '_'.join([matmul_node.name, 'offset']), '_'.join([matmul_node.name, 'bias'])], outputs=[node.outputs[0]], attrs={"dtype":0, "transpose_x2":True}) + quant_graph.remove(pre_node.name, mapping={}) + quant_graph.remove(matmul_node.name, mapping={}) + quant_graph.remove(node.name, mapping={}) + quant_graph.update_map() + + return quant_graph + + def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument( @@ -314,7 +375,7 @@ def parse_arguments(): action='store_true', help="do not use real data" ) - + return parser.parse_args() @@ -429,7 +490,10 @@ def main(): os.makedirs(quant_path, mode=0o744) quant_onnx = os.path.join(quant_path, 'unet.onnx') calib.export_quant_onnx(quant_onnx, use_external=True) - + quant_numpy = calib._get_quant_params() + graph = modify_quant_fuse(unet_onnx, quant_onnx, quant_numpy) + fuse_path = os.path.join(quant_path, 'unet_fuse.onnx') + graph.save(fuse_path) if __name__ == "__main__": main()