diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/README.md b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/README.md index cfd8e5c65f55a82f222751ee17f80e7bc3645819..c7b572a9d30f310fd8efe11b9d82f0901a85aed2 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/README.md +++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/README.md @@ -108,7 +108,23 @@ pip3 install -r requirements.txt | | |---- 模型权重 ``` -### 3.2 单卡单prompt功能测试 +### 3.2 RoPE算子编译 +进入算子路径,执行编译命令 +```shell +cd pta_plugin +bash build.sh +``` +编译成功后会在build文件夹下生成.so结尾的算子文件 + + + +在cogvideox_5b/models/attention_processor.py脚本中添加编译生成的算子路径 +```python +torch.ops.load_library("./pta_plugin/build/libPTAExtensionOPS.so") +``` +注意:首次运行需要加载RoPE算子,请在正式推理前进行warmup + +### 3.3 单卡单prompt功能测试 设置权重路径: ```shell model_path='data/CogVideoX-5b' diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention_processor.py index d49abec2e5dcd8decbe334bc0273ba705320e7c6..c748dcde4d5673109b2b193be7fc01026b83971e 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention_processor.py +++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/models/attention_processor.py @@ -19,6 +19,7 @@ import torch import torch_npu import torch.nn.functional as F from torch import nn +torch.ops.load_library("./pta_plugin/build/libPTAExtensionOPS.so") from diffusers.image_processor import IPAdapterMaskProcessor from diffusers.utils import deprecate, logging @@ -1917,11 +1918,10 @@ class CogVideoXAttnProcessor2_0: # Apply RoPE if needed if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + cos, sin = image_rotary_emb + query[:, :, text_seq_length:] = torch.ops.mindie.rope_mindie_sd(query[:, :, text_seq_length:], cos[None, None], sin[None, None], mode=1) if not attn.is_cross_attention: - key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + key[:, :, text_seq_length:] = torch.ops.mindie.rope_mindie_sd(key[:, :, text_seq_length:], cos[None, None], sin[None, None], mode=1) if get_sp_world_size() == 1: hidden_states = torch_npu.npu_prompt_flash_attention( diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_cogvideox.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_cogvideox.py index 6fba7a36091851589d3b88dd563413af6ee8c034..39766b00d29c96e24da4405dc15e8b1f1ca5d568 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_cogvideox.py +++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/cogvideox_5b/pipelines/pipeline_cogvideox.py @@ -454,8 +454,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): temporal_size=num_frames, ) - freqs_cos = freqs_cos.to(device=device) - freqs_sin = freqs_sin.to(device=device) + freqs_cos = freqs_cos.to(device=device).to(torch.bfloat16) + freqs_sin = freqs_sin.to(device=device).to(torch.bfloat16) return freqs_cos, freqs_sin diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py index 71417fc7e57de2f03375cdc55bba8478c6a28d9c..06b4365ab7c2e7cfd72b66be2cdd9cfb87666937 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py +++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/inference.py @@ -2,8 +2,10 @@ import os import argparse import time import random -import numpy as np +from typing import Literal + +import numpy as np import torch import torch_npu from torch_npu.contrib import transfer_to_npu @@ -11,7 +13,7 @@ from torch_npu.contrib import transfer_to_npu from diffusers import CogVideoXDPMScheduler from diffusers.utils import export_to_video -from cogvideox_5b import CogVideoXPipeline, CogVideoXTransformer3DModel, get_rank, get_world_size, all_gather +from cogvideox_5b import CogVideoXPipeline, CogVideoXTransformer3DModel, get_rank, get_world_size, all_gather, parallelize_transformer from mindiesd.pipeline.sampling_optm import AdaStep diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/CMakeLists.txt b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ff66b7724ef9a8e6adcf33033067dd631f88dd63 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/CMakeLists.txt @@ -0,0 +1,30 @@ +cmake_minimum_required(VERSION 3.10) + +project(PTAExtensionOPS) + +execute_process( + COMMAND python3 -c "import site; print(site.getsitepackages()[0])" + OUTPUT_VARIABLE python_site_packages_path +) +string(STRIP "${python_site_packages_path}" python_site_packages_path) + +set(CMAKE_CXX_FLAGS "-fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -fPIE -pie ${CMAKE_CXX_FLAGS}") +set(CMAKE_CXX_FLAGS "-fabi-version=11 ${CMAKE_CXX_FLAGS}") +set(PYTORCH_INSTALL_PATH ${python_site_packages_path}/torch) +set(PYTORCH_NPU_INSTALL_PATH ${python_site_packages_path}/torch_npu) + +link_directories(${PYTORCH_INSTALL_PATH}/lib) +link_directories(${PYTORCH_NPU_INSTALL_PATH}/lib) + +add_library(PTAExtensionOPS SHARED extension_ops.cpp) + +target_compile_features(PTAExtensionOPS PRIVATE cxx_std_17) +target_compile_options(PTAExtensionOPS PRIVATE -D_GLIBCXX_USE_CXX11_ABI=0) + +include_directories(${PYTORCH_NPU_INSTALL_PATH}/include/third_party/acl/inc) +include_directories(${PYTORCH_NPU_INSTALL_PATH}/include) +include_directories(${PYTORCH_INSTALL_PATH}/include) +include_directories(${PYTORCH_INSTALL_PATH}/include/torch/csrc/distributed) +include_directories(${PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include) + +target_link_libraries(PTAExtensionOPS PUBLIC c10 torch torch_cpu torch_npu ) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/build.sh b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..95d55f5ff205ac2f76dcdc2183092b606f039daa --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/build.sh @@ -0,0 +1,19 @@ +#!/bin/bash +if [ -n "$ASCEND_INSTALL_PATH" ]; then + _ASCEND_INSTALL_PATH=$ASCEND_INSTALL_PATH +elif [ -n "$ASCEND_HOME_PATH" ]; then + _ASCEND_INSTALL_PATH=$ASCEND_HOME_PATH +else + if [ -d "$HOME/Ascend/ascend-toolkit/latest" ]; then + _ASCEND_INSTALL_PATH=$HOME/Ascend/ascend-toolkit/latest + else + _ASCEND_INSTALL_PATH=/usr/local/Ascend/ascend-toolkit/latest + fi +fi +source $_ASCEND_INSTALL_PATH/bin/setenv.bash + +set -e +rm -rf build +mkdir -p build +cmake -B build +cmake --build build -j \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/extension_ops.cpp b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/extension_ops.cpp new file mode 100644 index 0000000000000000000000000000000000000000..548a9a7365c1c926e593ac6209d525db10e96d78 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/extension_ops.cpp @@ -0,0 +1,69 @@ +/** + * @file extension_add.cpp + * + * Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + */ +#include +#include + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/core/npu/NPUFormat.h" + +using torch::autograd::AutogradContext; +using torch::autograd::Function; +using npu_preparation = at_npu::native::OpPreparation; +using npu_utils = at_npu::native::NpuUtils; +using namespace at; + +// flash_attention_tik +// register forward implementation for NPU device +at::Tensor rope_mindie_sd_impl_npu(const at::Tensor &x, const at::Tensor &cos, const at::Tensor &sin, int64_t mode=1) +{ + at::Tensor result = at_npu::native::empty_with_format(x.sizes(),x.options(),at_npu::native::get_npu_format(x)); + + at_npu::native::OpCommand cmd; + + cmd.Name("RotaryPositionEmbedding") + .Input(x) + .Input(cos) + .Input(sin) + .Output(result) + .Attr("mode", mode) + .Run(); + + return result; +} + +// register forward implementation for Meta device +at::Tensor rope_mindie_sd_impl_meta(const at::Tensor &x, const at::Tensor &cos, const at::Tensor &sin, int64_t mode) +{ + return empty_like(x); +} + + +// register the schemas for my_op and my_op_backward in the myops namespace +TORCH_LIBRARY(mindie, m) +{ + m.def("rope_mindie_sd(Tensor query, Tensor key, Tensor value, int mode) -> Tensor"); +} + +// register forward and backward implementations for the NPU device +// the device name used by the NPU device in PyTorch 2.1 and above is PrivateUse1. +// in versions below 2.1, XLA is used. If the version is below 2.1, PrivateUse1 needs to be changed to XLA. +TORCH_LIBRARY_IMPL(mindie, PrivateUse1, m) +{ + m.impl("rope_mindie_sd", &rope_mindie_sd_impl_npu); +} + +// bind the NPU's autograd implementation to the operation +// if the version is below PyTorch 2.1, AutogradPrivateUse1 needs to be changed to AutogradXLA. + +// register forward and backward implementations for the Meta device +TORCH_LIBRARY_IMPL(mindie, Meta, m) +{ + m.impl("rope_mindie_sd", &rope_mindie_sd_impl_meta); +} \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/test/test_rope.py b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/test/test_rope.py new file mode 100644 index 0000000000000000000000000000000000000000..5d6f3425b44b6c94066aa084b3266a68db145853 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/CogVideoX-5b/pta_plugin/test/test_rope.py @@ -0,0 +1,25 @@ +#!/usr/bin/python3 +# coding=utf-8 +# +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +# =============================================================================== + +import torch +import torch.nn as nn +import torch_npu + +torch.ops.load_library("../build/libPTAExtensionOPS.so") + +if __name__ == "__main__": + torch.npu.set_device(0) + x = torch.randn((2, 48, 128, 64), device="npu") + cos = torch.randn((1, 1, 128, 64), device="npu") + sin = torch.randn((1, 1, 128, 64), device="npu") + + count = 5 + for i in range(count): + output = torch.ops.mindie.rope_mindie_sd(x, cos, sin, mode=1) \ No newline at end of file