diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_xl/convert_lora_safetensors_to_diffusers.py b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_xl/convert_lora_safetensors_to_diffusers.py index 2f449cc28dac1c7f527b46acb4a62927ff5acc20..940bc0c89352e2ba3aab417a102065e0152b0526 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_xl/convert_lora_safetensors_to_diffusers.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_xl/convert_lora_safetensors_to_diffusers.py @@ -20,6 +20,7 @@ import torch from safetensors.torch import load_file from diffusers import StableDiffusionXLPipeline +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear def convert(base_model_path, checkpoint_path, LORA_PREFIX_UNET, LORA_PREFIX_TEXT_ENCODER, alpha): @@ -78,24 +79,17 @@ def convert(base_model_path, checkpoint_path, LORA_PREFIX_UNET, LORA_PREFIX_TEXT pair_keys.append(key.replace("lora_up", "lora_down")) # update weight - if len(state_dict[pair_keys[0]].shape) == 4: - weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) - weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) - try: - if len(curr_layer.weight.shape) == 2: - curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down) # for SD2.1 - else: - curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) - except Exception: - shape4failed += 1 - else: - weight_up = state_dict[pair_keys[0]].to(torch.float32) - weight_down = state_dict[pair_keys[1]].to(torch.float32) - try: - curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down) - except Exception: - shapeno4failed += 1 - + if isinstance(curr_layer, LoRACompatibleConv): + upmat = state_dict[pair_keys[0]].to(torch.float32).flatten(start_dim=1) + downmat = state_dict[pair_keys[1]].to(torch.float32).flatten(start_dim=1) + fusionupdown = torch.mm(upmat, downmat) + fusionupdown = fusionupdown.reshape(curr_layer.weight.data.shape) + curr_layer.weight.data += alpha * fusionupdown + elif isinstance(curr_layer,LoRACompatibleLinear): + upmat = state_dict[pair_keys[0]].to(torch.float32)[None, :] + downmat = state_dict[pair_keys[1]].to(torch.float32)[None, :] + fusion = torch.bmm(upmat, downmat)[0] + curr_layer.weight.data += alpha * fusion # update visited list for item in pair_keys: visited.append(item) diff --git a/OWNERS b/OWNERS index 05a0867100908b9e55041e78e6b97f9cdae78705..9ee948465573ed6f1e31f85e654f30093b17fe98 100644 --- a/OWNERS +++ b/OWNERS @@ -19,6 +19,7 @@ approvers: - fighting_zhen - kezhan1 - jyoung6652 +- Gongen reviewers: - ginray0215 - matrixplayer @@ -53,4 +54,16 @@ reviewers: - han_yifeng - jyoung6652 - maoxx241 -- chenchuw \ No newline at end of file +- chenchuw +- guowenna +- lanwangli +- mazhixin00 +- zhou-wenxue +- huanghao7 +- Gongen +- kezhan1 +- terrychen1982 +- fan2956 +- Yansifu +- yliuls +- shaopeng666 \ No newline at end of file