diff --git a/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/ControlNet_infer.py b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/ControlNet_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..4b07552bad5c574b98c69c21c8ae1df6226676af --- /dev/null +++ b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/ControlNet_infer.py @@ -0,0 +1,245 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import random +import time + +from PIL import Image +from ais_bench.infer.interface import InferSession +import cv2 +import einops +import numpy as np +from pytorch_lightning import seed_everything +import torch + +from annotator.util import resize_image, HWC3 +from annotator.canny import CannyDetector +from cldm.model import create_model, load_state_dict +from cldm.ddim_hacked import DDIMSampler +import config + + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + default="./models/control_sd15_canny.pth", + help="Path or name of the pre-trained model." + ) + parser.add_argument( + "--image", + type=str, + default="./test_imgs/dog.png", + help="Path or name of the image." + ) + parser.add_argument( + "--prompt", + type=str, + default="Prompt", + help="label=Prompt" + ) + parser.add_argument( + "--a_prompt", + type=str, + default="best quality, extremely detailed", + help="added prompt" + ) + parser.add_argument( + "--n_prompt", + type=str, + default="longbody, lowres, bad anatomy, bad hands, missing fingers, " + "extra digit, fewer digits, cropped, worst quality, low quality", + help="negative prompt" + ) + parser.add_argument( + "--num_samples", + type=int, + default=1, + help="image_num" + ) + parser.add_argument( + "--image_resolution", + type=int, + default=512, + help="image resolution" + ) + parser.add_argument( + "--guess_mode", + type=bool, + default=False, + help="guess mode" + ) + parser.add_argument( + "--strength", + type=float, + default=1.0, + help="control strength" + ) + parser.add_argument( + "--scale", + type=float, + default=9.0, + help="guidance scale" + ) + parser.add_argument( + "--seed", + type=int, + default=200, + help="seed" + ) + parser.add_argument( + "--eta", + type=float, + default=0.0, + help="eta" + ) + parser.add_argument( + "--low_threshold", + type=int, + default=100, + help="canny low threshold" + ) + parser.add_argument( + "--high_threshold", + type=int, + default=200, + help="canny high threshold" + ) + parser.add_argument( + "--control_model_dir", + type=str, + default="./models", + help="Base path of om models " + ) + parser.add_argument( + "--sd_model_dir", + type=str, + default="./models", + help="Base path of om models." + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result images." + ) + parser.add_argument( + "--ddim_steps", + type=int, + default=20, + help="Number of inference steps." + ) + parser.add_argument( + "--device", + type=int, + default=0, + help="NPU device id." + ) + + return parser.parse_args() + +def process(model, ddim_sampler, sd_session, control_session, input_image, + prompt, a_prompt, n_prompt, num_samples, image_resolution, + ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, + high_threshold, + ): + with torch.no_grad(): + img = resize_image(HWC3(input_image), image_resolution) + H, W, C = img.shape + + apply_canny = CannyDetector() + detected_map = apply_canny(img, low_threshold, high_threshold) + detected_map= HWC3(detected_map) + + control = torch.from_numpy(detected_map.copy()).float().cpu() / 255.0 + control = torch.stack([control for _ in range(1)],dim=0) + control = einops.rearrange(control, 'b h w c -> b c h w').clone() + + + if seed == -1: + seed = random.randint(0, 65535) + seed_everything(seed) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + cond = {"c_concat": [control], "c_crossattn": + [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} + un_cond = {"c_concat": None if guess_mode else [control], + "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} + shape = (4, H // 8, W // 8) + + if config.save_memory: + model.low_vram_shift(is_diffusing=True) + + model.control_scales = [ + strength * (0.825 ** float(12 - i)) for i in range(13)] + if guess_mode else ([strength] * 13 + ) + + samples, intermediates = ddim_sampler.sample( + ddim_steps, num_samples, + shape, sd_session, control_session, + cond, verbose=False, eta=eta, + unconditional_guidance_scale=scale, + unconditional_conditioning=un_cond + ) + + if config.save_memory: + model.low_vram_shift(is_diffusing=False) + + x_samples = model.decode_first_stage(samples) + x_samples = ( + einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + \ + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) + + results = [x_samples[i] for i in range(num_samples)] + + return [255 - detected_map] + results + +def main(): + args = parse_arguments() + + model = create_model("./models/cldm_v15.yaml").cpu() + model.load_state_dict(load_state_dict(args.model, location="cpu")) + model = model.cpu() + + ddim_sampler = DDIMSampler(model) + + sd_om = args.sd_model_dir + control_om = args.control_model_dir + + sd_session = InferSession(args.device, sd_om) + control_session = InferSession(args.device, control_om) + + input_image = cv2.imread(args.image) + output = process(model, ddim_sampler, sd_session, control_session, input_image, + args.prompt, args.a_prompt, args.n_prompt, args.num_samples, + args.image_resolution, args.ddim_steps, args.guess_mode, + args.strength, args.scale, args.seed, args.eta, args.low_threshold, + args.high_threshold) + + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir, mode=0o744) + img0 = Image.fromarray(output[0]) + img1 = Image.fromarray(output[1]) + img0.save(os.path.join(args.save_dir, "cannyimg.png")) + img1.save(os.path.join(args.save_dir, "diffusionimg.png")) + +if __name__=="__main__": + main() \ No newline at end of file diff --git a/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/ControlNet_pth2onnx.py b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/ControlNet_pth2onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..3ffa007f481309ec051938524c714a0327c19f98 --- /dev/null +++ b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/ControlNet_pth2onnx.py @@ -0,0 +1,96 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch +from cldm.model import create_model, load_state_dict + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + default="./models/control_sd15_canny.pth", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "--control_path", + type=str, + default="./control.onnx", + help="path or name of the control.", + ) + parser.add_argument( + "--sd_path", + type=str, + default="./sd.onnx", + help="Path or name of the sd.", + ) + + return parser.parse_args() + +def export_control(model, control_path): + model = model.control_model.eval() + dummy_input =( + torch.randn(1, 4, 64, 72), + torch.randn(1, 3, 512, 576), + torch.tensor([1]), + torch.randn(1, 77, 768) + ) + + torch.onnx.export(model, dummy_input, control_path, + input_names = ["text", "hint", "t","cond_text"], + output_names = ["text_outs"], verbose=False, export_params=True, + opset_version=13) + +def export_sd(model, sd_path): + model = model.model.diffusion_model.eval() + dummy_input = ( + torch.randn(1, 4, 64, 72), + torch.tensor([1]), + torch.randn(1, 77, 768), + [torch.randn([1, 320, 64, 72]), torch.randn([1, 320, 64, 72]), + torch.randn([1, 320, 64, 72]), torch.randn([1, 320, 32, 36]), + torch.randn([1, 640, 32, 36]), torch.randn([1, 640, 32, 36]), + torch.randn([1, 640, 16, 18]), torch.randn([1, 1280, 16, 18]), + torch.randn([1, 1280, 16, 18]), torch.randn([1, 1280, 8, 9]), + torch.randn([1, 1280, 8, 9]), torch.randn([1, 1280, 8, 9]), + torch.randn([1, 1280, 8, 9])] + ) + + torch.onnx.export(model, dummy_input, sd_path, + input_names = ["text", "t", "cond_text", + "input1", "input2", "input3", "input4", + "input5", "input6", "input7", "input8", + "input9", "input10", "input11", "input12", + "input13"], + output_names = ["text_outs"], + verbose=False, export_params=True, opset_version=13) + +def main(): + args = parse_arguments() + model = create_model("./models/cldm_v15.yaml" ).cpu() + model.load_state_dict(load_state_dict(args.model, location='cpu')) + if not os.path.exists(args.control_path): + os.makedirs(args.control_path, mode=0o744) + if not os.path.exists(args.sd_path): + os.makedirs(args.sd_path, mode=0o744) + export_control(model, os.path.join(args.control_path, "control.onnx")) + print("control model done") + export_sd(model, os.path.join(args.sd_path, "sd.onnx")) + print("sd model done") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/README.md b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/README.md new file mode 100644 index 0000000000000000000000000000000000000000..194116606e21d90399f445f05b49700b2c2bf9f9 --- /dev/null +++ b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/README.md @@ -0,0 +1,254 @@ +# ControlNet模型-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + + - [输入输出数据](#section540883920406) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [获取源码](#section4622531142816) + - [准备数据集](#section183221994411) + - [模型推理](#section741711594517) + +- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573) + + +# 概述 + + ControlNet是一种神经网络结构,通过添加额外的条件来控制扩散模型,它将神经网络块的权重复制到“锁定”副本和“可训练”副本中。“可训练”的人会了解您的病情。“锁定”的模型会保留您的模型。因此,使用图像对的小数据集进行训练不会破坏生产就绪的扩散模型。“零卷积”是 1×1 卷积,权重和偏差都初始化为零。在训练之前,所有零卷积都输出零,ControlNet 不会造成任何失真。这允许在小规模甚至个人设备上进行培训。这也有利于合并/替换/偏移模型/权重/块/层。 + +- 参考实现: + ``` + url=https://github.com/lllyasviel/ControlNet + branch=main + commit_id=ed85cd1e25a5ed592f7d8178495b4483de0331bf + ``` + +## 输入输出数据 + +- 输入数据 + + | 输入数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | ------------------------- | ------------ | + | text | 1 x 4 x 64 x 72 | FLOAT32 | NCHW| + | hint | 1 x 3 x 512 x 576 | FLOAT32 | NCHW| + | t | 1 | INT64 | ND| + | cond_text| 1 x 77 x 768 | FLOAT32| ND| + + +- 输出数据 + + | 输出数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | -------- | ------------ | + | text_outs | 1 x 4 x 64 x 72 | FLOAT32 | NCHW | + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + | 配套 | 版本 | 环境准备指导 | + | ------------------------------------------------------------ | ------- | ------------------------------------------------------------ | + | 固件与驱动 | 23.0.rc3 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) | + | CANN | 7.0.0 | - | + | Python | 3.8.5 | - | | + + +# 快速上手 + +## 获取源码 + +1. 获取源码 + ``` + git clone https://github.com/lllyasviel/ControlNet + mv ControlNet_pth2onnx.py ControlNet_infer.py pipeline.py ControlNet/ + ``` + +2. 安装依赖。 + ``` + pip3 install -r requirements.txt + ``` + +3. 代码修改 + + 执行命令: + + ``` + patch -p1 < differences.patch + ``` + +4. 安装昇腾统一推理工具(AIT) + + 请访问[AIT代码仓](https://gitee.com/ascend/ait/tree/master/ait#ait),根据readme文档进行工具安装。 + + 安装AIT时,可只安装需要的组件:benchmark,其他组件为可选安装。 + +## 准备数据集 + +1. 获取原始数据集。 + + 本模型输入原始图片和文本生成图片,无需数据集。 + +## 模型推理 + +1. 模型转换。 + 使用PyTorch将模型权重文件.pth转换为.onnx文件,再使用ATC工具将.onnx文件转为离线推理模型文件.om文件。 + + 1. 获取权重 + + + ``` + 训练权重链接为:"https://huggingface.co/lllyasviel/ControlNet/blob/main/models/control_sd15_canny.pth"。 + 下载后放入`ControlNet/models`工作目录下. + ``` + + 2. 导出ONNX模型 + + + ``` + 可提前下载openclip权重放入'ControlNet/openai/clip-vit-large-patch14',以避免执行后面步骤时可能会出现下载失败。 + # 需要使用 git-lfs (https://git-lfs.com) + git lfs install + cd ControlNet + git clone https://huggingface.co/openai/clip-vit-large-patch14 + ``` + + + + 执行命令: + + ``` + python ControlNet_pth2onnx.py --model ./models/control_sd15_canny.pth --control_path onnx/control/ --sd_path onnx/sd/ + ``` + + 参数说明: + - --model:本地模型目录的路径 + - --control_path: control部分ONNX模型输出目录 + - --sd_path:sd部分onnx模型输出目录 + + + + 3. 使用ATC工具将ONNX模型转OM模型。 + + 1. 配置环境变量。 + + ``` + source /usr/local/Ascend/ascend-toolkit/set_env.sh + ``` + + > **说明:** + >该脚本中环境变量仅供参考,请以实际安装环境配置环境变量。详细介绍请参见《[CANN 开发辅助工具指南 \(推理\)](https://support.huawei.com/enterprise/zh/ascend-computing/cann-pid-251168373?category=developer-documents&subcategory=auxiliary-development-tools)》。 + + 2. 执行命令查看芯片名称($\{chip\_name\})。 + + ``` + npu-smi info + #该设备芯片名为Ascend310P3 (自行替换) + 回显如下: + +-------------------+-----------------+------------------------------------------------------+ + | NPU Name | Health | Power(W) Temp(C) Hugepages-Usage(page) | + | Chip Device | Bus-Id | AICore(%) Memory-Usage(MB) | + +===================+=================+======================================================+ + | 0 310P3 | OK | 15.8 42 0 / 0 | + | 0 0 | 0000:82:00.0 | 0 1074 / 21534 | + +===================+=================+======================================================+ + | 1 310P3 | OK | 15.4 43 0 / 0 | + | 0 1 | 0000:89:00.0 | 0 1070 / 21534 | + +===================+=================+======================================================+ + ``` + + 3. 执行ATC命令。 + + ``` + # control + atc --framework=5 \ + --model=./onnx/control/control.onnx \ + --output=./om/control/control \ + --input_format=ND \ + --log=error \ + --soc_version=Ascend${chip_name} + + # sd + atc --framework=5 \ + --model=./onnx/sd/sd.onnx \ + --output=./om/sd/sd \ + --input_format=ND \ + --log=error \ + --soc_version=Ascend${chip_name} + + ``` + + 参数说明: + - --model:为ONNX模型文件。 + - --output:输出的OM模型。 + - --framework:5代表ONNX模型。 + - --log:日志级别。 + - --soc_version:处理器型号。 + +2. 开始推理验证。 + + 1. 执行推理脚本。 + ``` + + python3 ControlNet_infer.py \ + --model ./models/control_sd15_canny.pth \ + --image test_imgs/dog.png + --prompt "cute dog" \ + --device 0 \ + --control_model_dir om/control/control.om \ + --sd_model_dir om/sd/sd.om \ + --save_dir ./results \ + --ddim_steps 20 + + ``` + + 参数说明: + - --model:本地模型目录的路径。 + - --prompt:文本信息。 + - --save_dir:生成图片的存放目录。 + - --ddim_steps:生成图片次数。 + - --image: 输入图片。 + - --control_model_dir: control的om位置。 + - --sd_model_dir: sd的om位置。 + - --device:推理设备ID。 + + 执行完成后在`./results`目录下生成推理图片。推理一张图片会输出一张图片边缘的图片,和一张跟据输入图片和文本重新生成的图片。 + +3. 推理结果图片展示。 + + 在线模型推理结果: + + ![](./test_results/dog0.PNG) + + 离线模型推理结果: + + ![](./test_results/dog1.PNG) + + +4. 性能验证。 + + 可使用ais_bench推理工具的纯推理模式验证om模型的性能,参考命令如下: + + ``` + + python -m ais_bench --model=${om_model_path} --loop=10 --batchsize=${batch_size} --device 0,1 + ``` + + - 参数说明: + - --model:模型路径 + - --batchsize:每批次样本数量 + - --device:Duo卡测试时设两张芯片 + + +# 模型推理性能&精度 + +调用ACL接口推理计算,性能参考下列数据。 + + +| 硬件形态 | 性能 | 模型| +| :------: | :--------: | :--------: | +| Duo并行 | 33.6fps | control | +| Duo并行 | 13.6fps | sd | \ No newline at end of file diff --git a/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/differences.patch b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/differences.patch new file mode 100644 index 0000000000000000000000000000000000000000..5a1b189ba4a3f75f4f6431bc99a5a62fd0526e01 --- /dev/null +++ b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/differences.patch @@ -0,0 +1,125 @@ +diff -Naru a/ControlNet/cldm/ddim_hacked.py b/ControlNet/cldm/ddim_hacked.py +--- a/ControlNet/cldm/ddim_hacked.py 2023-11-17 23:27:48.688000000 +0800 ++++ b/ControlNet/cldm/ddim_hacked.py 2023-11-17 23:15:45.548000000 +0800 +@@ -16,8 +16,8 @@ + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: +- if attr.device != torch.device("cuda"): +- attr = attr.to(torch.device("cuda")) ++ if attr.device != torch.device("cpu"): ++ attr = attr.to(torch.device("cpu")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): +@@ -56,6 +56,8 @@ + S, + batch_size, + shape, ++ sd_session, ++ control_session, + conditioning=None, + callback=None, + normals_sequence=None, +@@ -101,6 +103,8 @@ + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, ++ sd_session, ++ control_session, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, +@@ -120,7 +124,7 @@ + return samples, intermediates + + @torch.no_grad() +- def ddim_sampling(self, cond, shape, ++ def ddim_sampling(self, cond, shape, sd_session, control_session, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, +@@ -160,7 +164,7 @@ + assert len(ucg_schedule) == len(time_range) + unconditional_guidance_scale = ucg_schedule[i] + +- outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, ++ outs = self.p_sample_ddim(img, cond, ts, sd_session, control_session, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, +@@ -178,7 +182,7 @@ + return img, intermediates + + @torch.no_grad() +- def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, ++ def p_sample_ddim(self, x, c, t, sd_session, control_session, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): +@@ -187,8 +191,8 @@ + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + model_output = self.model.apply_model(x, t, c) + else: +- model_t = self.model.apply_model(x, t, c) +- model_uncond = self.model.apply_model(x, t, unconditional_conditioning) ++ model_t = self.model.apply_model(x, t, c, sd_session, control_session) ++ model_uncond = self.model.apply_model(x, t, unconditional_conditioning, sd_session, control_session) + model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) + + if self.model.parameterization == "v": +diff -Naru a/ControlNet/ldm/modules/attention.py b/ControlNet/ldm/modules/attention.py +--- a/ControlNet/ldm/modules/attention.py 2023-11-17 23:27:49.192000000 +0800 ++++ b/ControlNet/ldm/modules/attention.py 2023-11-17 23:17:33.896000000 +0800 +@@ -174,7 +174,7 @@ + if _ATTN_PRECISION =="fp32": + with torch.autocast(enabled=False, device_type = 'cuda'): + q, k = q.float(), k.float() +- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale ++ sim = einsum('bid,bjd->bij', q, k) * self.scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + +@@ -189,7 +189,7 @@ + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + +- out = einsum('b i j, b j d -> b i d', sim, v) ++ out = einsum('bij,bjd->bid', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + +diff -Naru a/ControlNet/ldm/modules/diffusionmodules/util.py b/ControlNet/ldm/modules/diffusionmodules/util.py +--- a/ControlNet/ldm/modules/diffusionmodules/util.py 2023-11-17 23:27:49.192000000 +0800 ++++ b/ControlNet/ldm/modules/diffusionmodules/util.py 2023-11-17 23:23:05.404000000 +0800 +@@ -109,7 +109,7 @@ + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ +- if flag: ++ if flag and not torch.onnx.is_in_onnx_export(): + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: +diff -Naru a/ControlNet/ldm/modules/encoders/modules.py b/ControlNet/ldm/modules/encoders/modules.py +--- a/ControlNet/ldm/modules/encoders/modules.py 2023-11-17 23:27:49.192000000 +0800 ++++ b/ControlNet/ldm/modules/encoders/modules.py 2023-11-17 23:20:02.000000000 +0800 +@@ -92,7 +92,7 @@ + "pooled", + "hidden" + ] +- def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, ++ def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, + freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS +diff -Naru a/ControlNet/models/cldm_v15.yaml b/ControlNet/models/cldm_v15.yaml +--- a/ControlNet/models/cldm_v15.yaml 2023-11-17 23:27:49.196000000 +0800 ++++ b/ControlNet/models/cldm_v15.yaml 2023-11-17 23:18:39.812000000 +0800 +@@ -1,5 +1,5 @@ + model: +- target: cldm.cldm.ControlLDM ++ target: pipeline.AscendControlNet + params: + linear_start: 0.00085 + linear_end: 0.0120 diff --git a/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/pipeline.py b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..388df82f2f1c50399a50122aed814c4a075a69b6 --- /dev/null +++ b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/pipeline.py @@ -0,0 +1,46 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from cldm.cldm import ControlLDM + +class AscendControlNet(ControlLDM): + def apply_model(self, x_noisy, t, cond, sd_session, control_session): + assert isinstance(cond, dict) + cond_txt = torch.cat(cond["c_crossattn"], 1) + mode = "static" + + control = control_session.infer( + [ + x_noisy.numpy(), + torch.cat(cond["c_concat"], 1).numpy(), + t.numpy(), + cond_txt.numpy() + ], mode + ) + + control = [c * scale for c, scale in zip(control, self.control_scales)] + + eps = torch.from_numpy( + sd_session.infer( + [ + x_noisy.numpy(), + t.numpy(), + cond_txt .numpy() + ] + + control, mode + )[0] + ) + + return eps diff --git a/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/requirements.txt b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b53c8a9faa68136e7f0f23d36fa8ab8985ac5118 --- /dev/null +++ b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/requirements.txt @@ -0,0 +1,24 @@ +torch==1.12.1 +torchvision==0.13.1 +numpy==1.23.1 +albumentations==1.3.0 +opencv-python-headless==4.8.1.78 +imageio==2.9.0 +imageio-ffmpeg==0.4.2 +pytorch-lightning==1.5.0 +omegaconf==2.1.1 +test-tube>=0.7.5 +streamlit==1.12.1 +einops==0.3.0 +transformers==4.19.2 +webdataset==0.2.5 +kornia==0.6 +open_clip_torch==2.0.2 +invisible-watermark>=0.1.5 +streamlit-drawable-canvas==0.8.0 +torchmetrics==0.6.0 +timm==0.6.12 +addict==2.4.0 +yapf==0.32.0 +prettytable==3.6.0 +safetensors==0.2.7 \ No newline at end of file diff --git a/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/test_results/dog0.PNG b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/test_results/dog0.PNG new file mode 100644 index 0000000000000000000000000000000000000000..3cfe8346ddf339ef4a1cfb276a752f3d5a05e1b1 Binary files /dev/null and b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/test_results/dog0.PNG differ diff --git a/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/test_results/dog1.PNG b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/test_results/dog1.PNG new file mode 100644 index 0000000000000000000000000000000000000000..3cfe8346ddf339ef4a1cfb276a752f3d5a05e1b1 Binary files /dev/null and b/ACL_PyTorch/built-in/foundation_models/ControlNet_for_PyTorch/test_results/dog1.PNG differ