From eca06956346472b0b34c8a623bdebf063e2e0068 Mon Sep 17 00:00:00 2001 From: sunyiran Date: Wed, 28 Feb 2024 11:27:08 +0800 Subject: [PATCH 1/2] SVD transfer to npu --- .../diffusion/diffusers0.25.0/README.md | 78 +++++++++ .../requirements_svd.txt | 11 ++ .../test_stable_video_diffusion.py | 159 ++++++++++++++++++ .../diffusers/models/attention_processor.py | 120 +++++++++++++ .../src/diffusers/models/modeling_utils.py | 21 +++ .../src/diffusers/pipelines/pipeline_utils.py | 22 +++ .../pipeline_stable_video_diffusion.py | 5 + .../test/infer_full_8p_svd_fp16.sh | 112 ++++++++++++ 8 files changed, 528 insertions(+) create mode 100644 PyTorch/built-in/diffusion/diffusers0.25.0/examples/stable_video_diffusion/requirements_svd.txt create mode 100644 PyTorch/built-in/diffusion/diffusers0.25.0/examples/stable_video_diffusion/test_stable_video_diffusion.py create mode 100644 PyTorch/built-in/diffusion/diffusers0.25.0/test/infer_full_8p_svd_fp16.sh diff --git a/PyTorch/built-in/diffusion/diffusers0.25.0/README.md b/PyTorch/built-in/diffusion/diffusers0.25.0/README.md index 844efd4198..5a09c80be2 100644 --- a/PyTorch/built-in/diffusion/diffusers0.25.0/README.md +++ b/PyTorch/built-in/diffusion/diffusers0.25.0/README.md @@ -214,8 +214,84 @@ --validation_image //验证使用的图片(仅controlnet微调脚本使用) ``` +# StableVideoDiffusion +## Image to Video任务推理 +### 准备环境 +- 当前模型支持的 PyTorch 版本和已知三方库依赖如下表所示。 + + **表 1** 版本支持表 + + | Torch_Version | 三方库依赖版本 | + | :--------: |:-------------------------------------------------------:| + | PyTorch 2.1 | torchvision==0.16.0 | + +- 环境准备指导。 + + 请参考《[Pytorch框架训练环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/ptes)》搭建torch环境。 + +- 安装依赖。 + + 在模型根目录下执行命令,安装模型对应PyTorch版本需要的依赖。 + ```shell + pip install -e . # 安装diffusers + cd examples/stable_video_diffusion/ # 根据下游任务安装对应依赖 + pip install -r requirements_svd.txt + ``` +### 获取stable-video-diffusion-img2vid-xt权重 + + - 联网情况下,预训练模型会自动下载。 + + - 无网络时,用户可访问huggingface官网自行下载,文件namespace如下: + + ``` + stabilityai/stable-video-diffusion-img2vid-xt + ``` + + - 获取对应的预训练模型后,在shell启动脚本中将`ckpt_path`参数,设置为本地预训练模型路径,填写一级目录。 + + +### 获取测试数据集 + - 准备好i2vgen-xl测试集,放到模型目录下,并重命名为`svd_testdata` + + - 准备数据路径txt文件,存放测试图片的文件名 + + - 参考目录结构为 + + ``` + ├── svd_testdata + ├── img_0001.jpg + ├── img_0002.jpg + ├── ... + ├── imglist.txt + ``` + > **说明:** + >该数据集的推理过程脚本只作为一种参考示例。 + +### 在模型目录下运行推理脚本。 + ```shell + bash test/infer_full_8p_svd_fp16.sh --ckpt_path=xxx --test_data_dir=xxx --test_file=xxx # 八卡推理 + ``` + + 模型推理脚本参数说明如下。 + + ``` + 公共参数: + --ckpt_path // 模型权重加载地址 + --batch_size // 推理的图像全局批大小 + --test_file // 测试图片路径文件 + --test_data_dir // 测试数据集存放目录 + ``` +## 训练结果展示 + +**表 2** 训练结果展示表 + +| NAME | fp16 | Denoise FPS | Batch Size | +|--------|:----:|:-----------:|---------| +| 8p-NPU | fp16 | 3.75 | 8 | +| 8p-竞品A | fp16 | 4.88 | 8 | + # 公网地址说明 代码涉及公网地址参考 public_address_statement.md @@ -226,6 +302,8 @@ 2024.01.30:首次发布。 +2024.02.29:新增StableVideoDiffusion推理。 + ## FAQ diff --git a/PyTorch/built-in/diffusion/diffusers0.25.0/examples/stable_video_diffusion/requirements_svd.txt b/PyTorch/built-in/diffusion/diffusers0.25.0/examples/stable_video_diffusion/requirements_svd.txt new file mode 100644 index 0000000000..c37611926f --- /dev/null +++ b/PyTorch/built-in/diffusion/diffusers0.25.0/examples/stable_video_diffusion/requirements_svd.txt @@ -0,0 +1,11 @@ +accelerate>=0.22.0 +torchvision==0.16.0 +transformers>=4.25.1 +ftfy +tensorboard +Jinja2 +datasets +peft==0.7.0 +toml +voluptuous +opencv-python-headless \ No newline at end of file diff --git a/PyTorch/built-in/diffusion/diffusers0.25.0/examples/stable_video_diffusion/test_stable_video_diffusion.py b/PyTorch/built-in/diffusion/diffusers0.25.0/examples/stable_video_diffusion/test_stable_video_diffusion.py new file mode 100644 index 0000000000..1c88cbeb48 --- /dev/null +++ b/PyTorch/built-in/diffusion/diffusers0.25.0/examples/stable_video_diffusion/test_stable_video_diffusion.py @@ -0,0 +1,159 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import numpy as np +from numpy.linalg import norm +import time +import argparse +from PIL import Image +import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu + +import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import Dataset + +from diffusers import StableVideoDiffusionPipeline +from diffusers.utils import load_image, export_to_video + + +class AliDataset(Dataset): + def __init__(self, root, anno_file): + super().__init__() + self.root = root + self.anno_file = anno_file + + self.path_list = [] + with open(self.anno_file, 'r') as fp: + for line in fp: + self.path_list.append(line.rstrip()) + + def __getitem__(self, index): + img_name = self.path_list[index] + imgpath = os.path.join(self.root, img_name) + return imgpath + + def __len__(self): + return len(self.path_list) + + +def make_test_data_sampler(dataset, distributed, rank): + if distributed: + return DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=rank, + shuffle=False + ) + else: + return torch.utils.data.sampler.SequentialSampler(dataset) + + +def numpy_cosine_similarity_distance(a, b): + if a.dtype == 'uint8' or b.dtype == 'uint8': + similarity = np.dot(a.astype('float'), b.astype('float')) / (norm(a) * norm(b)) + else: + similarity = np.dot(a, b) / (norm(a) * norm(b)) + + distance = 1.0 - similarity.mean() + + return distance + +def main(args): + seed = args.seed + generator = torch.Generator(device="cpu").manual_seed(seed) + device = torch.device("npu") + batch_size = args.global_batch_size + rank = 0 + num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 + distributed = num_gpus > 1 + if distributed: + dist.init_process_group("hccl") + rank = dist.get_rank() + device = rank % torch.npu.device_count() + torch.npu.set_device(device) + generator = torch.Generator(device="cpu").manual_seed(seed + rank) + print(f"Starting rank={rank}, seed={seed + rank}, world_size={dist.get_world_size()}.") + batch_size = int(args.global_batch_size // dist.get_world_size()) + + pipe = StableVideoDiffusionPipeline.from_pretrained( + args.ckpt, torch_dtype=torch.float16, variant="fp16" + ) + pipe.enable_model_cpu_offload(device=device) + pipe.enable_npu_svd_attention() + + test_dataset = AliDataset(args.test_data_dir, args.test_file) + sampler = make_test_data_sampler(test_dataset, distributed, rank) + loader = DataLoader( + test_dataset, + batch_size=batch_size, + shuffle=False, + sampler=sampler, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True + ) + + for image_path in loader: + for img in image_path: + image = load_image(img) + image = image.resize(args.image_size) + start_time = time.time() + frames = pipe(image, decode_chunk_size=8, generator=generator, num_frames=args.num_frames, num_inference_steps=25, + output_type="pil").frames[0] + step_time = time.time() + if not distributed or dist.get_rank() == 0: + print(f'infer step time: {step_time - start_time}') + if not os.path.exists(os.path.join(args.output_dir, img.split('/')[-1].split('.')[0])): + os.makedirs(os.path.join(args.output_dir, img.split('/')[-1].split('.')[0])) + for i in range(len(frames)): + frames[i].save(os.path.join(args.output_dir, img.split('/')[-1].split('.')[0], f"frames_{i}.png")) + if args.export_video: + export_to_video(frames, os.path.join(args.output_dir, img.split('/')[-1].split('.')[0]+".mp4"), fps=7) + + if args.eval_metrics: + if not distributed or dist.get_rank() == 0: + cos_distances = [] + output_paths = sorted(os.listdir(args.output_dir)) + for subdir in output_paths: + for i in range(args.num_frames): + img_benchmark = Image.open(os.path.join(args.benchmark_dir, subdir, f"frames_{i}.png")) + img_output = Image.open(os.path.join(args.output_dir, subdir, f"frames_{i}.png")) + img_benchmark = np.array(img_benchmark) + img_output = np.array(img_output) + cos_distance = numpy_cosine_similarity_distance(img_output.flatten(), img_benchmark.flatten()) + cos_distances.append(cos_distance) + + print(f"mean cos dis: {sum(cos_distances) / len(cos_distances)}") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--test-data-dir", type=str, default="", help='the path to testset') + parser.add_argument("--test-file", type=str, default="", help='the path to testdata file') + parser.add_argument("--global-batch-size", type=int, default=8) + parser.add_argument("--image-size", type=tuple, default=(1024, 576)) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--num-workers", type=int, default=0) + parser.add_argument("--ckpt", type=str, default="stabilityai/stable-video-diffusion-img2vid-xt", + help='the path to a SVD checkpoint') + parser.add_argument("--output-dir", type=str, default="", help='the path to save outputs') + parser.add_argument("--eval-metrics", type=bool, default=False, help='whether or not to eval metrics') + parser.add_argument("--benchmark-dir", type=str, default="") + parser.add_argument("--export-video", type=bool, default=False) + parser.add_argument("--num-frames", type=int, default=25) + args = parser.parse_args() + main(args) diff --git a/PyTorch/built-in/diffusion/diffusers0.25.0/src/diffusers/models/attention_processor.py b/PyTorch/built-in/diffusion/diffusers0.25.0/src/diffusers/models/attention_processor.py index 9eece3947a..51cf8abe07 100644 --- a/PyTorch/built-in/diffusion/diffusers0.25.0/src/diffusers/models/attention_processor.py +++ b/PyTorch/built-in/diffusion/diffusers0.25.0/src/diffusers/models/attention_processor.py @@ -213,6 +213,12 @@ class Attention(nn.Module): ) self.set_processor(processor) + def set_use_npu_svd_attention( + self, use_npu_svd_attention: bool, attention_op: Optional[Callable] = None + ) -> None: + processor = NpuSVDAttnProcessor2_0() + self.set_processor(processor) + def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ) -> None: @@ -1296,6 +1302,120 @@ class AttnProcessor2_0: return hidden_states +class NpuSVDAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("NpuSVDAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + args = () if USE_PEFT_BACKEND else (scale,) + query = attn.to_q(hidden_states, *args) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, *args) + value = attn.to_v(encoder_hidden_states, *args) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + seq_length = max(query.shape[-2], key.shape[-2]) + if seq_length < 1024: + chunk_size = min(256, query.shape[0]) + chunk_num = query.shape[0] // chunk_size + key_chunk = torch.chunk(key, chunk_num, dim=0) + value_chunk = torch.chunk(value, chunk_num, dim=0) + query_chunk = torch.chunk(query, chunk_num, dim=0) + + hidden_states_chunk = [] + for i in range(chunk_num): + hidden_states_i = F.scaled_dot_product_attention( + query_chunk[i], key_chunk[i], value_chunk[i], + attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states_chunk.append(hidden_states_i) + hidden_states = torch.cat(hidden_states_chunk, dim=0) + elif query.shape[-1] < 512 and query.dtype in (torch.float16, torch.bfloat16): + hidden_states = torch_npu.npu_fusion_attention( + query, key, value, attn.heads, input_layout="BNSD", + pse=None, + atten_mask=attention_mask, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1., + sync=False, + inner_precise=0, + )[0] + + else: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, *args) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class FusedAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). diff --git a/PyTorch/built-in/diffusion/diffusers0.25.0/src/diffusers/models/modeling_utils.py b/PyTorch/built-in/diffusion/diffusers0.25.0/src/diffusers/models/modeling_utils.py index 546c5b20f9..f14526178f 100644 --- a/PyTorch/built-in/diffusion/diffusers0.25.0/src/diffusers/models/modeling_utils.py +++ b/PyTorch/built-in/diffusion/diffusers0.25.0/src/diffusers/models/modeling_utils.py @@ -1,3 +1,4 @@ +# Copyright 2024 Huawei Technologies Co., Ltd # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. @@ -261,6 +262,26 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): if isinstance(module, torch.nn.Module): fn_recursive_set_mem_eff(module) + def set_use_npu_svd_attention( + self, valid: bool, attention_op: Optional[Callable] = None + ) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_npu_svd_attention method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_npu_svd_attention"): + module.set_use_npu_svd_attention(valid, attention_op) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_mem_eff(module) + + def enable_npu_svd_attention(self, attention_op: Optional[Callable] = None): + self.set_use_npu_svd_attention(True, attention_op) + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None: r""" Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). diff --git a/PyTorch/built-in/diffusion/diffusers0.25.0/src/diffusers/pipelines/pipeline_utils.py b/PyTorch/built-in/diffusion/diffusers0.25.0/src/diffusers/pipelines/pipeline_utils.py index e7a795365a..71028502be 100644 --- a/PyTorch/built-in/diffusion/diffusers0.25.0/src/diffusers/pipelines/pipeline_utils.py +++ b/PyTorch/built-in/diffusion/diffusers0.25.0/src/diffusers/pipelines/pipeline_utils.py @@ -1,3 +1,4 @@ +# Copyright 2024 Huawei Technologies Co., Ltd # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. @@ -2045,6 +2046,27 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): for module in modules: fn_recursive_set_mem_eff(module) + def enable_npu_svd_attention(self, attention_op: Optional[Callable] = None): + self.set_use_npu_svd_attention(True, attention_op) + + def set_use_npu_svd_attention(self, valid: bool, attention_op: Optional[Callable] = None): + # Recursively walk through all the children. + # Any children which exposes the set_use_npu_svd_attention method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_npu_svd_attention"): + module.set_use_npu_svd_attention(valid, attention_op) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + module_names, _ = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + fn_recursive_set_mem_eff(module) + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor diff --git a/PyTorch/built-in/diffusion/diffusers0.25.0/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/PyTorch/built-in/diffusion/diffusers0.25.0/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index 8b4c7bdd08..2797c3911c 100644 --- a/PyTorch/built-in/diffusion/diffusers0.25.0/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/PyTorch/built-in/diffusion/diffusers0.25.0/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -1,3 +1,4 @@ +# Copyright 2024 Huawei Technologies Co., Ltd # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time import inspect from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Union @@ -491,6 +493,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + denoise_start = time.time() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -525,6 +528,8 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + denoise_end = time.time() + print('denoise time: ', denoise_end - denoise_start) if not output_type == "latent": # cast back to fp16 if needed diff --git a/PyTorch/built-in/diffusion/diffusers0.25.0/test/infer_full_8p_svd_fp16.sh b/PyTorch/built-in/diffusion/diffusers0.25.0/test/infer_full_8p_svd_fp16.sh new file mode 100644 index 0000000000..bc2f110bad --- /dev/null +++ b/PyTorch/built-in/diffusion/diffusers0.25.0/test/infer_full_8p_svd_fp16.sh @@ -0,0 +1,112 @@ +# 微调生成的ckpt路径 +Network="StableVideoDiffusion" +BATCH_SIZE=8 +ckpt_path="stabilityai/stable-video-diffusion-img2vid-xt" +test_data_dir="svd_testdata" +test_file="svd_testdata/imglist.txt" +benchmark_dir="benchmark_output" +export WORLD_SIZE=8 +export MASTER_PORT=29500 +export MASTER_ADDR=127.0.0.1 + +for para in $* +do + if [[ $para == --batch_size* ]]; then + BATCH_SIZE=$(echo ${para#*=}) + elif [[ $para == --ckpt_path* ]]; then + ckpt_path=$(echo ${para#*=}) + elif [[ $para == --test_data_dir* ]]; then + test_data_dir=$(echo ${para#*=}) + elif [[ $para == --test_file* ]]; then + test_file=$(echo ${para#*=}) + fi +done + +# cd到与test文件夹同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径 +cur_path=$(pwd) +cur_path_last_dirname=${cur_path##*/} +if [ x"${cur_path_last_dirname}" == x"test" ]; then + test_path_dir=${cur_path} + cd .. + cur_path=$(pwd) +else + test_path_dir=${cur_path}/test +fi + +source ${test_path_dir}/env_npu.sh + +ASCEND_DEVICE_ID=0 +#创建DeviceID输出目录,不需要修改 +if [ -d ${test_path_dir}/output/${ASCEND_DEVICE_ID} ];then + rm -rf ${test_path_dir}/output/${ASCEND_DEVICE_ID} + mkdir -p ${test_path_dir}/output/${ASCEND_DEVICE_ID} +else + mkdir -p ${test_path_dir}/output/${ASCEND_DEVICE_ID} +fi + +#推理开始时间,不需要修改 +start_time=$(date +%s) +echo "start_time: ${start_time}" + +RANK_ID_START=0 +RANK_SIZE=8 +KERNEL_NUM=$(($(nproc)/8)) +for((RANK_ID=$RANK_ID_START;RANK_ID<$((RANK_SIZE+RANK_ID_START));RANK_ID++)); +do + export RANK=$RANK_ID + export LOCAL_RANK=$RANK_ID + PID_START=$((KERNEL_NUM * RANK_ID)) + PID_END=$((PID_START + KERNEL_NUM - 1)) + nohup taskset -c $PID_START-$PID_END python3 examples/stable_video_diffusion/test_stable_video_diffusion.py \ + --test-data-dir ${test_data_dir} \ + --test-file ${test_file} \ + --ckpt ${ckpt_path} \ + --output-dir ${test_path_dir}/output/${ASCEND_DEVICE_ID}/output \ + --global-batch-size ${BATCH_SIZE} \ + --benchmark-dir ${benchmark_dir} \ + --eval-metrics True >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/infer_svd_${ASCEND_DEVICE_ID}.log 2>&1 & +done +wait + +#训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(($end_time - $start_time)) + + +# 训练用例信息,不需要修改 +BatchSize=${BATCH_SIZE} +DeviceType=$(uname -m) +CaseName=${Network}_bs${BatchSize}_${WORLD_SIZE}'p'_'acc' + +# 结果打印,不需要修改 +echo "------------------ Final result ------------------" +# 输出性能FPS,需要模型审视修改 +denoise_time=`grep -a 'denoise time:' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/infer_svd_${ASCEND_DEVICE_ID}.log|awk -F "time: " '{print $2}' | sort -n | head -40 | awk '{a+=$1} END {if (NR != 0) printf("%.4f",a/NR)}'` +FPS=`awk 'BEGIN{printf "%.2f\n", '$BATCH_SIZE'/'$denoise_time'}'` +# 打印,不需要修改 +echo "Final Performance images/sec : $FPS" + +# 输出训练精度,需要模型审视修改 +train_accuracy=$(grep -a "mean cos dis" ${test_path_dir}/output/${ASCEND_DEVICE_ID}/infer_svd_${ASCEND_DEVICE_ID}.log |tail -1 |awk -F ": " '{print $2}' |awk '{a+=$1} END {printf("%.4f",a)}') +# 打印,不需要修改 +echo "Final Train Accuracy : ${train_accuracy}" +echo "E2E Training Duration sec : $e2e_time" + + +# 性能看护结果汇总 +# 获取性能数据,不需要修改 +# 吞吐量 +ActualFPS=${FPS} +# 训练总时长 +TrainingTime=`grep -a 'infer step time' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/infer_svd_${ASCEND_DEVICE_ID}.log |awk -F "infer step time: " '{print $2}' | awk '{a+=$1} END {printf("%.3f",a)}'` + +# 关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" >${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${WORLD_SIZE}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainAccuracy = ${train_accuracy}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}_perf_report.log +echo "TrainingTime = ${TrainingTime}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}_perf_report.log +echo "E2ETrainingTime = ${e2e_time}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}_perf_report.log \ No newline at end of file -- Gitee From 9e81aadd04eec4f9958a4bcea95c1864c29b1a7d Mon Sep 17 00:00:00 2001 From: sunyiran Date: Fri, 1 Mar 2024 17:51:04 +0800 Subject: [PATCH 2/2] add --- .../test/infer_full_1p_svd_fp16.sh | 109 ++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 PyTorch/built-in/diffusion/diffusers0.25.0/test/infer_full_1p_svd_fp16.sh diff --git a/PyTorch/built-in/diffusion/diffusers0.25.0/test/infer_full_1p_svd_fp16.sh b/PyTorch/built-in/diffusion/diffusers0.25.0/test/infer_full_1p_svd_fp16.sh new file mode 100644 index 0000000000..3277635c82 --- /dev/null +++ b/PyTorch/built-in/diffusion/diffusers0.25.0/test/infer_full_1p_svd_fp16.sh @@ -0,0 +1,109 @@ +# 微调生成的ckpt路径 +Network="StableVideoDiffusion" +BATCH_SIZE=1 +ckpt_path="stabilityai/stable-video-diffusion-img2vid-xt" +test_data_dir="svd_testdata" +test_file="svd_testdata/imglist.txt" +benchmark_dir="benchmark_output" +export WORLD_SIZE=1 +export MASTER_PORT=29500 +export MASTER_ADDR=127.0.0.1 + +for para in $* +do + if [[ $para == --batch_size* ]]; then + BATCH_SIZE=$(echo ${para#*=}) + elif [[ $para == --ckpt_path* ]]; then + ckpt_path=$(echo ${para#*=}) + elif [[ $para == --test_data_dir* ]]; then + test_data_dir=$(echo ${para#*=}) + elif [[ $para == --test_file* ]]; then + test_file=$(echo ${para#*=}) + fi +done + +# cd到与test文件夹同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径 +cur_path=$(pwd) +cur_path_last_dirname=${cur_path##*/} +if [ x"${cur_path_last_dirname}" == x"test" ]; then + test_path_dir=${cur_path} + cd .. + cur_path=$(pwd) +else + test_path_dir=${cur_path}/test +fi + +source ${test_path_dir}/env_npu.sh + +ASCEND_DEVICE_ID=0 +#创建DeviceID输出目录,不需要修改 +if [ -d ${test_path_dir}/output/${ASCEND_DEVICE_ID} ];then + rm -rf ${test_path_dir}/output/${ASCEND_DEVICE_ID} + mkdir -p ${test_path_dir}/output/${ASCEND_DEVICE_ID} +else + mkdir -p ${test_path_dir}/output/${ASCEND_DEVICE_ID} +fi + +#推理开始时间,不需要修改 +start_time=$(date +%s) +echo "start_time: ${start_time}" + +RANK_ID=0 +KERNEL_NUM=$(($(nproc)/8)) +export RANK=$RANK_ID +export LOCAL_RANK=$RANK_ID +PID_START=$((KERNEL_NUM * RANK_ID)) +PID_END=$((PID_START + KERNEL_NUM - 1)) +nohup taskset -c $PID_START-$PID_END python3 examples/stable_video_diffusion/test_stable_video_diffusion.py \ + --test-data-dir ${test_data_dir} \ + --test-file ${test_file} \ + --ckpt ${ckpt_path} \ + --output-dir ${test_path_dir}/output/${ASCEND_DEVICE_ID}/output \ + --global-batch-size ${BATCH_SIZE} \ + --benchmark-dir ${benchmark_dir} \ + --eval-metrics True >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/infer_svd_${ASCEND_DEVICE_ID}.log 2>&1 & + +wait + +#训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(($end_time - $start_time)) + + +# 训练用例信息,不需要修改 +BatchSize=${BATCH_SIZE} +DeviceType=$(uname -m) +CaseName=${Network}_bs${BatchSize}_${WORLD_SIZE}'p'_'acc' + +# 结果打印,不需要修改 +echo "------------------ Final result ------------------" +# 输出性能FPS,需要模型审视修改 +denoise_time=`grep -a 'denoise time:' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/infer_svd_${ASCEND_DEVICE_ID}.log|awk -F "time: " '{print $2}' | sort -n | head -40 | awk '{a+=$1} END {if (NR != 0) printf("%.4f",a/NR)}'` +FPS=`awk 'BEGIN{printf "%.2f\n", '$BATCH_SIZE'/'$denoise_time'}'` +# 打印,不需要修改 +echo "Final Performance images/sec : $FPS" + +# 输出训练精度,需要模型审视修改 +train_accuracy=$(grep -a "mean cos dis" ${test_path_dir}/output/${ASCEND_DEVICE_ID}/infer_svd_${ASCEND_DEVICE_ID}.log |tail -1 |awk -F ": " '{print $2}' |awk '{a+=$1} END {printf("%.4f",a)}') +# 打印,不需要修改 +echo "Final Train Accuracy : ${train_accuracy}" +echo "E2E Training Duration sec : $e2e_time" + + +# 性能看护结果汇总 +# 获取性能数据,不需要修改 +# 吞吐量 +ActualFPS=${FPS} +# 训练总时长 +TrainingTime=`grep -a 'infer step time' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/infer_svd_${ASCEND_DEVICE_ID}.log |awk -F "infer step time: " '{print $2}' | awk '{a+=$1} END {printf("%.3f",a)}'` + +# 关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" >${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${WORLD_SIZE}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainAccuracy = ${train_accuracy}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}_perf_report.log +echo "TrainingTime = ${TrainingTime}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}_perf_report.log +echo "E2ETrainingTime = ${e2e_time}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}_perf_report.log \ No newline at end of file -- Gitee