diff --git a/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/ControlNet_infer.py b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/ControlNet_infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b07552bad5c574b98c69c21c8ae1df6226676af
--- /dev/null
+++ b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/ControlNet_infer.py
@@ -0,0 +1,245 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import random
+import time
+
+from PIL import Image
+from ais_bench.infer.interface import InferSession
+import cv2
+import einops
+import numpy as np
+from pytorch_lightning import seed_everything
+import torch
+
+from annotator.util import resize_image, HWC3
+from annotator.canny import CannyDetector
+from cldm.model import create_model, load_state_dict
+from cldm.ddim_hacked import DDIMSampler
+import config
+
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="./models/control_sd15_canny.pth",
+ help="Path or name of the pre-trained model."
+ )
+ parser.add_argument(
+ "--image",
+ type=str,
+ default="./test_imgs/dog.png",
+ help="Path or name of the image."
+ )
+ parser.add_argument(
+ "--prompt",
+ type=str,
+ default="Prompt",
+ help="label=Prompt"
+ )
+ parser.add_argument(
+ "--a_prompt",
+ type=str,
+ default="best quality, extremely detailed",
+ help="added prompt"
+ )
+ parser.add_argument(
+ "--n_prompt",
+ type=str,
+ default="longbody, lowres, bad anatomy, bad hands, missing fingers, "
+ "extra digit, fewer digits, cropped, worst quality, low quality",
+ help="negative prompt"
+ )
+ parser.add_argument(
+ "--num_samples",
+ type=int,
+ default=1,
+ help="image_num"
+ )
+ parser.add_argument(
+ "--image_resolution",
+ type=int,
+ default=512,
+ help="image resolution"
+ )
+ parser.add_argument(
+ "--guess_mode",
+ type=bool,
+ default=False,
+ help="guess mode"
+ )
+ parser.add_argument(
+ "--strength",
+ type=float,
+ default=1.0,
+ help="control strength"
+ )
+ parser.add_argument(
+ "--scale",
+ type=float,
+ default=9.0,
+ help="guidance scale"
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=200,
+ help="seed"
+ )
+ parser.add_argument(
+ "--eta",
+ type=float,
+ default=0.0,
+ help="eta"
+ )
+ parser.add_argument(
+ "--low_threshold",
+ type=int,
+ default=100,
+ help="canny low threshold"
+ )
+ parser.add_argument(
+ "--high_threshold",
+ type=int,
+ default=200,
+ help="canny high threshold"
+ )
+ parser.add_argument(
+ "--control_model_dir",
+ type=str,
+ default="./models",
+ help="Base path of om models "
+ )
+ parser.add_argument(
+ "--sd_model_dir",
+ type=str,
+ default="./models",
+ help="Base path of om models."
+ )
+ parser.add_argument(
+ "--save_dir",
+ type=str,
+ default="./results",
+ help="Path to save result images."
+ )
+ parser.add_argument(
+ "--ddim_steps",
+ type=int,
+ default=20,
+ help="Number of inference steps."
+ )
+ parser.add_argument(
+ "--device",
+ type=int,
+ default=0,
+ help="NPU device id."
+ )
+
+ return parser.parse_args()
+
+def process(model, ddim_sampler, sd_session, control_session, input_image,
+ prompt, a_prompt, n_prompt, num_samples, image_resolution,
+ ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold,
+ high_threshold,
+ ):
+ with torch.no_grad():
+ img = resize_image(HWC3(input_image), image_resolution)
+ H, W, C = img.shape
+
+ apply_canny = CannyDetector()
+ detected_map = apply_canny(img, low_threshold, high_threshold)
+ detected_map= HWC3(detected_map)
+
+ control = torch.from_numpy(detected_map.copy()).float().cpu() / 255.0
+ control = torch.stack([control for _ in range(1)],dim=0)
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
+
+
+ if seed == -1:
+ seed = random.randint(0, 65535)
+ seed_everything(seed)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ cond = {"c_concat": [control], "c_crossattn":
+ [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
+ un_cond = {"c_concat": None if guess_mode else [control],
+ "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
+ shape = (4, H // 8, W // 8)
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=True)
+
+ model.control_scales = [
+ strength * (0.825 ** float(12 - i)) for i in range(13)]
+ if guess_mode else ([strength] * 13
+ )
+
+ samples, intermediates = ddim_sampler.sample(
+ ddim_steps, num_samples,
+ shape, sd_session, control_session,
+ cond, verbose=False, eta=eta,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=un_cond
+ )
+
+ if config.save_memory:
+ model.low_vram_shift(is_diffusing=False)
+
+ x_samples = model.decode_first_stage(samples)
+ x_samples = (
+ einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + \
+ 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
+
+ results = [x_samples[i] for i in range(num_samples)]
+
+ return [255 - detected_map] + results
+
+def main():
+ args = parse_arguments()
+
+ model = create_model("./models/cldm_v15.yaml").cpu()
+ model.load_state_dict(load_state_dict(args.model, location="cpu"))
+ model = model.cpu()
+
+ ddim_sampler = DDIMSampler(model)
+
+ sd_om = args.sd_model_dir
+ control_om = args.control_model_dir
+
+ sd_session = InferSession(args.device, sd_om)
+ control_session = InferSession(args.device, control_om)
+
+ input_image = cv2.imread(args.image)
+ output = process(model, ddim_sampler, sd_session, control_session, input_image,
+ args.prompt, args.a_prompt, args.n_prompt, args.num_samples,
+ args.image_resolution, args.ddim_steps, args.guess_mode,
+ args.strength, args.scale, args.seed, args.eta, args.low_threshold,
+ args.high_threshold)
+
+ if not os.path.exists(args.save_dir):
+ os.makedirs(args.save_dir, mode=0o744)
+ img0 = Image.fromarray(output[0])
+ img1 = Image.fromarray(output[1])
+ img0.save(os.path.join(args.save_dir, "cannyimg.png"))
+ img1.save(os.path.join(args.save_dir, "diffusionimg.png"))
+
+if __name__=="__main__":
+ main()
\ No newline at end of file
diff --git a/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/ControlNet_pth2onnx.py b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/ControlNet_pth2onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ffa007f481309ec051938524c714a0327c19f98
--- /dev/null
+++ b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/ControlNet_pth2onnx.py
@@ -0,0 +1,96 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+
+import torch
+from cldm.model import create_model, load_state_dict
+
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="./models/control_sd15_canny.pth",
+ help="Path or name of the pre-trained model.",
+ )
+ parser.add_argument(
+ "--control_path",
+ type=str,
+ default="./control.onnx",
+ help="path or name of the control.",
+ )
+ parser.add_argument(
+ "--sd_path",
+ type=str,
+ default="./sd.onnx",
+ help="Path or name of the sd.",
+ )
+
+ return parser.parse_args()
+
+def export_control(model, control_path):
+ model = model.control_model.eval()
+ dummy_input =(
+ torch.randn(1, 4, 64, 72),
+ torch.randn(1, 3, 512, 576),
+ torch.tensor([1]),
+ torch.randn(1, 77, 768)
+ )
+
+ torch.onnx.export(model, dummy_input, control_path,
+ input_names = ["text", "hint", "t","cond_text"],
+ output_names = ["text_outs"], verbose=False, export_params=True,
+ opset_version=13)
+
+def export_sd(model, sd_path):
+ model = model.model.diffusion_model.eval()
+ dummy_input = (
+ torch.randn(1, 4, 64, 72),
+ torch.tensor([1]),
+ torch.randn(1, 77, 768),
+ [torch.randn([1, 320, 64, 72]), torch.randn([1, 320, 64, 72]),
+ torch.randn([1, 320, 64, 72]), torch.randn([1, 320, 32, 36]),
+ torch.randn([1, 640, 32, 36]), torch.randn([1, 640, 32, 36]),
+ torch.randn([1, 640, 16, 18]), torch.randn([1, 1280, 16, 18]),
+ torch.randn([1, 1280, 16, 18]), torch.randn([1, 1280, 8, 9]),
+ torch.randn([1, 1280, 8, 9]), torch.randn([1, 1280, 8, 9]),
+ torch.randn([1, 1280, 8, 9])]
+ )
+
+ torch.onnx.export(model, dummy_input, sd_path,
+ input_names = ["text", "t", "cond_text",
+ "input1", "input2", "input3", "input4",
+ "input5", "input6", "input7", "input8",
+ "input9", "input10", "input11", "input12",
+ "input13"],
+ output_names = ["text_outs"],
+ verbose=False, export_params=True, opset_version=13)
+
+def main():
+ args = parse_arguments()
+ model = create_model("./models/cldm_v15.yaml" ).cpu()
+ model.load_state_dict(load_state_dict(args.model, location='cpu'))
+ if not os.path.exists(args.control_path):
+ os.makedirs(args.control_path, mode=0o744)
+ if not os.path.exists(args.sd_path):
+ os.makedirs(args.sd_path, mode=0o744)
+ export_control(model, os.path.join(args.control_path, "control.onnx"))
+ print("control model done")
+ export_sd(model, os.path.join(args.sd_path, "sd.onnx"))
+ print("sd model done")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/README.md b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..194116606e21d90399f445f05b49700b2c2bf9f9
--- /dev/null
+++ b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/README.md
@@ -0,0 +1,254 @@
+# ControlNet模型-推理指导
+
+
+- [概述](#ZH-CN_TOPIC_0000001172161501)
+
+ - [输入输出数据](#section540883920406)
+
+- [推理环境准备](#ZH-CN_TOPIC_0000001126281702)
+
+- [快速上手](#ZH-CN_TOPIC_0000001126281700)
+
+ - [获取源码](#section4622531142816)
+ - [准备数据集](#section183221994411)
+ - [模型推理](#section741711594517)
+
+- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573)
+
+
+# 概述
+
+ ControlNet是一种神经网络结构,通过添加额外的条件来控制扩散模型,它将神经网络块的权重复制到“锁定”副本和“可训练”副本中。“可训练”的人会了解您的病情。“锁定”的模型会保留您的模型。因此,使用图像对的小数据集进行训练不会破坏生产就绪的扩散模型。“零卷积”是 1×1 卷积,权重和偏差都初始化为零。在训练之前,所有零卷积都输出零,ControlNet 不会造成任何失真。这允许在小规模甚至个人设备上进行培训。这也有利于合并/替换/偏移模型/权重/块/层。
+
+- 参考实现:
+ ```
+ url=https://github.com/lllyasviel/ControlNet
+ branch=main
+ commit_id=ed85cd1e25a5ed592f7d8178495b4483de0331bf
+ ```
+
+## 输入输出数据
+
+- 输入数据
+
+ | 输入数据 | 大小 | 数据类型 | 数据排布格式 |
+ | -------- | -------- | ------------------------- | ------------ |
+ | text | 1 x 4 x 64 x 72 | FLOAT32 | NCHW|
+ | hint | 1 x 3 x 512 x 576 | FLOAT32 | NCHW|
+ | t | 1 | INT64 | ND|
+ | cond_text| 1 x 77 x 768 | FLOAT32| ND|
+
+
+- 输出数据
+
+ | 输出数据 | 大小 | 数据类型 | 数据排布格式 |
+ | -------- | -------- | -------- | ------------ |
+ | text_outs | 1 x 4 x 64 x 72 | FLOAT32 | NCHW |
+
+# 推理环境准备
+
+- 该模型需要以下插件与驱动
+
+ **表 1** 版本配套表
+ | 配套 | 版本 | 环境准备指导 |
+ | ------------------------------------------------------------ | ------- | ------------------------------------------------------------ |
+ | 固件与驱动 | 23.0.rc3 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) |
+ | CANN | 7.0.0 | - |
+ | Python | 3.8.5 | - | |
+
+
+# 快速上手
+
+## 获取源码
+
+1. 获取源码
+ ```
+ git clone https://github.com/lllyasviel/ControlNet
+ mv ControlNet_pth2onnx.py ControlNet_infer.py pipeline.py ControlNet/
+ ```
+
+2. 安装依赖。
+ ```
+ pip3 install -r requirements.txt
+ ```
+
+3. 代码修改
+
+ 执行命令:
+
+ ```
+ patch -p1 < differences.patch
+ ```
+
+4. 安装昇腾统一推理工具(AIT)
+
+ 请访问[AIT代码仓](https://gitee.com/ascend/ait/tree/master/ait#ait),根据readme文档进行工具安装。
+
+ 安装AIT时,可只安装需要的组件:benchmark,其他组件为可选安装。
+
+## 准备数据集
+
+1. 获取原始数据集。
+
+ 本模型输入原始图片和文本生成图片,无需数据集。
+
+## 模型推理
+
+1. 模型转换。
+ 使用PyTorch将模型权重文件.pth转换为.onnx文件,再使用ATC工具将.onnx文件转为离线推理模型文件.om文件。
+
+ 1. 获取权重
+
+
+ ```
+ 训练权重链接为:"https://huggingface.co/lllyasviel/ControlNet/blob/main/models/control_sd15_canny.pth"。
+ 下载后放入`ControlNet/models`工作目录下.
+ ```
+
+ 2. 导出ONNX模型
+
+
+ ```
+ 可提前下载openclip权重放入'ControlNet/openai/clip-vit-large-patch14',以避免执行后面步骤时可能会出现下载失败。
+ # 需要使用 git-lfs (https://git-lfs.com)
+ git lfs install
+ cd ControlNet
+ git clone https://huggingface.co/openai/clip-vit-large-patch14
+ ```
+
+
+
+ 执行命令:
+
+ ```
+ python ControlNet_pth2onnx.py --model ./models/control_sd15_canny.pth --control_path onnx/control/ --sd_path onnx/sd/
+ ```
+
+ 参数说明:
+ - --model:本地模型目录的路径
+ - --control_path: control部分ONNX模型输出目录
+ - --sd_path:sd部分onnx模型输出目录
+
+
+
+ 3. 使用ATC工具将ONNX模型转OM模型。
+
+ 1. 配置环境变量。
+
+ ```
+ source /usr/local/Ascend/ascend-toolkit/set_env.sh
+ ```
+
+ > **说明:**
+ >该脚本中环境变量仅供参考,请以实际安装环境配置环境变量。详细介绍请参见《[CANN 开发辅助工具指南 \(推理\)](https://support.huawei.com/enterprise/zh/ascend-computing/cann-pid-251168373?category=developer-documents&subcategory=auxiliary-development-tools)》。
+
+ 2. 执行命令查看芯片名称($\{chip\_name\})。
+
+ ```
+ npu-smi info
+ #该设备芯片名为Ascend310P3 (自行替换)
+ 回显如下:
+ +-------------------+-----------------+------------------------------------------------------+
+ | NPU Name | Health | Power(W) Temp(C) Hugepages-Usage(page) |
+ | Chip Device | Bus-Id | AICore(%) Memory-Usage(MB) |
+ +===================+=================+======================================================+
+ | 0 310P3 | OK | 15.8 42 0 / 0 |
+ | 0 0 | 0000:82:00.0 | 0 1074 / 21534 |
+ +===================+=================+======================================================+
+ | 1 310P3 | OK | 15.4 43 0 / 0 |
+ | 0 1 | 0000:89:00.0 | 0 1070 / 21534 |
+ +===================+=================+======================================================+
+ ```
+
+ 3. 执行ATC命令。
+
+ ```
+ # control
+ atc --framework=5 \
+ --model=./onnx/control/control.onnx \
+ --output=./om/control/control \
+ --input_format=ND \
+ --log=error \
+ --soc_version=Ascend${chip_name}
+
+ # sd
+ atc --framework=5 \
+ --model=./onnx/sd/sd.onnx \
+ --output=./om/sd/sd \
+ --input_format=ND \
+ --log=error \
+ --soc_version=Ascend${chip_name}
+
+ ```
+
+ 参数说明:
+ - --model:为ONNX模型文件。
+ - --output:输出的OM模型。
+ - --framework:5代表ONNX模型。
+ - --log:日志级别。
+ - --soc_version:处理器型号。
+
+2. 开始推理验证。
+
+ 1. 执行推理脚本。
+ ```
+
+ python3 ControlNet_infer.py \
+ --model ./models/control_sd15_canny.pth \
+ --image test_imgs/dog.png
+ --prompt "cute dog" \
+ --device 0 \
+ --control_model_dir om/control/control.om \
+ --sd_model_dir om/sd/sd.om \
+ --save_dir ./results \
+ --ddim_steps 20
+
+ ```
+
+ 参数说明:
+ - --model:本地模型目录的路径。
+ - --prompt:文本信息。
+ - --save_dir:生成图片的存放目录。
+ - --ddim_steps:生成图片次数。
+ - --image: 输入图片。
+ - --control_model_dir: control的om位置。
+ - --sd_model_dir: sd的om位置。
+ - --device:推理设备ID。
+
+ 执行完成后在`./results`目录下生成推理图片。推理一张图片会输出一张图片边缘的图片,和一张跟据输入图片和文本重新生成的图片。
+
+3. 推理结果图片展示。
+
+ 在线模型推理结果:
+
+ 
+
+ 离线模型推理结果:
+
+ 
+
+
+4. 性能验证。
+
+ 可使用ais_bench推理工具的纯推理模式验证om模型的性能,参考命令如下:
+
+ ```
+
+ python -m ais_bench --model=${om_model_path} --loop=10 --batchsize=${batch_size} --device 0,1
+ ```
+
+ - 参数说明:
+ - --model:模型路径
+ - --batchsize:每批次样本数量
+ - --device:Duo卡测试时设两张芯片
+
+
+# 模型推理性能&精度
+
+调用ACL接口推理计算,性能参考下列数据。
+
+
+| 硬件形态 | 性能 | 模型|
+| :------: | :--------: | :--------: |
+| Duo并行 | 33.6fps | control |
+| Duo并行 | 13.6fps | sd |
\ No newline at end of file
diff --git a/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/differences.patch b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/differences.patch
new file mode 100644
index 0000000000000000000000000000000000000000..5a1b189ba4a3f75f4f6431bc99a5a62fd0526e01
--- /dev/null
+++ b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/differences.patch
@@ -0,0 +1,125 @@
+diff -Naru a/ControlNet/cldm/ddim_hacked.py b/ControlNet/cldm/ddim_hacked.py
+--- a/ControlNet/cldm/ddim_hacked.py 2023-11-17 23:27:48.688000000 +0800
++++ b/ControlNet/cldm/ddim_hacked.py 2023-11-17 23:15:45.548000000 +0800
+@@ -16,8 +16,8 @@
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+- if attr.device != torch.device("cuda"):
+- attr = attr.to(torch.device("cuda"))
++ if attr.device != torch.device("cpu"):
++ attr = attr.to(torch.device("cpu"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+@@ -56,6 +56,8 @@
+ S,
+ batch_size,
+ shape,
++ sd_session,
++ control_session,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+@@ -101,6 +103,8 @@
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling(conditioning, size,
++ sd_session,
++ control_session,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+@@ -120,7 +124,7 @@
+ return samples, intermediates
+
+ @torch.no_grad()
+- def ddim_sampling(self, cond, shape,
++ def ddim_sampling(self, cond, shape, sd_session, control_session,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+@@ -160,7 +164,7 @@
+ assert len(ucg_schedule) == len(time_range)
+ unconditional_guidance_scale = ucg_schedule[i]
+
+- outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
++ outs = self.p_sample_ddim(img, cond, ts, sd_session, control_session, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+@@ -178,7 +182,7 @@
+ return img, intermediates
+
+ @torch.no_grad()
+- def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
++ def p_sample_ddim(self, x, c, t, sd_session, control_session, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ dynamic_threshold=None):
+@@ -187,8 +191,8 @@
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ model_output = self.model.apply_model(x, t, c)
+ else:
+- model_t = self.model.apply_model(x, t, c)
+- model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
++ model_t = self.model.apply_model(x, t, c, sd_session, control_session)
++ model_uncond = self.model.apply_model(x, t, unconditional_conditioning, sd_session, control_session)
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
+
+ if self.model.parameterization == "v":
+diff -Naru a/ControlNet/ldm/modules/attention.py b/ControlNet/ldm/modules/attention.py
+--- a/ControlNet/ldm/modules/attention.py 2023-11-17 23:27:49.192000000 +0800
++++ b/ControlNet/ldm/modules/attention.py 2023-11-17 23:17:33.896000000 +0800
+@@ -174,7 +174,7 @@
+ if _ATTN_PRECISION =="fp32":
+ with torch.autocast(enabled=False, device_type = 'cuda'):
+ q, k = q.float(), k.float()
+- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
++ sim = einsum('bid,bjd->bij', q, k) * self.scale
+ else:
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+@@ -189,7 +189,7 @@
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+
+- out = einsum('b i j, b j d -> b i d', sim, v)
++ out = einsum('bij,bjd->bid', sim, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+diff -Naru a/ControlNet/ldm/modules/diffusionmodules/util.py b/ControlNet/ldm/modules/diffusionmodules/util.py
+--- a/ControlNet/ldm/modules/diffusionmodules/util.py 2023-11-17 23:27:49.192000000 +0800
++++ b/ControlNet/ldm/modules/diffusionmodules/util.py 2023-11-17 23:23:05.404000000 +0800
+@@ -109,7 +109,7 @@
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+- if flag:
++ if flag and not torch.onnx.is_in_onnx_export():
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+diff -Naru a/ControlNet/ldm/modules/encoders/modules.py b/ControlNet/ldm/modules/encoders/modules.py
+--- a/ControlNet/ldm/modules/encoders/modules.py 2023-11-17 23:27:49.192000000 +0800
++++ b/ControlNet/ldm/modules/encoders/modules.py 2023-11-17 23:20:02.000000000 +0800
+@@ -92,7 +92,7 @@
+ "pooled",
+ "hidden"
+ ]
+- def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
++ def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
+ freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
+ super().__init__()
+ assert layer in self.LAYERS
+diff -Naru a/ControlNet/models/cldm_v15.yaml b/ControlNet/models/cldm_v15.yaml
+--- a/ControlNet/models/cldm_v15.yaml 2023-11-17 23:27:49.196000000 +0800
++++ b/ControlNet/models/cldm_v15.yaml 2023-11-17 23:18:39.812000000 +0800
+@@ -1,5 +1,5 @@
+ model:
+- target: cldm.cldm.ControlLDM
++ target: pipeline.AscendControlNet
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
diff --git a/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/pipeline.py b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..388df82f2f1c50399a50122aed814c4a075a69b6
--- /dev/null
+++ b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/pipeline.py
@@ -0,0 +1,46 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from cldm.cldm import ControlLDM
+
+class AscendControlNet(ControlLDM):
+ def apply_model(self, x_noisy, t, cond, sd_session, control_session):
+ assert isinstance(cond, dict)
+ cond_txt = torch.cat(cond["c_crossattn"], 1)
+ mode = "static"
+
+ control = control_session.infer(
+ [
+ x_noisy.numpy(),
+ torch.cat(cond["c_concat"], 1).numpy(),
+ t.numpy(),
+ cond_txt.numpy()
+ ], mode
+ )
+
+ control = [c * scale for c, scale in zip(control, self.control_scales)]
+
+ eps = torch.from_numpy(
+ sd_session.infer(
+ [
+ x_noisy.numpy(),
+ t.numpy(),
+ cond_txt .numpy()
+ ]
+ + control, mode
+ )[0]
+ )
+
+ return eps
diff --git a/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/requirements.txt b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b53c8a9faa68136e7f0f23d36fa8ab8985ac5118
--- /dev/null
+++ b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/requirements.txt
@@ -0,0 +1,24 @@
+torch==1.12.1
+torchvision==0.13.1
+numpy==1.23.1
+albumentations==1.3.0
+opencv-python-headless==4.8.1.78
+imageio==2.9.0
+imageio-ffmpeg==0.4.2
+pytorch-lightning==1.5.0
+omegaconf==2.1.1
+test-tube>=0.7.5
+streamlit==1.12.1
+einops==0.3.0
+transformers==4.19.2
+webdataset==0.2.5
+kornia==0.6
+open_clip_torch==2.0.2
+invisible-watermark>=0.1.5
+streamlit-drawable-canvas==0.8.0
+torchmetrics==0.6.0
+timm==0.6.12
+addict==2.4.0
+yapf==0.32.0
+prettytable==3.6.0
+safetensors==0.2.7
\ No newline at end of file
diff --git a/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/test_results/dog0.PNG b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/test_results/dog0.PNG
new file mode 100644
index 0000000000000000000000000000000000000000..3cfe8346ddf339ef4a1cfb276a752f3d5a05e1b1
Binary files /dev/null and b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/test_results/dog0.PNG differ
diff --git a/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/test_results/dog1.PNG b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/test_results/dog1.PNG
new file mode 100644
index 0000000000000000000000000000000000000000..3cfe8346ddf339ef4a1cfb276a752f3d5a05e1b1
Binary files /dev/null and b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/test_results/dog1.PNG differ