diff --git a/MindIE/MultiModal/OpenSora-1.2/README.md b/MindIE/MultiModal/OpenSora-1.2/README.md index bf72ce59342fca81339748e43a265a22925cf4bc..a439a60331bb934f297f893274c505e04ec47fe2 100644 --- a/MindIE/MultiModal/OpenSora-1.2/README.md +++ b/MindIE/MultiModal/OpenSora-1.2/README.md @@ -34,6 +34,10 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh pip3 install -r requirements.txt ``` +安装colossalai +```shell +pip3 install colossalai==0.4.4 --no-deps +``` ### 1.4 MindIE安装 ```shell # 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构。 diff --git a/MindIE/MultiModal/OpenSora-1.2/opensora/stdit3/stdit3.py b/MindIE/MultiModal/OpenSora-1.2/opensora/stdit3/stdit3.py index e58992062851337a42458e45ced5aeeef0b7093a..348a21df5ce506217d846308c9d7042156f05597 100644 --- a/MindIE/MultiModal/OpenSora-1.2/opensora/stdit3/stdit3.py +++ b/MindIE/MultiModal/OpenSora-1.2/opensora/stdit3/stdit3.py @@ -354,6 +354,15 @@ class STDiT3(DiffusionModel): # cast to float32 for better accuracy x = x.to(torch.float32) return x + + def load_weights(self, state_dict, shard=False): + with torch.no_grad(): + if not shard: + self.load_state_dict(state_dict) + return {} + else: + self.load_state_dict(state_dict, strict=False) + return state_dict.keys() def _init_embedding(self, config): self.x_embedder = PatchEmbed3D(self.patch_size, config.in_channels, self.hidden_size) diff --git a/MindIE/MultiModal/OpenSora-1.2/opensora/vae/VideoAutoencoder.py b/MindIE/MultiModal/OpenSora-1.2/opensora/vae/VideoAutoencoder.py index 2db644efc603f53530adde8a71c79cf4be8968e7..4bffccdd030e88a2d0bccb426408d6a41b720c66 100644 --- a/MindIE/MultiModal/OpenSora-1.2/opensora/vae/VideoAutoencoder.py +++ b/MindIE/MultiModal/OpenSora-1.2/opensora/vae/VideoAutoencoder.py @@ -196,4 +196,13 @@ class VideoAutoencoder(DiffusionModel): x_z = torch.cat(x_z_list, dim=2) x = self.spatial_vae.decode(x_z) - return x \ No newline at end of file + return x + + def load_weights(self, state_dict, shard=False): + with torch.no_grad(): + if not shard: + self.load_state_dict(state_dict) + return {} + else: + self.load_state_dict(state_dict, strict=False) + return state_dict.keys() diff --git a/MindIE/MultiModal/OpenSora-1.2/requirents.txt b/MindIE/MultiModal/OpenSora-1.2/requirents.txt index 3d3a4381585ec6552ad217dd5089ae3d68bafea1..b6d5654a77cdf2452303b0dc18fb5d4e67b994b5 100644 --- a/MindIE/MultiModal/OpenSora-1.2/requirents.txt +++ b/MindIE/MultiModal/OpenSora-1.2/requirents.txt @@ -1,4 +1,3 @@ -colossalai==0.3.7 setuptools==57.5.0 torch==2.1.0 diffusers==0.26.3