From ec029100e06f827887460e3e473ff251df6fe3af Mon Sep 17 00:00:00 2001 From: shikang Date: Mon, 1 Jul 2024 19:06:29 +0800 Subject: [PATCH 1/8] add quant script param --- .../foundation_models/stable_diffusionxl/README_quant.md | 2 ++ .../foundation_models/stable_diffusionxl/quant_unet.py | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md index 19c3a9e795..366b2ef927 100644 --- a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md +++ b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md @@ -25,6 +25,7 @@ python3 quant_unet.py \ --model_dir ./models \ --prompt_file ./prompts.txt \ --save_path unet_quant \ + --save_quant_param quant_param.npy \ --data_free ``` 参数说明: @@ -32,6 +33,7 @@ python3 quant_unet.py \ - --model_dir:存放导出模型的目录。 - --prompt_file:输入文本文件,按行分割。 - --save_path:量化模型的储存目录,为model_dir下的子文件夹名。 +- --save_quant_param:保存量化后的quant param数据为npy格式,在改图中会使用到这个数据文件。 - --data_free:使用虚拟数据。 执行成功后生成`models_bs${bs}/unet_quant`文件夹,包含unet.onnx模型及权重。 diff --git a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py index 141f73607b..670b576fef 100644 --- a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py +++ b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py @@ -314,6 +314,12 @@ def parse_arguments(): action='store_true', help="do not use real data" ) + parser.add_argument( + "--save_quant_param", + type=str, + default="quant_param.npy", + help="Path to save quant weight." + ) return parser.parse_args() @@ -429,6 +435,8 @@ def main(): os.makedirs(quant_path, mode=0o744) quant_onnx = os.path.join(quant_path, 'unet.onnx') calib.export_quant_onnx(quant_onnx, use_external=True) + quant_numpy = calib._get_quant_params() + np.save(args.save_quant_param, quant_numpy) if __name__ == "__main__": -- Gitee From 474948e33ed2d1e28c6d85b047c842a728dc1cd3 Mon Sep 17 00:00:00 2001 From: shikang Date: Tue, 2 Jul 2024 02:15:30 +0000 Subject: [PATCH 2/8] add sd script --- .../stable_diffusionxl/README.md | 2 - .../stable_diffusionxl/README_quant.md | 6 +- .../stable_diffusionxl/modify_onnx.py | 23 +++--- .../stable_diffusionxl/quant_unet.py | 77 ++++++++++++++++--- 4 files changed, 81 insertions(+), 27 deletions(-) diff --git a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README.md b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README.md index 004bb94fb9..374a62afb8 100644 --- a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README.md +++ b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README.md @@ -186,7 +186,6 @@ FA、TOME、Gelu融合算子需通过安装与CANN版本对应的推理引擎包(MindIE)来获取,如未安装推理引擎或使用的版本不支持FA、TOME、SliceGelu算子,FA_soc和TOME_num参数请使用默认配置、不设置faster_gelu参数。 - 多batch场景限制:A2场景下暂不支持FA算子优化,FA_soc参数请设置为None。 3. 适配cache方案(可选,可提升性能但可能导致精度下降) @@ -414,7 +413,6 @@ # Clip Score 和 HPSv2 均需使用的权重 GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K - cd ./CLIP-ViT-H-14-laion2B-s32B-b79K # HPSv2权重 wget https://huggingface.co/spaces/xswu/HPSv2/resolve/main/HPS_v2_compressed.pt --no-check-certificate diff --git a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md index 366b2ef927..7d602bae08 100644 --- a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md +++ b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md @@ -25,7 +25,6 @@ python3 quant_unet.py \ --model_dir ./models \ --prompt_file ./prompts.txt \ --save_path unet_quant \ - --save_quant_param quant_param.npy \ --data_free ``` 参数说明: @@ -33,11 +32,10 @@ python3 quant_unet.py \ - --model_dir:存放导出模型的目录。 - --prompt_file:输入文本文件,按行分割。 - --save_path:量化模型的储存目录,为model_dir下的子文件夹名。 -- --save_quant_param:保存量化后的quant param数据为npy格式,在改图中会使用到这个数据文件。 - --data_free:使用虚拟数据。 -执行成功后生成`models_bs${bs}/unet_quant`文件夹,包含unet.onnx模型及权重。 - +执行成功后生成`models_bs${bs}/unet_quant`文件夹,包含unet.onnx模型, unet_fuse.onnx(matmul和dequant算子融合)模型及权重。 + ### 真实数据校准 1. 使用ATC工具将ONNX模型转OM模型。 diff --git a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/modify_onnx.py b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/modify_onnx.py index c214e602c1..48c017f905 100644 --- a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/modify_onnx.py +++ b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/modify_onnx.py @@ -38,7 +38,7 @@ def add_flash_attention(model, fa_name, soc_type): matmul = model[name[:-3] + 'to_q/MatMul'] reshape = model[name[:-3] + 'Reshape'] seqlen = 4096 - if soc_type == 2 and model[reshape.inputs[1]].value[1] != seqlen: + if soc_type == 3 and model[reshape.inputs[1]].value[1] != seqlen: continue softmax_node = model.get_next_nodes(node.outputs[0])[0] if soc_type == 1: @@ -49,11 +49,18 @@ def add_flash_attention(model, fa_name, soc_type): # add flashattention new_node = model.add_node(name[:-3] + fa_name, fa_name) + if soc_type == 3: + new_node.attrs = { + 'input_layout': 'BSH', + 'num_head': 10, + 'scale_value': 0.125, + 'next_tokens': 65535 + } inputs = [None, None, None] # input 0: q if soc_type == 1: matmul_node = model.get_prev_node(softmax_node.inputs[0]) - if soc_type == 2: + if soc_type == 3: matmul_node = model.get_prev_node(node.inputs[0]) inputs[0] = matmul_node.inputs[0] # input 1: k @@ -82,10 +89,6 @@ def add_flash_attention(model, fa_name, soc_type): model.remove(prev_node.name) next_node = model.get_next_nodes(node.outputs[0])[0] model.remove(next_node.name) - if soc_type == 2: - name = node.name.replace(fa_name, 'Cast') - cast = model.add_node(name, 'Cast', attrs={'to': 1}) - model.insert_node(node.name, cast) def change_input(model, bs): @@ -162,7 +165,8 @@ def build_index(h, w, sy=2, sx=2): def get_block(model): # find self-attention block norms = [] - for node in model.get_nodes('Add'): + nodes = model.get_nodes('Add') + model.get_nodes('QuantBatchMatMul') + for node in nodes: next_nodes = model.get_next_nodes(node.outputs[0]) if len(next_nodes) != 3: continue @@ -475,10 +479,7 @@ def main(): if args.FA_soc == 'Duo': add_flash_attention(model, 'FlashAttentionTik', soc_type=1) elif args.FA_soc == 'A2': - if batch_size > 2: - print('A2 does not support FA in multi-batch case! The FA modification does not effect.') - else: - add_flash_attention(model, 'UnpadFlashAttentionMix', soc_type=2) + add_flash_attention(model, 'NPUPromptFlashAttention', soc_type=3) if args.TOME_num: insert_tome_block(model, args.TOME_num) replace_slice(model, args.faster_gelu) diff --git a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py index 670b576fef..b5bf88d9f0 100644 --- a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py +++ b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py @@ -20,6 +20,7 @@ from ais_bench.infer.interface import InferSession from diffusers import DPMSolverMultistepScheduler, EulerDiscreteScheduler, DDIMScheduler from modelslim.onnx.squant_ptq.onnx_quant_tools import OnnxCalibrator from modelslim.onnx.squant_ptq.quant_config import QuantConfig +from auto_optimizer import OnnxGraph import numpy as np import onnx import torch @@ -258,6 +259,67 @@ class StableDiffusionXLDumpPipeline(AscendStableDiffusionXLPipeline): return dump_data +def get_quant_data(node, param, graph): + input_scale = param.input_scale + weight_scale = param.weight_scale + input_offset = param.input_offset + quant_weight = param.quant_weight + node_name = '_'.join(node.inputs[1].split('_')[:-1]) + scale = input_scale[node_name] * weight_scale[node_name] + packed_weight_np_data = scale.squeeze() + float32_scale_deq = np.array(packed_weight_np_data, np.float32) + uint32_scale_deq = np.frombuffer(float32_scale_deq, np.uint32) + uint64_result = np.zeros(float32_scale_deq.shape, np.int64) + if len(uint64_result.shape) == 0: + uint64_result = np.expand_dims(uint64_result, axis=0) + uint64_result |= np.int64(uint32_scale_deq) + graph.add_initializer('_'.join(node.name, 'scale'), uint64_result) + graph.add_initializer('_'.join(node.name, 'offset'), np.array(0).astype(np.float32)) + correction = quant_weight[node_name].astype(np.float32).sum(axis=0)*input_offset[node_name].astype(np.float32) + return scale, correction + + +def modify_quant_fuse(unet, quant, param): + quant_graph = OnnxGraph.parse(quant) + unet_graph = OnnxGraph.parse(unet) + quant_op_type = "AscendDequant" + quant_list = quant_graph.get_nodes(quant_op_type) + input_scale = param.input_scale + weight_scale = param.weight_scale + input_offset = param.input_offset + quant_weight = param.quant_weight + for node in quant_list: + pre_node = quant_graph.get_prev_node(node.inputs[0]) + if pre_node.op_type == "MatMul": + _, _ = get_quant_data(pre_node, param, quant_graph) + x = pre_node.inputs[1] + w = quant_graph[x].value + quant_graph[x].value = w.transpose(1,0) + quant_graph.add_node('_'.join([pre_node.name, 'quant']), "QuantBatchMatMul", \ + inputs=[pre_node.inputs[0], x, '_'.join([pre_node.name, 'scale']), '_'.join([pre_node.name, 'offset'])], \ + outputs=[node.outputs[0]], attrs={"dtype":0, "transpose_x2":True}) + quant_graph.remove(pre_node.name, mapping={}) + quant_graph.remove(node.name, mapping={}) + quant_graph.update_map() + elif pre_node.op_type == "Add": + matmul_node = quant_graph.get_prev_node(pre_node.inputs[0]) + scale, correction = get_quant_data(matmul_node, param, quant_graph) + x = matmul_node.inputs[1] + w = quant_graph[x].value + quant_graph[x].value = w.transpose(1,0) + ori_bias = np.round(unet_graph[unet_graph[pre_node.name].inputs[0]].value / scale - correction).astype(np.int32) + quant_graph.add_initializer('_'.join([matmul_node.name, 'bias']), ori_bias) + quant_graph.add_node('_'.join([matmul_node.name, 'quant']), "QuantBatchMatMul", \ + inputs=[matmul_node.inputs[0], x, '_'.join([matmul_node.name, 'scale']), '_'.join([matmul_node.name, 'offset']), + '_'.join([matmul_node.name, 'bias'])], outputs=[node.outputs[0]], attrs={"dtype":0, "transpose_x2":True}) + graph.remove(pre_node.name, mapping={}) + graph.remove(matmul_node.name, mapping={}) + graph.remove(node.name, mapping={}) + quant_graph.update_map() + + return quant_graph + + def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument( @@ -314,13 +376,7 @@ def parse_arguments(): action='store_true', help="do not use real data" ) - parser.add_argument( - "--save_quant_param", - type=str, - default="quant_param.npy", - help="Path to save quant weight." - ) - + return parser.parse_args() @@ -435,9 +491,10 @@ def main(): os.makedirs(quant_path, mode=0o744) quant_onnx = os.path.join(quant_path, 'unet.onnx') calib.export_quant_onnx(quant_onnx, use_external=True) - quant_numpy = calib._get_quant_params() - np.save(args.save_quant_param, quant_numpy) - + quant_numpy = calib._get_quant_params() + graph = modify_quant_fuse(unet_onnx, quant_onnx, quant_numpy) + fuse_path = os.path.join(quant_path, 'unet_fuse.onnx') + graph.save(fuse_path) if __name__ == "__main__": main() -- Gitee From 29b6fc998dedb4a37f8f5763e5ce0a9a66fceffd Mon Sep 17 00:00:00 2001 From: shikang Date: Fri, 12 Jul 2024 06:06:11 +0000 Subject: [PATCH 3/8] bug fix --- .../stable_diffusionxl/README_quant.md | 2 +- .../stable_diffusionxl/modify_onnx.py | 7 ++++--- .../foundation_models/stable_diffusionxl/quant_unet.py | 10 +++++----- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md index 7d602bae08..613d2d7525 100644 --- a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md +++ b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md @@ -82,7 +82,7 @@ python3 quant_unet.py \ # unet cd ../unet/ atc --framework=5 \ - --model=./unet.onnx \ + --model=./unet_fuse.onnx \ --output=./unet \ --input_format=NCHW \ --log=error \ diff --git a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/modify_onnx.py b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/modify_onnx.py index 48c017f905..e0a82946be 100644 --- a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/modify_onnx.py +++ b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/modify_onnx.py @@ -52,7 +52,7 @@ def add_flash_attention(model, fa_name, soc_type): if soc_type == 3: new_node.attrs = { 'input_layout': 'BSH', - 'num_head': 10, + 'num_heads': 10, 'scale_value': 0.125, 'next_tokens': 65535 } @@ -165,9 +165,10 @@ def build_index(h, w, sy=2, sx=2): def get_block(model): # find self-attention block norms = [] - nodes = model.get_nodes('Add') + model.get_nodes('QuantBatchMatMul') - for node in nodes: + for node in model.get_nodes('Add'): next_nodes = model.get_next_nodes(node.outputs[0]) + if next_nodes[0].op_type == 'AscendQuant': + next_nodes = model.get_next_nodes(next_nodes[0].outputs[0]) if len(next_nodes) != 3: continue op_type = set(n.op_type for n in next_nodes) diff --git a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py index b5bf88d9f0..0774210089 100644 --- a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py +++ b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py @@ -273,8 +273,8 @@ def get_quant_data(node, param, graph): if len(uint64_result.shape) == 0: uint64_result = np.expand_dims(uint64_result, axis=0) uint64_result |= np.int64(uint32_scale_deq) - graph.add_initializer('_'.join(node.name, 'scale'), uint64_result) - graph.add_initializer('_'.join(node.name, 'offset'), np.array(0).astype(np.float32)) + graph.add_initializer('_'.join([node.name, 'scale']), uint64_result) + graph.add_initializer('_'.join([node.name, 'offset']), np.array(0).astype(np.float32)) correction = quant_weight[node_name].astype(np.float32).sum(axis=0)*input_offset[node_name].astype(np.float32) return scale, correction @@ -312,9 +312,9 @@ def modify_quant_fuse(unet, quant, param): quant_graph.add_node('_'.join([matmul_node.name, 'quant']), "QuantBatchMatMul", \ inputs=[matmul_node.inputs[0], x, '_'.join([matmul_node.name, 'scale']), '_'.join([matmul_node.name, 'offset']), '_'.join([matmul_node.name, 'bias'])], outputs=[node.outputs[0]], attrs={"dtype":0, "transpose_x2":True}) - graph.remove(pre_node.name, mapping={}) - graph.remove(matmul_node.name, mapping={}) - graph.remove(node.name, mapping={}) + quant_graph.remove(pre_node.name, mapping={}) + quant_graph.remove(matmul_node.name, mapping={}) + quant_graph.remove(node.name, mapping={}) quant_graph.update_map() return quant_graph -- Gitee From ab9bfbe54118d26dd9956532c6b5210f8123c31c Mon Sep 17 00:00:00 2001 From: shikang Date: Fri, 12 Jul 2024 06:11:07 +0000 Subject: [PATCH 4/8] bug fix --- .../stable_diffusionxl/quant_unet.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py index 0774210089..5dea3ca7ed 100644 --- a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py +++ b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/quant_unet.py @@ -276,6 +276,7 @@ def get_quant_data(node, param, graph): graph.add_initializer('_'.join([node.name, 'scale']), uint64_result) graph.add_initializer('_'.join([node.name, 'offset']), np.array(0).astype(np.float32)) correction = quant_weight[node_name].astype(np.float32).sum(axis=0)*input_offset[node_name].astype(np.float32) + return scale, correction @@ -295,9 +296,8 @@ def modify_quant_fuse(unet, quant, param): x = pre_node.inputs[1] w = quant_graph[x].value quant_graph[x].value = w.transpose(1,0) - quant_graph.add_node('_'.join([pre_node.name, 'quant']), "QuantBatchMatMul", \ - inputs=[pre_node.inputs[0], x, '_'.join([pre_node.name, 'scale']), '_'.join([pre_node.name, 'offset'])], \ - outputs=[node.outputs[0]], attrs={"dtype":0, "transpose_x2":True}) + quant_graph.add_node('_'.join([pre_node.name, 'quant']), "QuantBatchMatMul", inputs=[pre_node.inputs[0], x, '_'.join([pre_node.name, 'scale']), \ + '_'.join([pre_node.name, 'offset'])], outputs=[node.outputs[0]], attrs={"dtype":0, "transpose_x2":True}) quant_graph.remove(pre_node.name, mapping={}) quant_graph.remove(node.name, mapping={}) quant_graph.update_map() @@ -309,9 +309,8 @@ def modify_quant_fuse(unet, quant, param): quant_graph[x].value = w.transpose(1,0) ori_bias = np.round(unet_graph[unet_graph[pre_node.name].inputs[0]].value / scale - correction).astype(np.int32) quant_graph.add_initializer('_'.join([matmul_node.name, 'bias']), ori_bias) - quant_graph.add_node('_'.join([matmul_node.name, 'quant']), "QuantBatchMatMul", \ - inputs=[matmul_node.inputs[0], x, '_'.join([matmul_node.name, 'scale']), '_'.join([matmul_node.name, 'offset']), - '_'.join([matmul_node.name, 'bias'])], outputs=[node.outputs[0]], attrs={"dtype":0, "transpose_x2":True}) + quant_graph.add_node('_'.join([matmul_node.name, 'quant']), "QuantBatchMatMul", inputs=[matmul_node.inputs[0], x, '_'.join([matmul_node.name, 'scale']), \ + '_'.join([matmul_node.name, 'offset']), '_'.join([matmul_node.name, 'bias'])], outputs=[node.outputs[0]], attrs={"dtype":0, "transpose_x2":True}) quant_graph.remove(pre_node.name, mapping={}) quant_graph.remove(matmul_node.name, mapping={}) quant_graph.remove(node.name, mapping={}) -- Gitee From 40020eccf82cbd7d345c12b58fc15094245171d7 Mon Sep 17 00:00:00 2001 From: shikang Date: Mon, 15 Jul 2024 06:20:32 +0000 Subject: [PATCH 5/8] debug fix --- .../foundation_models/stable_diffusionxl/README_quant.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md index 613d2d7525..7d602bae08 100644 --- a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md +++ b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md @@ -82,7 +82,7 @@ python3 quant_unet.py \ # unet cd ../unet/ atc --framework=5 \ - --model=./unet_fuse.onnx \ + --model=./unet.onnx \ --output=./unet \ --input_format=NCHW \ --log=error \ -- Gitee From bfe922633c6d38e65453d1039a28da71d92065c2 Mon Sep 17 00:00:00 2001 From: shikang Date: Mon, 15 Jul 2024 06:26:21 +0000 Subject: [PATCH 6/8] debug fix --- .../built-in/foundation_models/stable_diffusionxl/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README.md b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README.md index 374a62afb8..ccde19b6dd 100644 --- a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README.md +++ b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README.md @@ -152,7 +152,7 @@ ```bash bs=1 # 量化模型 - unet_model="models/unet_quant/unet.onnx" + unet_model="models/unet_quant/unet_fuse.onnx" # 非量化模型 unet_model="models/unet/unet.onnx" -- Gitee From 07db030d9321f9db065e4b0c5a6237aec8a1fae6 Mon Sep 17 00:00:00 2001 From: shikang Date: Tue, 16 Jul 2024 06:21:46 +0000 Subject: [PATCH 7/8] bug fixed --- .../foundation_models/stable_diffusionxl/README_quant.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md index 7d602bae08..150d1d5036 100644 --- a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md +++ b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/README_quant.md @@ -60,7 +60,7 @@ python3 quant_unet.py \ 2. 执行ATC命令。 ```bash - # 为减少量化耗时,建议使用bs=1场景进行量化 + # 为减少量化耗时,要求使用bs=1场景进行量化 bs=1 # text_encoder cd ./models/text_encoder -- Gitee From 46511a7f6081ac762012c9d57f3776f28c2d67ec Mon Sep 17 00:00:00 2001 From: shikang Date: Wed, 17 Jul 2024 09:11:27 +0000 Subject: [PATCH 8/8] bug fix --- .../foundation_models/stable_diffusionxl/modify_onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/modify_onnx.py b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/modify_onnx.py index e0a82946be..fee6db54e6 100644 --- a/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/modify_onnx.py +++ b/ACL_PyTorch/built-in/foundation_models/stable_diffusionxl/modify_onnx.py @@ -35,7 +35,7 @@ def add_flash_attention(model, fa_name, soc_type): else: flag = 'attn1' in name if flag: - matmul = model[name[:-3] + 'to_q/MatMul'] + matmul = model[name[:-3] + 'to_q/MatMul_quant'] reshape = model[name[:-3] + 'Reshape'] seqlen = 4096 if soc_type == 3 and model[reshape.inputs[1]].value[1] != seqlen: -- Gitee