diff --git a/PyTorch/built-in/mlm/PIDM/LICENSE b/PyTorch/built-in/mlm/PIDM/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..31e2ee5ab9abc854863028b0ce8c53d98d987958
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2023 Ankan Kumar Bhunia
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/PyTorch/built-in/mlm/PIDM/README.md b/PyTorch/built-in/mlm/PIDM/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..fa71e811dbba47bbaaffa601840c1b7ad7718291
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/README.md
@@ -0,0 +1,170 @@
+# Person Image Synthesis via Denoising Diffusion Model [](https://colab.research.google.com/github/ankanbhunia/PIDM/blob/main/PIDM_demo.ipynb)
+
+
+
+ ArXiv
+ |
+ Project
+ |
+ Demo
+ |
+ Youtube
+
+
+
+
+
+## News
+
+- **2023.02** A demo available through Google Colab:
+
+ :rocket:
+ [Demo on Colab](https://colab.research.google.com/github/ankanbhunia/PIDM/blob/main/PIDM_demo.ipynb)
+
+
+
+
+## Generated Results
+
+
+
+You can directly download our test results from Google Drive: (1) [PIDM.zip](https://drive.google.com/file/d/1zcyTF37UrOmUqtRwwq1kgkyxnNX3oaQN/view?usp=share_link) (2) [PIDM_vs_Others.zip](https://drive.google.com/file/d/1iu75RVQBjR-TbB4ZQUns1oalzYZdNqGS/view?usp=share_link)
+
+The [PIDM_vs_Others.zip](https://drive.google.com/file/d/1iu75RVQBjR-TbB4ZQUns1oalzYZdNqGS/view?usp=share_link) file compares our method with several state-of-the-art methods e.g. ADGAN [14], PISE [24], GFLA [20], DPTN [25], CASD [29],
+NTED [19]. Each row contains target_pose, source_image, ground_truth, ADGAN, PISE, GFLA, DPTN, CASD, NTED, and PIDM (ours) respectively.
+
+
+
+
+## Dataset
+
+- Download `img_highres.zip` of the DeepFashion Dataset from [In-shop Clothes Retrieval Benchmark](https://drive.google.com/drive/folders/0B7EVK8r0v71pYkd5TzBiclMzR00).
+
+- Unzip `img_highres.zip`. You will need to ask for password from the [dataset maintainers](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html). Then rename the obtained folder as **img** and put it under the `./dataset/deepfashion` directory.
+
+- We split the train/test set following [GFLA](https://github.com/RenYurui/Global-Flow-Local-Attention). Several images with significant occlusions are removed from the training set. Download the train/test pairs and the keypoints `pose.zip` extracted with [Openpose](https://github.com/CMU-Perceptual-Computing-Lab/openpose) by downloading the following files:
+
+
+
+ - Download the train/test pairs from [Google Drive](https://drive.google.com/drive/folders/1PhnaFNg9zxMZM-ccJAzLIt2iqWFRzXSw?usp=sharing) including **train_pairs.txt**, **test_pairs.txt**, **train.lst**, **test.lst**. Put these files under the `./dataset/deepfashion` directory.
+ - Download the keypoints `pose.rar` extracted with Openpose from [Google Driven](https://drive.google.com/file/d/1waNzq-deGBKATXMU9JzMDWdGsF4YkcW_/view?usp=sharing). Unzip and put the obtained floder under the `./dataset/deepfashion` directory.
+
+- Run the following code to save images to lmdb dataset.
+
+ ```bash
+ python data/prepare_data.py \
+ --root ./dataset/deepfashion \
+ --out ./dataset/deepfashion
+ ```
+## Custom Dataset
+
+The folder structure of any custom dataset should be as follows:
+
+- dataset/
+- - /
+- - - img/
+- - - pose/
+- - - train_pairs.txt
+- - - test_pairs.txt
+
+You basically will have all your images inside ```img``` folder. You can use different subfolders to store your images or put all your images inside the ```img``` folder as well. The corresponding poses are stored inside ```pose``` folder (as txt file if you use openpose. In our project, we use 18-point keypoint estimation). ```train_pairs.txt``` and ```test_pairs.txt``` will have paths of all possible pairs seperated by comma ```,```.
+
+After that, run the following command to process the data:
+
+```
+python data/prepare_data.py \
+--root ./dataset/ \
+--out ./dataset/
+--sizes ((256,256),)
+```
+
+This will create an lmdb dataset ```./dataset//256-256/```
+
+
+
+
+## Conda Installation
+
+``` bash
+# 1. Create a conda virtual environment.
+conda create -n PIDM python=3.7
+conda activate PIDM
+conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
+
+# 2. Clone the Repo and Install dependencies
+git clone https://github.com/ankanbhunia/PIDM
+pip install -r requirements.txt
+
+```
+## Method
+
+
+
+## Training
+
+This code supports multi-GPU training. Full training takes 5 days with 8 A100 GPUs and a batch size 8 on the DeepFashion dataset. The model is trained for 300 epochs; however, it generates high-quality usable samples after 200 epochs. We also attempted training with V100 GPUs, and our code takes a similar amount of time for training.
+
+ ```bash
+python -m torch.distributed.launch --nproc_per_node=8 --master_port 48949 train.py \
+--dataset_path "./dataset/deepfashion" --batch_size 8 --exp_name "pidm_deepfashion"
+
+ ```
+
+
+## Inference
+
+Download the pretrained model from [here](https://drive.google.com/file/d/1WkV5Pn-_fBdiZlvVHHx_S97YESBkx4lD/view?usp=share_link) and place it in the ```checkpoints``` folder.
+For pose control use ```obj.predict_pose``` as in the following code snippets.
+
+ ```python
+from predict import Predictor
+obj = Predictor()
+
+obj.predict_pose(image=, sample_algorithm='ddim', num_poses=4, nsteps=50)
+
+ ```
+
+For apperance control use ```obj.predict_appearance```
+
+ ```python
+from predict import Predictor
+obj = Predictor()
+
+src =
+ref_img =
+ref_mask =
+ref_pose =
+
+obj.predict_appearance(image=src, ref_img = ref_img, ref_mask = ref_mask, ref_pose = ref_pose, sample_algorithm = 'ddim', nsteps = 50)
+
+ ```
+
+The output will be saved as ```output.png``` filename.
+
+
+## Citation
+
+If you use the results and code for your research, please cite our paper:
+
+```
+@article{bhunia2022pidm,
+ title={Person Image Synthesis via Denoising Diffusion Model},
+ author={Bhunia, Ankan Kumar and Khan, Salman and Cholakkal, Hisham and Anwer, Rao Muhammad and Laaksonen, Jorma and Shah, Mubarak and Khan, Fahad Shahbaz},
+ journal={CVPR},
+ year={2023}
+}
+```
+
+[Ankan Kumar Bhunia](https://scholar.google.com/citations?user=2leAc3AAAAAJ&hl=en),
+[Salman Khan](https://scholar.google.com/citations?user=M59O9lkAAAAJ&hl=en),
+[Hisham Cholakkal](https://scholar.google.com/citations?user=bZ3YBRcAAAAJ&hl=en),
+[Rao Anwer](https://scholar.google.fi/citations?user=_KlvMVoAAAAJ&hl=en),
+[Jorma Laaksonen](https://scholar.google.com/citations?user=qQP6WXIAAAAJ&hl=en),
+[Mubarak Shah](https://scholar.google.com/citations?user=p8gsO3gAAAAJ&hl=en) &
+[Fahad Khan](https://scholar.google.ch/citations?user=zvaeYnUAAAAJ&hl=en&oi=ao)
+
diff --git a/PyTorch/built-in/mlm/PIDM/config/__init__.py b/PyTorch/built-in/mlm/PIDM/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e0bfaed508414956366b6014645b3afef5ae358
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/config/__init__.py
@@ -0,0 +1,2 @@
+from .diffconfig import *
+from .dataconfig import *
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/PIDM/config/data.yaml b/PyTorch/built-in/mlm/PIDM/config/data.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..800fb63f4c24d896627f149d489fa3847cc9469f
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/config/data.yaml
@@ -0,0 +1,18 @@
+distributed: False
+
+data:
+ type: data.fashion_data::Dataset
+ preprocess_mode: resize_and_crop
+ path: /home/ankanbhunia/Downloads/pidm-demo/dataset/deepfashion
+ num_workers: 8
+ sub_path: 256-256
+ resolution: 256
+ scale_param: 0.05
+ train:
+ batch_size: 8 # real_batch_size: 4 * 2 (source-->target & target --> source) * 2 (GPUs) = 16
+ distributed: False
+ val:
+ batch_size: 8
+ distributed: False
+
+
diff --git a/PyTorch/built-in/mlm/PIDM/config/dataconfig.py b/PyTorch/built-in/mlm/PIDM/config/dataconfig.py
new file mode 100644
index 0000000000000000000000000000000000000000..e97c96cd9b15546743b3669ef55bdcc521400084
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/config/dataconfig.py
@@ -0,0 +1,203 @@
+import collections
+import functools
+import os
+import re
+
+import yaml
+#from util.distributed import master_only_print as print
+
+class AttrDict(dict):
+ """Dict as attribute trick."""
+
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+ for key, value in self.__dict__.items():
+ if isinstance(value, dict):
+ self.__dict__[key] = AttrDict(value)
+ elif isinstance(value, (list, tuple)):
+ if isinstance(value[0], dict):
+ self.__dict__[key] = [AttrDict(item) for item in value]
+ else:
+ self.__dict__[key] = value
+
+ def yaml(self):
+ """Convert object to yaml dict and return."""
+ yaml_dict = {}
+ for key, value in self.__dict__.items():
+ if isinstance(value, AttrDict):
+ yaml_dict[key] = value.yaml()
+ elif isinstance(value, list):
+ if isinstance(value[0], AttrDict):
+ new_l = []
+ for item in value:
+ new_l.append(item.yaml())
+ yaml_dict[key] = new_l
+ else:
+ yaml_dict[key] = value
+ else:
+ yaml_dict[key] = value
+ return yaml_dict
+
+ def __repr__(self):
+ """Print all variables."""
+ ret_str = []
+ for key, value in self.__dict__.items():
+ if isinstance(value, AttrDict):
+ ret_str.append('{}:'.format(key))
+ child_ret_str = value.__repr__().split('\n')
+ for item in child_ret_str:
+ ret_str.append(' ' + item)
+ elif isinstance(value, list):
+ if isinstance(value[0], AttrDict):
+ ret_str.append('{}:'.format(key))
+ for item in value:
+ # Treat as AttrDict above.
+ child_ret_str = item.__repr__().split('\n')
+ for item in child_ret_str:
+ ret_str.append(' ' + item)
+ else:
+ ret_str.append('{}: {}'.format(key, value))
+ else:
+ ret_str.append('{}: {}'.format(key, value))
+ return '\n'.join(ret_str)
+
+
+class Config(AttrDict):
+ r"""Configuration class. This should include every human specifiable
+ hyperparameter values for your training."""
+
+ def __init__(self, filename=None, args=None, verbose=False, is_train=True):
+ super(Config, self).__init__()
+ # Set default parameters.
+ # Logging.
+
+ large_number = 1000000000
+ self.snapshot_save_iter = large_number
+ self.snapshot_save_epoch = large_number
+ self.snapshot_save_start_iter = 0
+ self.snapshot_save_start_epoch = 0
+ self.image_save_iter = large_number
+ self.eval_epoch = large_number
+ self.start_eval_epoch = large_number
+ self.eval_epoch = large_number
+ self.max_epoch = large_number
+ self.max_iter = large_number
+ self.logging_iter = 100
+ self.image_to_tensorboard=False
+ # self.which_iter = args.which_iter
+ # self.resume = not args.no_resume
+
+
+ # self.checkpoints_dir = args.checkpoints_dir
+ # self.name = args.name
+ self.phase = 'train' if is_train else 'test'
+
+ # Networks.
+ self.gen = AttrDict(type='generators.dummy')
+ self.dis = AttrDict(type='discriminators.dummy')
+
+ # Optimizers.
+ self.gen_optimizer = AttrDict(type='adam',
+ lr=0.0001,
+ adam_beta1=0.0,
+ adam_beta2=0.999,
+ eps=1e-8,
+ lr_policy=AttrDict(iteration_mode=False,
+ type='step',
+ step_size=large_number,
+ gamma=1))
+ self.dis_optimizer = AttrDict(type='adam',
+ lr=0.0001,
+ adam_beta1=0.0,
+ adam_beta2=0.999,
+ eps=1e-8,
+ lr_policy=AttrDict(iteration_mode=False,
+ type='step',
+ step_size=large_number,
+ gamma=1))
+ # Data.
+ self.data = AttrDict(name='dummy',
+ type='datasets.images',
+ num_workers=0)
+ self.test_data = AttrDict(name='dummy',
+ type='datasets.images',
+ num_workers=0,
+ test=AttrDict(is_lmdb=False,
+ roots='',
+ batch_size=1))
+ self.trainer = AttrDict(
+ image_to_tensorboard=False,
+ hparam_to_tensorboard=False)
+
+ # Cudnn.
+ self.cudnn = AttrDict(deterministic=False,
+ benchmark=True)
+
+ # Others.
+ self.pretrained_weight = ''
+ self.inference_args = AttrDict()
+
+
+ # Update with given configurations.
+ assert os.path.exists(filename), 'File {} not exist.'.format(filename)
+ loader = yaml.SafeLoader
+ loader.add_implicit_resolver(
+ u'tag:yaml.org,2002:float',
+ re.compile(u'''^(?:
+ [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
+ |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
+ |\\.[0-9_]+(?:[eE][-+][0-9]+)?
+ |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
+ |[-+]?\\.(?:inf|Inf|INF)
+ |\\.(?:nan|NaN|NAN))$''', re.X),
+ list(u'-+0123456789.'))
+ try:
+ with open(filename, 'r') as f:
+ cfg_dict = yaml.load(f, Loader=loader)
+ except EnvironmentError:
+ print('Please check the file with name of "%s"', filename)
+ recursive_update(self, cfg_dict)
+
+ # Put common opts in both gen and dis.
+ if 'common' in cfg_dict:
+ self.common = AttrDict(**cfg_dict['common'])
+ self.gen.common = self.common
+ self.dis.common = self.common
+
+
+ if verbose:
+ print(' config '.center(80, '-'))
+ print(self.__repr__())
+ print(''.center(80, '-'))
+
+
+def rsetattr(obj, attr, val):
+ """Recursively find object and set value"""
+ pre, _, post = attr.rpartition('.')
+ return setattr(rgetattr(obj, pre) if pre else obj, post, val)
+
+
+def rgetattr(obj, attr, *args):
+ """Recursively find object and return value"""
+
+ def _getattr(obj, attr):
+ r"""Get attribute."""
+ return getattr(obj, attr, *args)
+
+ return functools.reduce(_getattr, [obj] + attr.split('.'))
+
+
+def recursive_update(d, u):
+ """Recursively update AttrDict d with AttrDict u"""
+ for key, value in u.items():
+ if isinstance(value, collections.abc.Mapping):
+ d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)
+ elif isinstance(value, (list, tuple)):
+ if isinstance(value[0], dict):
+ d.__dict__[key] = [AttrDict(item) for item in value]
+ else:
+ d.__dict__[key] = value
+ else:
+ d.__dict__[key] = value
+ return d
diff --git a/PyTorch/built-in/mlm/PIDM/config/diffconfig.py b/PyTorch/built-in/mlm/PIDM/config/diffconfig.py
new file mode 100755
index 0000000000000000000000000000000000000000..587bcdbeb041a484ba8c91b102690150fbf73f07
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/config/diffconfig.py
@@ -0,0 +1,77 @@
+from typing import Optional, List
+from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
+from tensorfn.config import (
+ MainConfig,
+ Config,
+ Optimizer,
+ Scheduler,
+ DataLoader,
+ Instance,
+)
+
+import diffusion
+import model
+from models.unet_autoenc import BeatGANsAutoencConfig
+
+
+class Diffusion(Config):
+ beta_schedule: Instance
+
+class Dataset(Config):
+ name: StrictStr
+ path: StrictStr
+ resolution: StrictInt
+
+class Training(Config):
+ ckpt_path: StrictStr
+ optimizer: Optimizer
+ scheduler: Optional[Scheduler]
+ dataloader: DataLoader
+
+
+class Eval(Config):
+ wandb: StrictBool
+ save_every: StrictInt
+ valid_every: StrictInt
+ log_every: StrictInt
+
+
+class DiffusionConfig(MainConfig):
+ diffusion: Diffusion
+ training: Training
+
+
+def get_model_conf():
+
+ return BeatGANsAutoencConfig(image_size=256,
+ in_channels=3+20,
+ model_channels=128,
+ out_channels=3*2, # also learns sigma
+ num_res_blocks=2,
+ num_input_res_blocks=None,
+ embed_channels=512,
+ attention_resolutions=(32, 16, 8,),
+ time_embed_channels=None,
+ dropout=0.1,
+ channel_mult=(1, 1, 2, 2, 4, 4),
+ input_channel_mult=None,
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ resblock_updown=True,
+ use_new_attention_order=False,
+ resnet_two_cond=True,
+ resnet_cond_channels=None,
+ resnet_use_zero_module=True,
+ attn_checkpoint=False,
+ enc_out_channels=512,
+ enc_attn_resolutions=None,
+ enc_pool='adaptivenonzero',
+ enc_num_res_block=2,
+ enc_channel_mult=(1, 1, 2, 2, 4, 4, 4),
+ enc_grad_checkpoint=False,
+ latent_net_conf=None)
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/PIDM/config/diffusion.conf b/PyTorch/built-in/mlm/PIDM/config/diffusion.conf
new file mode 100755
index 0000000000000000000000000000000000000000..776983b0de89eb8eabe709ec3c2aeda2ecdca636
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/config/diffusion.conf
@@ -0,0 +1,30 @@
+diffusion: {
+ beta_schedule: {
+ __target: diffusion.make_beta_schedule
+ schedule: linear
+ n_timestep: 1000
+ linear_start: 1e-4
+ linear_end: 2e-2
+ }
+}
+
+training: {
+ ckpt_path = checkpoints/fashion/
+ optimizer: {
+ type: adam
+ lr: 2e-5
+ }
+ scheduler: {
+ type: cycle
+ lr: 2e-5
+ n_iter: 2400000
+ warmup: 5000
+ decay: [linear, flat]
+ }
+ dataloader: {
+ batch_size: 8
+ num_workers: 2
+ drop_last: true
+ }
+}
+
diff --git a/PyTorch/built-in/mlm/PIDM/data/__init__.py b/PyTorch/built-in/mlm/PIDM/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..775af5e1c8d61b5dddab9e01794442dfaea4a04d
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/data/__init__.py
@@ -0,0 +1,96 @@
+import importlib
+import torch.utils.data
+#from util.distributed import master_only_print as print
+from torch.utils.data import Dataset
+import numpy as np
+import glob
+
+def find_dataset_using_name(dataset_name):
+ # Given the option --dataset [datasetname],
+ dataset_filename = dataset_name
+ module, target = dataset_name.split('::')
+ datasetlib = importlib.import_module(module)
+ # In the file, the class called`` DatasetNameDataset() will
+ # be instantiated. It has to be a subclass of BaseDataset,
+ # and it is case-insensitive.
+ dataset = None
+ # target_dataset_name = 'Dataset'
+ for name, cls in datasetlib.__dict__.items():
+ if name == target:
+ dataset = cls
+
+ if dataset is None:
+ raise ValueError("In %s.py, there should be a class "
+ "with class name that matches %s in lowercase." %
+ (dataset_filename, target))
+
+ return dataset
+
+
+def get_option_setter(dataset_name):
+ dataset_class = find_dataset_using_name(dataset_name)
+ return dataset_class.modify_commandline_options
+
+
+def create_dataloader(opt, distributed, labels_required, is_inference):
+ dataset = find_dataset_using_name(opt.type)
+ instance = dataset(opt, is_inference, labels_required)
+ phase = 'val' if is_inference else 'training'
+ batch_size = opt.val.batch_size if is_inference else opt.train.batch_size
+ print("%s dataset [%s] of size %d was created" %
+ (phase, opt.type, len(instance)))
+
+ dataloader = torch.utils.data.DataLoader(
+ instance,
+ batch_size=batch_size,
+ sampler=data_sampler(instance, shuffle=not is_inference, distributed=distributed),
+ drop_last=not is_inference,
+ num_workers=getattr(opt, 'num_workers', 0),
+ )
+
+ return dataloader
+
+
+def data_sampler(dataset, shuffle, distributed):
+ if distributed:
+ return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
+ if shuffle:
+ return torch.utils.data.RandomSampler(dataset)
+ else:
+ return torch.utils.data.SequentialSampler(dataset)
+
+
+def get_dataloader(opt, distributed, is_inference):
+ dataset = create_dataloader(opt, distributed, is_inference)
+ return dataset
+
+
+def get_train_val_dataloader(opt, labels_required=False, distributed = False):
+
+
+ val_dataset = create_dataloader(opt, distributed, labels_required = labels_required, is_inference=True,)
+ train_dataset = create_dataloader(opt, distributed, labels_required = labels_required, is_inference=False)
+
+ return val_dataset, train_dataset
+
+def pad_images(img_tensor, pad_value):
+
+ b,c,h,w = img_tensor.size()
+
+ pad = torch.ones((b,c,h,(h-w)//2)).cuda()*pad_value
+
+ return torch.cat([pad, img_tensor, pad], 3)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/PyTorch/built-in/mlm/PIDM/data/demo_appearance_dataset.py b/PyTorch/built-in/mlm/PIDM/data/demo_appearance_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..13b4eae82be0707117dabc5ebffa429d54861dd6
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/data/demo_appearance_dataset.py
@@ -0,0 +1,31 @@
+
+from data.demo_dataset import DemoDataset
+
+class DemoAppearanceDataset(DemoDataset):
+ def __init__(self, data_root, opt, load_from_dataset=False):
+ super(DemoAppearanceDataset, self).__init__(data_root, opt, load_from_dataset)
+
+ def load_item(self, garment_img_path, reference_img_path, label_path=None):
+ if self.load_from_dataset:
+ reference_img_path = self.transfrom_2_real_path(reference_img_path)
+ garment_img_path = self.transfrom_2_real_path(garment_img_path)
+ label_path = self.transfrom_2_real_path(label_path)
+ else:
+ reference_img_path = self.transfrom_2_demo_path(reference_img_path)
+ garment_img_path = self.transfrom_2_demo_path(garment_img_path)
+ label_path = self.transfrom_2_demo_path(label_path)
+
+ label_path = self.img_to_label(label_path)
+ reference_img = self.get_image_tensor(reference_img_path)[None,:]
+ garment_img = self.get_image_tensor(garment_img_path)[None,:]
+
+ label, face_center = self.get_label_tensor(label_path)
+
+ garment_label_path = self.img_to_label(garment_img_path)
+ _, garment_face_center = self.get_label_tensor(garment_label_path)
+ return {'reference_image':reference_img,
+ 'garment_image':garment_img,
+ 'target_skeleton':label[None,:],
+ 'face_center':face_center[None,:],
+ 'garment_face_center':garment_face_center[None,:],
+ }
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/PIDM/data/demo_dataset.py b/PyTorch/built-in/mlm/PIDM/data/demo_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..85c311ea9c62d6a63e0c462fe6c9dfdb6e9b7216
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/data/demo_dataset.py
@@ -0,0 +1,162 @@
+
+import os
+import cv2
+import math
+import numpy as np
+from PIL import Image
+
+import torch
+import torchvision.transforms.functional as F
+
+class DemoDataset(object):
+ def __init__(self, data_root, opt, load_from_dataset=False):
+ super().__init__()
+ self.LIMBSEQ = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
+ [1, 16], [16, 18], [3, 17], [6, 18]]
+
+ self.COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
+
+ self.img_size = tuple([int(item) for item in opt.sub_path.split('-')])
+ self.data_root = data_root
+ self.load_from_dataset = load_from_dataset # load from deepfashion dataset
+
+ def load_item(self, reference_img_path, label_path=None):
+ if self.load_from_dataset:
+ reference_img_path = self.transfrom_2_real_path(reference_img_path)
+ label_path = self.transfrom_2_real_path(label_path)
+ else:
+ reference_img_path = self.transfrom_2_demo_path(reference_img_path)
+ label_path = self.transfrom_2_demo_path(label_path)
+
+ label_path = self.img_to_label(label_path)
+ reference_img = self.get_image_tensor(reference_img_path)[None,:]
+ label, _ = self.get_label_tensor(label_path)
+ label = label[None,:]
+
+ return {'reference_image':reference_img, 'target_skeleton':label}
+
+ def get_image_tensor(self, path):
+ img = Image.open(path)
+ img = F.resize(img, self.img_size)
+ img = F.to_tensor(img)
+ img = F.normalize(img, (0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
+ return img
+
+ def get_label_tensor(self, path, param={}):
+ canvas = np.zeros((self.img_size[0], self.img_size[1], 3)).astype(np.uint8)
+ keypoint = np.loadtxt(path)
+ keypoint = self.trans_keypoins(keypoint, param, (self.img_size[0], self.img_size[1]))
+ stickwidth = 4
+ for i in range(18):
+ x, y = keypoint[i, 0:2]
+ if x == -1 or y == -1:
+ continue
+ cv2.circle(canvas, (int(x), int(y)), 4, self.COLORS[i], thickness=-1)
+ joints = []
+ for i in range(17):
+ Y = keypoint[np.array(self.LIMBSEQ[i])-1, 0]
+ X = keypoint[np.array(self.LIMBSEQ[i])-1, 1]
+ cur_canvas = canvas.copy()
+ if -1 in Y or -1 in X:
+ joints.append(np.zeros_like(cur_canvas[:, :, 0]))
+ continue
+ mX = np.mean(X)
+ mY = np.mean(Y)
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
+ cv2.fillConvexPoly(cur_canvas, polygon, self.COLORS[i])
+ canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
+
+ joint = np.zeros_like(cur_canvas[:, :, 0])
+ cv2.fillConvexPoly(joint, polygon, 255)
+ joint = cv2.addWeighted(joint, 0.4, joint, 0.6, 0)
+ joints.append(joint)
+ pose = F.to_tensor(Image.fromarray(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)))
+
+ tensors_dist = 0
+ e = 1
+ for i in range(len(joints)):
+ im_dist = cv2.distanceTransform(255-joints[i], cv2.DIST_L1, 3)
+ im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8)
+ tensor_dist = F.to_tensor(Image.fromarray(im_dist))
+ tensors_dist = tensor_dist if e == 1 else torch.cat([tensors_dist, tensor_dist])
+ e += 1
+
+ label_tensor = torch.cat((pose, tensors_dist), dim=0)
+ if int(keypoint[14, 0]) != -1 and int(keypoint[15, 0]) != -1:
+ y0, x0 = keypoint[14, 0:2]
+ y1, x1 = keypoint[15, 0:2]
+ face_center = torch.tensor([y0, x0, y1, x1]).float()
+ else:
+ face_center = torch.tensor([-1, -1, -1, -1]).float()
+ return label_tensor, face_center
+
+ def transfrom_2_demo_path(self, item):
+ item = os.path.join(self.data_root ,item)
+ return item
+
+ def transfrom_2_real_path(self, item):
+ # item, ext = os.path.splitext(item)
+ if 'WOMEN' in item:
+ item = os.path.basename(item)
+ path = ['img/WOMEN']
+ name = item.split('WOMEN')[-1]
+ elif 'MEN' in item:
+ item = os.path.basename(item)
+ path = ['img/MEN']
+ name = item.split('MEN')[-1]
+ else:
+ return item
+ path.append(name.split('id0')[0])
+ path.append('id_0'+ name.split('id0')[-1][:7])
+ filename = name.split(name.split('id0')[-1][:7])[-1]
+ count=0
+ for i in filename.split('_')[-1]:
+ try:
+ int(i)
+ count+=1
+ except:
+ pass
+ filename = filename.split('_')[0]+'_' \
+ +filename.split('_')[-1][:count] + '_' \
+ +filename.split('_')[-1][count:]
+ path.append(filename)
+ path = os.path.join(*path)
+ return os.path.join(self.data_root, path)
+
+ def img_to_label(self, path):
+ return path.replace('img/', 'pose/').replace('.png', '.txt').replace('.jpg', '.txt')
+
+ def trans_keypoins(self, keypoints, param, img_size):
+ # find missing index
+ missing_keypoint_index = keypoints == -1
+
+ # crop the white line in the original dataset
+ keypoints[:,0] = (keypoints[:,0]-40)
+
+ # resize the dataset
+ img_h, img_w = img_size
+ scale_w = 1.0/176.0 * img_w
+ scale_h = 1.0/256.0 * img_h
+
+ if 'scale_size' in param and param['scale_size'] is not None:
+ new_h, new_w = param['scale_size']
+ scale_w = scale_w / img_w * new_w
+ scale_h = scale_h / img_h * new_h
+
+
+ if 'crop_param' in param and param['crop_param'] is not None:
+ w, h, _, _ = param['crop_param']
+ else:
+ w, h = 0, 0
+
+ keypoints[:,0] = keypoints[:,0]*scale_w - w
+ keypoints[:,1] = keypoints[:,1]*scale_h - h
+ keypoints[missing_keypoint_index] = -1
+ return keypoints
+
+
diff --git a/PyTorch/built-in/mlm/PIDM/data/fashion e-commerce images/s-anu1153-annu-paridhan-original-imagburgxue2fesm.jpeg b/PyTorch/built-in/mlm/PIDM/data/fashion e-commerce images/s-anu1153-annu-paridhan-original-imagburgxue2fesm.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..2e587dd6754d30bd3f4aef4c28b9ac5d5a78eb68
Binary files /dev/null and b/PyTorch/built-in/mlm/PIDM/data/fashion e-commerce images/s-anu1153-annu-paridhan-original-imagburgxue2fesm.jpeg differ
diff --git a/PyTorch/built-in/mlm/PIDM/data/fashion_base_function.py b/PyTorch/built-in/mlm/PIDM/data/fashion_base_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6e4cb6560512aec0ef077b3d4a8565a55d3082c
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/data/fashion_base_function.py
@@ -0,0 +1,42 @@
+import random
+import numpy as np
+from PIL import Image
+
+import torchvision.transforms as transforms
+import torchvision.transforms.functional as F
+
+def get_random_params(size, scale_param):
+ w, h = size
+ scale = random.random() * scale_param
+
+ new_w = int( w * (1.0+scale) )
+ new_h = int( h * (1.0+scale) )
+ x = random.randint(0, np.maximum(0, new_w - w))
+ y = random.randint(0, np.maximum(0, new_h - h))
+ return {'crop_param': (x, y, w, h), 'scale_size':(new_h, new_w)}
+
+
+def get_transform(param, method=Image.BICUBIC, normalize=True, toTensor=True):
+ transform_list = []
+ if 'scale_size' in param and param['scale_size'] is not None:
+ osize = param['scale_size']
+ transform_list.append(transforms.Resize(osize, interpolation=method))
+
+ if 'crop_param' in param and param['crop_param'] is not None:
+ transform_list.append(transforms.Lambda(lambda img: __crop(img, param['crop_param'])))
+
+ if toTensor:
+ transform_list += [transforms.ToTensor()]
+
+ if normalize:
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))]
+ return transforms.Compose(transform_list)
+
+def __crop(img, pos):
+ x1, y1, tw, th = pos
+ return img.crop((x1, y1, x1 + tw, y1 + th))
+
+def normalize():
+ return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+
diff --git a/PyTorch/built-in/mlm/PIDM/data/fashion_data.py b/PyTorch/built-in/mlm/PIDM/data/fashion_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cf393dae2f5f1c1074fff594edbd5474a09cd9a
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/data/fashion_data.py
@@ -0,0 +1,468 @@
+import os
+import cv2
+import math
+import lmdb
+import numpy as np
+from io import BytesIO
+from PIL import Image
+
+import torch
+import torchvision.transforms.functional as F
+from torch.utils.data import Dataset
+
+from data.fashion_base_function import get_random_params, get_transform
+
+def bin2dec(b, bits):
+ mask = 2 ** torch.arange(bits - 1, -1, -1).to(b.device, b.dtype)
+ return torch.sum(mask * b, -1)
+
+class Dataset(Dataset):
+ def __init__(self, opt, is_inference, labels_required = False):
+ self.root = opt.path
+ self.semantic_path = self.root
+ path = os.path.join(self.root, str(opt.sub_path))
+ self.path = path
+ self.labels_required = labels_required
+
+ # self.env = lmdb.open(
+ # path,
+ # max_readers=32,
+ # readonly=True,
+ # lock=False,
+ # readahead=False,
+ # meminit=False,
+ # )
+
+ # if not self.env:
+ # raise IOError('Cannot open lmdb dataset', path)
+
+ self.file_path = 'train_pairs.txt' if not is_inference else 'test_pairs.txt'
+ self.data = self.get_paths(self.root, self.file_path)
+ self.is_inference = is_inference
+ self.preprocess_mode = opt.preprocess_mode
+ self.scale_param = opt.scale_param if not is_inference else 0
+ self.limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
+ [1, 16], [16, 18], [3, 17], [6, 18]]
+
+ self.colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
+
+ def __len__(self):
+ return len(self.data)
+
+
+ def get_paths(self, root, path):
+ fd = open(os.path.join(root, path))
+ lines = fd.readlines()
+ fd.close()
+
+ image_paths = []
+ for item in lines:
+ dict_item={}
+ item = item.strip().split(',')
+ dict_item['source_image'] = [path.replace('.jpg', '.png') for path in item[1:]]
+ dict_item['source_label'] = [os.path.join(self.semantic_path, self.img_to_label(path)) for path in dict_item['source_image']]
+ dict_item['target_image'] = item[0].replace('.jpg', '.png')
+ dict_item['target_label'] = os.path.join(self.semantic_path, self.img_to_label(dict_item['target_image']))
+ image_paths.append(dict_item)
+ return image_paths
+
+ def open_lmdb(self):
+
+ self.env = lmdb.open(
+ self.path,
+ max_readers=32,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False,
+ )
+ self.txn = self.env.begin(buffers=True)
+
+ def __getitem__(self, index):
+
+ if not hasattr(self, 'txn'):
+ self.open_lmdb()
+
+
+ path_item = self.data[index]
+ i = np.random.choice(list(range(0, len(path_item['source_image']))))
+ source_image_path = path_item['source_image'][i]
+ source_label_path = path_item['source_label'][i]
+
+ target_image_tensor, param = self.get_image_tensor(path_item['target_image'])
+
+ if self.labels_required:
+ target_label_tensor, target_face_center = self.get_label_tensor(path_item['target_label'], target_image_tensor, param)
+
+ ref_tensor, param = self.get_image_tensor(source_image_path)
+
+ if self.labels_required:
+ label_ref_tensor, ref_face_center = self.get_label_tensor(source_label_path, ref_tensor, param)
+
+ image_path = self.get_image_path(source_image_path, path_item['target_image'])
+
+ if not self.is_inference:
+
+ if torch.rand(1) < 0.5:
+
+ target_image_tensor = F.hflip(target_image_tensor)
+ ref_tensor = F.hflip(ref_tensor)
+
+ if self.labels_required:
+
+ target_label_tensor = F.hflip(target_label_tensor)
+ label_ref_tensor = F.hflip(label_ref_tensor)
+
+
+ if self.labels_required:
+
+ input_dict = {'target_skeleton': target_label_tensor,
+ 'target_image': target_image_tensor,
+ 'target_face_center': target_face_center,
+
+ 'source_image': ref_tensor,
+ 'source_skeleton': label_ref_tensor,
+ 'source_face_center': ref_face_center,
+
+ 'path': image_path,
+ }
+
+ else:
+
+ input_dict = {'target_image': target_image_tensor,
+ 'source_image': ref_tensor,
+ }
+
+ return input_dict
+
+ def get_image_path(self, source_name, target_name):
+ source_name = self.path_to_fashion_name(source_name)
+ target_name = self.path_to_fashion_name(target_name)
+ image_path = os.path.splitext(source_name)[0] + '_2_' + os.path.splitext(target_name)[0]+'_vis.png'
+ return image_path
+
+ def path_to_fashion_name(self, path_in):
+ path_in = path_in.split('img/')[-1]
+ path_in = os.path.join('fashion', path_in)
+ path_names = path_in.split('/')
+ path_names[3] = path_names[3].replace('_', '')
+ path_names[4] = path_names[4].split('_')[0] + "_" + "".join(path_names[4].split('_')[1:])
+ path_names = "".join(path_names)
+ return path_names
+
+ def img_to_label(self, path):
+ return path.replace('img/', 'pose/').replace('.png', '.txt')
+
+ def get_image_tensor(self, path):
+ with self.env.begin(write=False) as txn:
+ key = f'{path}'.encode('utf-8')
+ img_bytes = txn.get(key)
+ buffer = BytesIO(img_bytes)
+ img = Image.open(buffer)
+ param = get_random_params(img.size, self.scale_param)
+ trans = get_transform(param, normalize=True, toTensor=True)
+ img = trans(img)
+ return img, param
+
+ def get_label_tensor(self, path, img, param):
+ canvas = np.zeros((img.shape[1], img.shape[2], 3)).astype(np.uint8)
+ keypoint = np.loadtxt(path)
+ keypoint = self.trans_keypoins(keypoint, param, img.shape[1:])
+ stickwidth = 4
+ for i in range(18):
+ x, y = keypoint[i, 0:2]
+ if x == -1 or y == -1:
+ continue
+ cv2.circle(canvas, (int(x), int(y)), 4, self.colors[i], thickness=-1)
+ joints = []
+ for i in range(17):
+ Y = keypoint[np.array(self.limbSeq[i])-1, 0]
+ X = keypoint[np.array(self.limbSeq[i])-1, 1]
+ cur_canvas = canvas.copy()
+ if -1 in Y or -1 in X:
+ joints.append(np.zeros_like(cur_canvas[:, :, 0]))
+ continue
+ mX = np.mean(X)
+ mY = np.mean(Y)
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
+ cv2.fillConvexPoly(cur_canvas, polygon, self.colors[i])
+ canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
+
+ joint = np.zeros_like(cur_canvas[:, :, 0])
+ cv2.fillConvexPoly(joint, polygon, 255)
+ joint = cv2.addWeighted(joint, 0.4, joint, 0.6, 0)
+ joints.append(joint)
+ pose = F.to_tensor(Image.fromarray(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)))
+
+ tensors_dist = 0
+ e = 1
+ for i in range(len(joints)):
+ im_dist = cv2.distanceTransform(255-joints[i], cv2.DIST_L1, 3)
+ im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8)
+ tensor_dist = F.to_tensor(Image.fromarray(im_dist))
+ tensors_dist = tensor_dist if e == 1 else torch.cat([tensors_dist, tensor_dist])
+ e += 1
+
+ label_tensor = torch.cat((pose, tensors_dist), dim=0)
+ if int(keypoint[14, 0]) != -1 and int(keypoint[15, 0]) != -1:
+ y0, x0 = keypoint[14, 0:2]
+ y1, x1 = keypoint[15, 0:2]
+ face_center = torch.tensor([y0, x0, y1, x1]).float()
+ else:
+ face_center = torch.tensor([-1, -1, -1, -1]).float()
+ return label_tensor, face_center
+
+
+ def trans_keypoins(self, keypoints, param, img_size):
+ missing_keypoint_index = keypoints == -1
+
+ # crop the white line in the original dataset
+ keypoints[:,0] = (keypoints[:,0]-40)
+
+ # resize the dataset
+ img_h, img_w = img_size
+ scale_w = 1.0/176.0 * img_w
+ scale_h = 1.0/256.0 * img_h
+
+ if 'scale_size' in param and param['scale_size'] is not None:
+ new_h, new_w = param['scale_size']
+ scale_w = scale_w / img_w * new_w
+ scale_h = scale_h / img_h * new_h
+
+
+ if 'crop_param' in param and param['crop_param'] is not None:
+ w, h, _, _ = param['crop_param']
+ else:
+ w, h = 0, 0
+
+ keypoints[:,0] = keypoints[:,0]*scale_w - w
+ keypoints[:,1] = keypoints[:,1]*scale_h - h
+ keypoints[missing_keypoint_index] = -1
+ return keypoints
+
+
+
+class Dataset_guide(Dataset):
+ def __init__(self, opt, is_inference, labels_required = False):
+ self.root = opt.path
+ self.semantic_path = self.root
+ path = os.path.join(self.root, str(opt.sub_path))
+ self.path = path
+ self.labels_required = labels_required
+
+
+ # self.env = lmdb.open(
+ # path,
+ # max_readers=32,
+ # readonly=True,
+ # lock=False,
+ # readahead=False,
+ # meminit=False,
+ # )
+
+ # if not self.env:
+ # raise IOError('Cannot open lmdb dataset', path)
+
+ self.file_path = 'train_pairs.txt' if not is_inference else 'test_pairs.txt'
+ self.data = self.get_paths(self.root, self.file_path)
+ self.len_ = len(self.data)
+ self.is_inference = is_inference
+ self.preprocess_mode = opt.preprocess_mode
+ self.scale_param = opt.scale_param if not is_inference else 0
+ self.limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
+ [1, 16], [16, 18], [3, 17], [6, 18]]
+
+ self.colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
+
+ def __len__(self):
+ return len(self.data)
+
+
+ def get_paths(self, root, path):
+ fd = open(os.path.join(root, path))
+ lines = fd.readlines()
+ fd.close()
+
+ image_paths = []
+ for item in lines:
+ dict_item={}
+ item = item.strip().split(',')
+ dict_item['source_image'] = [path.replace('.jpg', '.png') for path in item[1:]]
+ dict_item['source_label'] = [os.path.join(self.semantic_path, self.img_to_label(path)) for path in dict_item['source_image']]
+ dict_item['target_image'] = item[0].replace('.jpg', '.png')
+ dict_item['target_label'] = os.path.join(self.semantic_path, self.img_to_label(dict_item['target_image']))
+ image_paths.append(dict_item)
+ return image_paths
+
+ def open_lmdb(self):
+
+ self.env = lmdb.open(
+ self.path,
+ max_readers=32,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False,
+ )
+ self.txn = self.env.begin(buffers=True)
+
+ def __getitem__(self, index):
+
+ if not hasattr(self, 'txn'):
+ self.open_lmdb()
+
+ path_item = self.data[index]
+
+ guide_labels = (torch.rand(2) < 0.5)
+
+
+ if guide_labels[0]:
+ path_item_ref = path_item
+ else:
+ path_item_ref = self.data[np.random.choice(list(set(list(np.arange(self.len_)))-{index}))]
+
+ if guide_labels[1]:
+ path_item_pose = path_item
+ else:
+ path_item_pose = self.data[np.random.choice(list(set(list(np.arange(self.len_)))-{index}))]
+
+
+ i = np.random.choice(list(range(0, len(path_item_ref['source_image']))))
+
+ ref_image_path = path_item_ref['source_image'][i]
+ target_image_path = path_item['target_image']
+ pose_label_path = path_item_pose['target_label']
+
+ ref_image_tensor, param = self.get_image_tensor(ref_image_path)
+ target_image_tensor, param = self.get_image_tensor(target_image_path)
+ target_label_tensor, target_face_center = self.get_label_tensor(pose_label_path, target_image_tensor, param)
+
+ if torch.rand(1) < 0.5:
+
+ ref_image_tensor = F.hflip(ref_image_tensor)
+ target_image_tensor = F.hflip(target_image_tensor)
+ target_label_tensor = F.hflip(target_label_tensor)
+
+
+ return {'pose_skeleton': target_label_tensor,
+ 'target_image': target_image_tensor,
+ 'ref_image_tensor': ref_image_tensor,
+ 'guide_label' : bin2dec(guide_labels, 2)}
+
+ def get_image_path(self, source_name, target_name):
+ source_name = self.path_to_fashion_name(source_name)
+ target_name = self.path_to_fashion_name(target_name)
+ image_path = os.path.splitext(source_name)[0] + '_2_' + os.path.splitext(target_name)[0]+'_vis.png'
+ return image_path
+
+ def path_to_fashion_name(self, path_in):
+ path_in = path_in.split('img/')[-1]
+ path_in = os.path.join('fashion', path_in)
+ path_names = path_in.split('/')
+ path_names[3] = path_names[3].replace('_', '')
+ path_names[4] = path_names[4].split('_')[0] + "_" + "".join(path_names[4].split('_')[1:])
+ path_names = "".join(path_names)
+ return path_names
+
+ def img_to_label(self, path):
+ return path.replace('img/', 'pose/').replace('.png', '.txt')
+
+ def get_image_tensor(self, path):
+ with self.env.begin(write=False) as txn:
+ key = f'{path}'.encode('utf-8')
+ img_bytes = txn.get(key)
+ buffer = BytesIO(img_bytes)
+ img = Image.open(buffer)
+ param = get_random_params(img.size, self.scale_param)
+ trans = get_transform(param, normalize=True, toTensor=True)
+ img = trans(img)
+ return img, param
+
+ def get_label_tensor(self, path, img, param):
+ canvas = np.zeros((img.shape[1], img.shape[2], 3)).astype(np.uint8)
+ keypoint = np.loadtxt(path)
+ keypoint = self.trans_keypoins(keypoint, param, img.shape[1:])
+ stickwidth = 4
+ for i in range(18):
+ x, y = keypoint[i, 0:2]
+ if x == -1 or y == -1:
+ continue
+ cv2.circle(canvas, (int(x), int(y)), 4, self.colors[i], thickness=-1)
+ joints = []
+ for i in range(17):
+ Y = keypoint[np.array(self.limbSeq[i])-1, 0]
+ X = keypoint[np.array(self.limbSeq[i])-1, 1]
+ cur_canvas = canvas.copy()
+ if -1 in Y or -1 in X:
+ joints.append(np.zeros_like(cur_canvas[:, :, 0]))
+ continue
+ mX = np.mean(X)
+ mY = np.mean(Y)
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
+ cv2.fillConvexPoly(cur_canvas, polygon, self.colors[i])
+ canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
+
+ joint = np.zeros_like(cur_canvas[:, :, 0])
+ cv2.fillConvexPoly(joint, polygon, 255)
+ joint = cv2.addWeighted(joint, 0.4, joint, 0.6, 0)
+ joints.append(joint)
+ pose = F.to_tensor(Image.fromarray(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)))
+
+ tensors_dist = 0
+ e = 1
+ for i in range(len(joints)):
+ im_dist = cv2.distanceTransform(255-joints[i], cv2.DIST_L1, 3)
+ im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8)
+ tensor_dist = F.to_tensor(Image.fromarray(im_dist))
+ tensors_dist = tensor_dist if e == 1 else torch.cat([tensors_dist, tensor_dist])
+ e += 1
+
+ label_tensor = torch.cat((pose, tensors_dist), dim=0)
+ if int(keypoint[14, 0]) != -1 and int(keypoint[15, 0]) != -1:
+ y0, x0 = keypoint[14, 0:2]
+ y1, x1 = keypoint[15, 0:2]
+ face_center = torch.tensor([y0, x0, y1, x1]).float()
+ else:
+ face_center = torch.tensor([-1, -1, -1, -1]).float()
+ return label_tensor, face_center
+
+
+ def trans_keypoins(self, keypoints, param, img_size):
+ missing_keypoint_index = keypoints == -1
+
+ # crop the white line in the original dataset
+ keypoints[:,0] = (keypoints[:,0]-40)
+
+ # resize the dataset
+ img_h, img_w = img_size
+ scale_w = 1.0/176.0 * img_w
+ scale_h = 1.0/256.0 * img_h
+
+ if 'scale_size' in param and param['scale_size'] is not None:
+ new_h, new_w = param['scale_size']
+ scale_w = scale_w / img_w * new_w
+ scale_h = scale_h / img_h * new_h
+
+
+ if 'crop_param' in param and param['crop_param'] is not None:
+ w, h, _, _ = param['crop_param']
+ else:
+ w, h = 0, 0
+
+ keypoints[:,0] = keypoints[:,0]*scale_w - w
+ keypoints[:,1] = keypoints[:,1]*scale_h - h
+ keypoints[missing_keypoint_index] = -1
+ return keypoints
+
+
+
diff --git a/PyTorch/built-in/mlm/PIDM/data/prepare_data.py b/PyTorch/built-in/mlm/PIDM/data/prepare_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..90c5012fa2d1812e5f7a0085024a8e961699a3e8
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/data/prepare_data.py
@@ -0,0 +1,88 @@
+import os
+import lmdb
+import argparse
+import multiprocessing
+from PIL import Image
+from tqdm import tqdm
+from io import BytesIO
+
+from torchvision.transforms import functional as trans_fn
+
+
+def format_for_lmdb(*args):
+ key_parts = []
+ for arg in args:
+ if isinstance(arg, int):
+ arg = str(arg).zfill(5)
+ key_parts.append(arg)
+ return '-'.join(key_parts).encode('utf-8')
+
+class Resizer:
+ def __init__(self, *, size, root):
+ self.size = size
+ self.root = root
+
+ def get_resized_bytes(self, img):
+ img = trans_fn.resize(img, self.size, Image.BICUBIC)
+ buf = BytesIO()
+ img.save(buf, format='png')
+ img_bytes = buf.getvalue()
+ return img_bytes
+
+ def prepare(self, filename):
+ filename = os.path.join(self.root, filename)
+ img = Image.open(filename)
+ img = img.convert('RGB')
+ img_bytes = self.get_resized_bytes(img)
+ return img_bytes
+
+ def __call__(self, index_filename):
+ index, filename = index_filename
+ result_img = self.prepare(filename)
+ return index, result_img, filename
+
+
+def prepare_data(root, dataset, out, n_worker, sizes, chunksize):
+ assert dataset in ['deepfashion']
+ if dataset == 'deepfashion':
+ file_txt = '{}/train_pairs.txt'.format(root)
+ filenames = []
+ with open(file_txt, 'r') as f:
+ lines = f.readlines()
+ for item in lines:
+ filenames.extend(item.strip().split(','))
+
+ file_txt = '{}/test_pairs.txt'.format(root)
+ with open(file_txt, 'r') as f:
+ lines = f.readlines()
+ for item in lines:
+ filenames.extend(item.strip().split(','))
+ filenames = list(set(filenames))
+
+
+ total = len(filenames)
+ os.makedirs(out, exist_ok=True)
+
+ for size in sizes:
+ lmdb_path = os.path.join(out, str('-'.join([str(item) for item in size])))
+ with lmdb.open(lmdb_path, map_size=1024 ** 4, readahead=False) as env:
+ with env.begin(write=True) as txn:
+ txn.put(format_for_lmdb('length'), format_for_lmdb(total))
+ resizer = Resizer(size=size, root=root)
+ with multiprocessing.Pool(n_worker) as pool:
+ for idx, result_img, filename in tqdm(
+ pool.imap_unordered(resizer, enumerate(filenames), chunksize=chunksize),
+ total=total):
+ filename = os.path.splitext(filename)[0] + '.png'
+ txn.put(format_for_lmdb(filename), result_img)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('--root', type=str, help='a path to output directory')
+ parser.add_argument('--dataset', type=str, default='deepfashion', help='a path to output directory')
+ parser.add_argument('--out', type=str, help='a path to output directory')
+ parser.add_argument('--sizes', type=int, nargs='+', default=((256, 256),) )
+ parser.add_argument('--n_worker', type=int, help='number of worker processes', default=8)
+ parser.add_argument('--chunksize', type=int, help='approximate chunksize for each worker', default=10)
+ args = parser.parse_args()
+ prepare_data(**vars(args))
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/PIDM/diffusion.py b/PyTorch/built-in/mlm/PIDM/diffusion.py
new file mode 100755
index 0000000000000000000000000000000000000000..54cd1fc6f924543684641bef339270746bdd7b91
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/diffusion.py
@@ -0,0 +1,1149 @@
+"""
+This code started out as a PyTorch port of Ho et al's diffusion models:
+https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
+Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
+"""
+
+import enum
+import math
+
+import numpy as np
+import torch as th
+import torch
+from models.nn import mean_flat
+from models.losses import normal_kl, discretized_gaussian_log_likelihood
+from types import *
+
+import torch
+import tqdm
+
+def compute_alpha(beta, t):
+ beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
+ a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
+ return a
+
+
+def ddim_steps(x, seq, model, b, x_cond, diffusion = None, **kwargs):
+ x_cond[0] = [model.encode(x_cond[0])['cond'], model.encode(torch.zeros_like(x_cond[0]))['cond']]
+
+ with torch.no_grad():
+ n = x.size(0)
+ seq_next = [-1] + list(seq[:-1])
+ x0_preds = []
+ xs = [x]
+ xt = x
+ for i, j in tqdm.tqdm(zip(reversed(seq), reversed(seq_next))):
+ t = (torch.ones(n) * i).to(x.device)
+ next_t = (torch.ones(n) * j).to(x.device)
+ at = compute_alpha(b, t.long())
+ at_next = compute_alpha(b, next_t.long())
+
+ [cond, target_pose] = x_cond[:2]
+ et = model.forward_with_cond_scale(x = torch.cat([xt, target_pose],1), t = t, cond = cond, cond_scale = 2)[0]
+ et, model_var_values = torch.split(et, 3, dim=1)
+ x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
+ #x0_preds.append(x0_t.to('cpu'))
+ c1 = (
+ kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
+ )
+ c2 = ((1 - at_next) - c1 ** 2).sqrt()
+ xt = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
+
+ if len(x_cond) == 4:
+ [_,_,ref,mask] = x_cond
+ xt = xt*mask + diffusion.q_sample(ref, t.long())*(1-mask)
+ #xs.append(xt_next.to('cpu'))
+
+ return [xt], x0_preds
+
+
+
+
+def make_beta_schedule(
+ schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
+):
+ if schedule == "quad":
+ betas = (
+ torch.linspace(
+ linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64
+ )
+ ** 2
+ )
+
+ elif schedule == "linear":
+ betas = torch.linspace(
+ linear_start, linear_end, n_timestep, dtype=torch.float64
+ )
+
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * math.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = betas.clamp(max=0.999)
+
+ return betas
+
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+class ModelMeanType(enum.Enum):
+ """
+ Which type of output the model predicts.
+ """
+
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
+ START_X = enum.auto() # the model predicts x_0
+ EPSILON = enum.auto() # the model predicts epsilon
+
+
+class ModelVarType(enum.Enum):
+ """
+ What is used as the model's output variance.
+ The LEARNED_RANGE option has been added to allow the model to predict
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
+ """
+
+ LEARNED = enum.auto()
+ FIXED_SMALL = enum.auto()
+ FIXED_LARGE = enum.auto()
+ LEARNED_RANGE = enum.auto()
+
+
+class LossType(enum.Enum):
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
+ RESCALED_MSE = (
+ enum.auto()
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
+ KL = enum.auto() # use the variational lower-bound
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
+
+ def is_vb(self):
+ return self == LossType.KL or self == LossType.RESCALED_KL
+
+
+class GaussianDiffusion:
+ """
+ Utilities for training and sampling diffusion models.
+ Ported directly from here, and then adapted over time to further experimentation.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
+ starting at T and going to 1.
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
+ :param model_var_type: a ModelVarType determining how variance is output.
+ :param loss_type: a LossType determining the loss function to use.
+ :param rescale_timesteps: if True, pass floating point timesteps into the
+ model so that they are always scaled like in the
+ original paper (0 to 1000).
+ """
+
+ def __init__(
+ self,
+ *,
+ betas,
+ model_mean_type,
+ model_var_type,
+ loss_type,
+ rescale_timesteps=False,
+ ):
+ self.model_mean_type = model_mean_type
+ self.model_var_type = model_var_type
+ self.loss_type = loss_type
+ self.rescale_timesteps = rescale_timesteps
+
+ # Use float64 for accuracy.
+ betas = np.array(betas, dtype=np.float64)
+ self.betas = betas
+ assert len(betas.shape) == 1, "betas must be 1-D"
+ assert (betas > 0).all() and (betas <= 1).all()
+
+ self.num_timesteps = int(betas.shape[0])
+
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ self.posterior_variance = (
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ # log calculation clipped because the posterior variance is 0 at the
+ # beginning of the diffusion chain.
+ self.posterior_log_variance_clipped = np.log(
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
+ )
+ self.posterior_mean_coef1 = (
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ self.posterior_mean_coef2 = (
+ (1.0 - self.alphas_cumprod_prev)
+ * np.sqrt(alphas)
+ / (1.0 - self.alphas_cumprod)
+ )
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ )
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = _extract_into_tensor(
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
+ )
+ return mean, variance, log_variance
+
+ def q_sample(self, x_start, t, noise=None):
+ """
+ Diffuse the data for a given number of diffusion steps.
+ In other words, sample from q(x_t | x_0).
+ :param x_start: the initial data batch.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :param noise: if specified, the split-out normal noise.
+ :return: A noisy version of x_start.
+ """
+ if noise is None:
+ noise = th.randn_like(x_start)
+ assert noise.shape == x_start.shape
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+ * noise
+ )
+
+ def q_posterior_mean_variance(self, x_start, x_t, t):
+ """
+ Compute the mean and variance of the diffusion posterior:
+ q(x_{t-1} | x_t, x_0)
+ """
+ assert x_start.shape == x_t.shape
+ posterior_mean = (
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape
+ )
+ assert (
+ posterior_mean.shape[0]
+ == posterior_variance.shape[0]
+ == posterior_log_variance_clipped.shape[0]
+ == x_start.shape[0]
+ )
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(
+ self, model, x, t, x_cond = None, cond_scale = 1, clip_denoised=True, denoised_fn=None, model_kwargs=None, normalize=False
+ ):
+ """
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
+ the initial x, x_0.
+ :param model: the model, which takes a signal and a batch of timesteps
+ as input.
+ :param x: the [N x C x ...] tensor at time t.
+ :param t: a 1-D Tensor of timesteps.
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample. Applies before
+ clip_denoised.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict with the following keys:
+ - 'mean': the model mean output.
+ - 'variance': the model variance output.
+ - 'log_variance': the log of 'variance'.
+ - 'pred_xstart': the prediction for x_0.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ if clip_denoised:
+ return x.clamp(-1, 1)
+ return x
+
+ def get_mean_var_from_eps(model_output):
+ model_output, model_var_values = th.split(model_output, C, dim=1)
+
+ min_log = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x.shape
+ )
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
+ # The model_var_values is [-1, 1] for [min_var, max_var].
+ frac = (model_var_values + 1) / 2
+ model_log_variance = frac * max_log + (1 - frac) * min_log
+ model_variance = th.exp(model_log_variance)
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
+ )
+ model_mean, _, _ = self.q_posterior_mean_variance(
+ x_start=pred_xstart, x_t=x, t=t
+ )
+ return model_mean, model_variance, model_log_variance, pred_xstart
+
+ B, C = x.shape[:2]
+ assert t.shape == (B,)
+ if x_cond is None:
+ model_output = model(x, self._scale_timesteps(t), **model_kwargs)
+ else:
+ [cond, target_pose] = x_cond
+ model_output, cond_output, uncond_output = model.forward_with_cond_scale(x = torch.cat([x, target_pose],1), t = self._scale_timesteps(t), cond = cond, cond_scale = cond_scale)
+
+ model_mean, model_variance, model_log_variance, pred_xstart = get_mean_var_from_eps(model_output)
+
+ return {
+ "mean": model_mean,
+ "variance": model_variance,
+ "log_variance": model_log_variance,
+ "pred_xstart": pred_xstart,
+ }
+
+ #out_eps[:,:3,:,:] = torch.sqrt(cond_model_variance/out_variance)*(out_eps[:,:3,:,:] - out_mean) + cond_model_mean
+ # #model_output = out_eps
+ # else:
+
+ # return {
+ # "mean": model_mean,
+ # "variance": model_variance,
+ # "log_variance": model_log_variance,
+ # "pred_xstart": pred_xstart}
+
+
+
+
+ #model_mean, model_variance, model_log_variance, pred_xstart = get_mean_var_from_eps(model_output)
+
+
+
+ # if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
+ # assert model_output.shape == (B, C * 2, *x.shape[2:])
+ # model_output, model_var_values = th.split(model_output, C, dim=1)
+ # if self.model_var_type == ModelVarType.LEARNED:
+ # model_log_variance = model_var_values
+ # model_variance = th.exp(model_log_variance)
+ # else:
+ # min_log = _extract_into_tensor(
+ # self.posterior_log_variance_clipped, t, x.shape
+ # )
+ # max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
+ # # The model_var_values is [-1, 1] for [min_var, max_var].
+ # frac = (model_var_values + 1) / 2
+ # model_log_variance = frac * max_log + (1 - frac) * min_log
+ # model_variance = th.exp(model_log_variance)
+ # else:
+ # model_variance, model_log_variance = {
+ # # for fixedlarge, we set the initial (log-)variance like so
+ # # to get a better decoder log likelihood.
+ # ModelVarType.FIXED_LARGE: (
+ # np.append(self.posterior_variance[1], self.betas[1:]),
+ # np.log(np.append(self.posterior_variance[1], self.betas[1:])),
+ # ),
+ # ModelVarType.FIXED_SMALL: (
+ # self.posterior_variance,
+ # self.posterior_log_variance_clipped,
+ # ),
+ # }[self.model_var_type]
+ # model_variance = _extract_into_tensor(model_variance, t, x.shape)
+ # model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
+
+
+ # if self.model_mean_type == ModelMeanType.PREVIOUS_X:
+ # pred_xstart = process_xstart(
+ # self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
+ # )
+ # model_mean = model_output
+ # elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
+ # if self.model_mean_type == ModelMeanType.START_X:
+ # pred_xstart = process_xstart(model_output)
+ # else:
+ # pred_xstart = process_xstart(
+ # self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
+ # )
+ # model_mean, _, _ = self.q_posterior_mean_variance(
+ # x_start=pred_xstart, x_t=x, t=t
+ # )
+ # else:
+ # raise NotImplementedError(self.model_mean_type)
+
+ # assert (
+ # model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
+ # )
+
+
+ def _predict_xstart_from_eps(self, x_t, t, eps):
+ assert x_t.shape == eps.shape
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
+ )
+
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
+ assert x_t.shape == xprev.shape
+ return ( # (xprev - coef2*x_t) / coef1
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
+ - _extract_into_tensor(
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
+ )
+ * x_t
+ )
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - pred_xstart
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _scale_timesteps(self, t):
+ if self.rescale_timesteps:
+ return t.float() * (1000.0 / self.num_timesteps)
+ return t
+
+ def p_sample(
+ self, model, x_cond, x, cond_scale, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+ :param model: the model to sample from.
+ :param x: the current tensor at x_{t-1}.
+ :param t: the value of t, starting at 0 for the first diffusion step.
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - 'sample': a random sample from the model.
+ - 'pred_xstart': a prediction of x_0.
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ x_cond,
+ cond_scale,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ noise = th.randn_like(x)
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
+
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def p_sample_loop(
+ self,
+ model,
+ x_cond,
+ cond_scale,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ history=False
+ ):
+ """
+ Generate samples from the model.
+ :param model: the model module.
+ :param shape: the shape of the samples, (N, C, H, W).
+ :param noise: if specified, the noise from the encoder to sample.
+ Should be of the same shape as `shape`.
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param device: if specified, the device to create the samples on.
+ If not specified, use a model parameter's device.
+ :param progress: if True, show a tqdm progress bar.
+ :return: a non-differentiable batch of samples.
+ """
+ shape = x_cond[0].shape
+ final = None
+ x_cond[0] = [model.encode(x_cond[0])['cond'], model.encode(torch.zeros_like(x_cond[0]))['cond']]
+
+ samples_list = []
+ for sample in self.p_sample_loop_progressive(
+ model,
+ x_cond,
+ cond_scale,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ ):
+ final = sample
+
+ if history:
+ samples_list.append(sample["sample"])
+
+ if history:
+ return samples_list[::100] + [final["sample"]]
+
+ else:
+ return final["sample"]
+
+ def p_sample_loop_progressive(
+ self,
+ model,
+ x_cond,
+ cond_scale,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model and yield intermediate samples from
+ each timestep of diffusion.
+ Arguments are the same as p_sample_loop().
+ Returns a generator over dicts, where each dict is the return value of
+ p_sample().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.p_sample(
+ model,
+ x_cond[:2],
+ img,
+ cond_scale,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ yield out
+ if len(x_cond) == 4:
+ [_,_,ref,mask] = x_cond
+ img = out["sample"]*mask + self.q_sample(ref, t)*(1-mask)
+ else:
+ img = out["sample"]
+
+
+
+
+
+ def p_sample_cond(
+ self, model, cond_model, x_cond, x, cond_scale, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+ :param model: the model to sample from.
+ :param x: the current tensor at x_{t-1}.
+ :param t: the value of t, starting at 0 for the first diffusion step.
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - 'sample': a random sample from the model.
+ - 'pred_xstart': a prediction of x_0.
+ """
+ [img, target_pose] = x_cond
+
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ img,
+ cond_scale,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ with torch.enable_grad():
+
+ x_t = x.detach().requires_grad_(True)
+ logits = cond_model(xt=x_t, ref=img, pose=target_pose, t=t)
+ guide_label = torch.ones((logits.shape[0]))*3
+ loss = torch.nn.CrossEntropyLoss()(logits, guide_label.long().cuda())
+
+ grad = torch.autograd.grad(loss, x_t)[0].detach()
+
+ noise = th.randn_like(x)
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+
+ out['mean'] = out['mean'] + out["log_variance"]*grad # torch.clamp(grad, max = 1, min = -1)
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def p_sample_cond_loop(
+ self,
+ model,
+ cond_model,
+ x_cond,
+ cond_scale,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model.
+ :param model: the model module.
+ :param shape: the shape of the samples, (N, C, H, W).
+ :param noise: if specified, the noise from the encoder to sample.
+ Should be of the same shape as `shape`.
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param device: if specified, the device to create the samples on.
+ If not specified, use a model parameter's device.
+ :param progress: if True, show a tqdm progress bar.
+ :return: a non-differentiable batch of samples.
+ """
+ shape = x_cond[0].shape
+ final = None
+ for sample in self.p_sample_cond_loop_progressive(
+ model,
+ cond_model,
+ x_cond,
+ cond_scale,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ ):
+ final = sample
+ return final["sample"]
+
+ def p_sample_cond_loop_progressive(
+ self,
+ model,
+ cond_model,
+ x_cond,
+ cond_scale,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model and yield intermediate samples from
+ each timestep of diffusion.
+ Arguments are the same as p_sample_loop().
+ Returns a generator over dicts, where each dict is the return value of
+ p_sample().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.p_sample_cond(
+ model,
+ cond_model,
+ x_cond,
+ img,
+ cond_scale,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ yield out
+ img = out["sample"]
+
+
+
+ def ddim_sample(
+ self,
+ model,
+ x_cond,
+ x,
+ cond_scale,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t-1} from the model using DDIM.
+ Same usage as p_sample().
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ x_cond,
+ cond_scale,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
+ sigma = (
+ eta
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
+ )
+ # Equation 12.
+ noise = th.randn_like(x)
+ mean_pred = (
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
+ )
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ sample = mean_pred + nonzero_mask * sigma * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_reverse_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t+1} from the model using DDIM reverse ODE.
+ """
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
+ - out["pred_xstart"]
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
+
+ # Equation 12. reversed
+ mean_pred = (
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
+ + th.sqrt(1 - alpha_bar_next) * eps
+ )
+
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_sample_loop(
+ self,
+ model,
+ x_cond,
+ cond_scale,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Generate samples from the model using DDIM.
+ Same usage as p_sample_loop().
+ """
+ shape = x_cond[0].shape
+ final = None
+ x_cond[0] = [model.encode(x_cond[0])['cond'], model.encode(torch.zeros_like(x_cond[0]))['cond']]
+
+ x_cond[0][0] = [torch.stack([i*feat[0] + (1-i)*feat[-1] for i in torch.linspace(0,1,len(feat))], 0) for feat in x_cond[0][0]]
+
+ for sample in self.ddim_sample_loop_progressive(
+ model,
+ x_cond,
+ cond_scale,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ eta=eta,
+ ):
+ final = sample
+ return final["sample"]
+
+ def ddim_sample_loop_progressive(
+ self,
+ model,
+ x_cond,
+ cond_scale,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Use DDIM to sample from the model and yield intermediate samples from
+ each timestep of DDIM.
+ Same usage as p_sample_loop_progressive().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.ddim_sample(
+ model,
+ x_cond,
+ img,
+ cond_scale,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ eta=eta,
+ )
+ yield out
+ img = out["sample"]
+
+ def _vb_terms_bpd(
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
+ ):
+ """
+ Get a term for the variational lower-bound.
+ The resulting units are bits (rather than nats, as one might expect).
+ This allows for comparison to other papers.
+ :return: a dict with the following keys:
+ - 'output': a shape [N] tensor of NLLs or KLs.
+ - 'pred_xstart': the x_0 predictions.
+ """
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )
+ out = self.p_mean_variance(
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
+ )
+ kl = normal_kl(
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
+ )
+ kl = mean_flat(kl) / np.log(2.0)
+
+ decoder_nll = -discretized_gaussian_log_likelihood(
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
+ )
+ assert decoder_nll.shape == x_start.shape
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
+
+ # At the first timestep return the decoder NLL,
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
+ output = th.where((t == 0), decoder_nll, kl)
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
+
+ def training_losses(self, model, x_start, cond_input, t, prob, model_kwargs=None, noise=None):
+ """
+ Compute training losses for a single timestep.
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param t: a batch of timestep indices.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param noise: if specified, the specific Gaussian noise to try to remove.
+ :return: a dict with the key "loss" containing a tensor of shape [N].
+ Some mean or variance settings may also have other keys.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+ if noise is None:
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start, t, noise=noise)
+
+ [img, target_pose] = cond_input
+
+
+ terms = {}
+
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] = self._vb_terms_bpd(
+ model=model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ model_kwargs=model_kwargs,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] *= self.num_timesteps
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
+ model_output = model(x = torch.cat([x_t, target_pose],1), t = self._scale_timesteps(t), x_cond = img, prob = prob)
+
+
+ if self.model_var_type in [
+ ModelVarType.LEARNED,
+ ModelVarType.LEARNED_RANGE,
+ ]:
+ B, C = x_t.shape[:2]
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
+ model_output, model_var_values = th.split(model_output, C, dim=1)
+ # Learn the variance using the variational bound, but don't let
+ # it affect our mean prediction.
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
+ terms["vb"] = self._vb_terms_bpd(
+ model=lambda *args, r=frozen_out: r,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_MSE:
+ # Divide by 1000 for equivalence with initial implementation.
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
+ terms["vb"] *= self.num_timesteps / 1000.0
+
+ target = {
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )[0],
+ ModelMeanType.START_X: x_start,
+ ModelMeanType.EPSILON: noise,
+ }[self.model_mean_type]
+ assert model_output.shape == target.shape == x_start.shape
+ terms["mse"] = mean_flat((target - model_output) ** 2)
+ if "vb" in terms:
+ terms["loss"] = terms["mse"] + terms["vb"]
+ else:
+ terms["loss"] = terms["mse"]
+ else:
+ raise NotImplementedError(self.loss_type)
+
+ return terms
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
+ )
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
+ """
+ Compute the entire variational lower-bound, measured in bits-per-dim,
+ as well as other related quantities.
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param clip_denoised: if True, clip denoised samples.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - total_bpd: the total variational lower-bound, per batch element.
+ - prior_bpd: the prior term in the lower-bound.
+ - vb: an [N x T] tensor of terms in the lower-bound.
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
+ """
+ device = x_start.device
+ batch_size = x_start.shape[0]
+
+ vb = []
+ xstart_mse = []
+ mse = []
+ for t in list(range(self.num_timesteps))[::-1]:
+ t_batch = th.tensor([t] * batch_size, device=device)
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
+ # Calculate VLB term at the current timestep
+ with th.no_grad():
+ out = self._vb_terms_bpd(
+ model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t_batch,
+ clip_denoised=clip_denoised,
+ model_kwargs=model_kwargs,
+ )
+ vb.append(out["output"])
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
+ mse.append(mean_flat((eps - noise) ** 2))
+
+ vb = th.stack(vb, dim=1)
+ xstart_mse = th.stack(xstart_mse, dim=1)
+ mse = th.stack(mse, dim=1)
+
+ prior_bpd = self._prior_bpd(x_start)
+ total_bpd = vb.sum(dim=1) + prior_bpd
+ return {
+ "total_bpd": total_bpd,
+ "prior_bpd": prior_bpd,
+ "vb": vb,
+ "xstart_mse": xstart_mse,
+ "mse": mse,
+ }
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res.expand(broadcast_shape)
+
+
+
+def create_gaussian_diffusion(
+ betas,
+ learn_sigma=True,
+ sigma_small=False,
+ use_kl=False,
+ predict_xstart=False,
+ rescale_timesteps=False,
+ rescale_learned_sigmas=False,
+):
+
+ if use_kl:
+ loss_type = LossType.RESCALED_KL
+ elif rescale_learned_sigmas:
+ loss_type = LossType.RESCALED_MSE
+ else:
+ loss_type = LossType.MSE
+
+ model_mean_type=(
+ ModelMeanType.EPSILON if not predict_xstart else ModelMeanType.START_X
+ )
+
+ model_var_type=(
+ (
+ ModelVarType.FIXED_LARGE
+ if not sigma_small
+ else ModelVarType.FIXED_SMALL
+ )
+ if not learn_sigma
+ else ModelVarType.LEARNED_RANGE
+ )
+
+
+
+ return GaussianDiffusion(betas = betas,
+ model_mean_type = model_mean_type,
+ model_var_type = model_var_type,
+ loss_type = loss_type)
diff --git a/PyTorch/built-in/mlm/PIDM/model.py b/PyTorch/built-in/mlm/PIDM/model.py
new file mode 100755
index 0000000000000000000000000000000000000000..7733ff9fca2b9095ad153ab06146df421928a9c4
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/model.py
@@ -0,0 +1,448 @@
+import math
+from typing import List
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from pydantic import StrictInt, StrictFloat, StrictBool
+
+
+# def swish(input):
+# return input * torch.sigmoid(input)
+
+swish = F.silu
+
+
+@torch.no_grad()
+def variance_scaling_init_(tensor, scale=1, mode="fan_avg", distribution="uniform"):
+ fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor)
+
+ if mode == "fan_in":
+ scale /= fan_in
+
+ elif mode == "fan_out":
+ scale /= fan_out
+
+ else:
+ scale /= (fan_in + fan_out) / 2
+
+ if distribution == "normal":
+ std = math.sqrt(scale)
+
+ return tensor.normal_(0, std)
+
+ else:
+ bound = math.sqrt(3 * scale)
+
+ return tensor.uniform_(-bound, bound)
+
+
+def conv2d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ stride=1,
+ padding=0,
+ bias=True,
+ scale=1,
+ mode="fan_avg",
+):
+ conv = nn.Conv2d(
+ in_channel, out_channel, kernel_size, stride=stride, padding=padding, bias=bias
+ )
+
+ variance_scaling_init_(conv.weight, scale, mode=mode)
+
+ if bias:
+ nn.init.zeros_(conv.bias)
+
+ return conv
+
+
+def linear(in_channel, out_channel, scale=1, mode="fan_avg"):
+ lin = nn.Linear(in_channel, out_channel)
+
+ variance_scaling_init_(lin.weight, scale, mode=mode)
+ nn.init.zeros_(lin.bias)
+
+ return lin
+
+
+class Swish(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input):
+ return swish(input)
+
+
+class Upsample(nn.Sequential):
+ def __init__(self, channel):
+ layers = [
+ nn.Upsample(scale_factor=2, mode="nearest"),
+ conv2d(channel, channel, 3, padding=1),
+ ]
+
+ super().__init__(*layers)
+
+
+class Downsample(nn.Sequential):
+ def __init__(self, channel):
+ layers = [conv2d(channel, channel, 3, stride=2, padding=1)]
+
+ super().__init__(*layers)
+
+
+class ResBlock(nn.Module):
+ def __init__(
+ self, in_channel, out_channel, time_dim, use_affine_time=False, dropout=0
+ ):
+ super().__init__()
+
+ self.use_affine_time = use_affine_time
+ time_out_dim = out_channel
+ time_scale = 1
+ norm_affine = True
+
+ if self.use_affine_time:
+ time_out_dim *= 2
+ time_scale = 1e-10
+ norm_affine = False
+
+ #print (in_channel)
+
+ self.norm1 = nn.GroupNorm(32, in_channel)
+ self.activation1 = Swish()
+ self.conv1 = conv2d(in_channel, out_channel, 3, padding=1)
+
+ self.time = nn.Sequential(
+ Swish(), linear(time_dim, time_out_dim, scale=time_scale)
+ )
+
+ self.norm2 = nn.GroupNorm(32, out_channel, affine=norm_affine)
+ self.activation2 = Swish()
+ self.dropout = nn.Dropout(dropout)
+ self.conv2 = conv2d(out_channel, out_channel, 3, padding=1, scale=1e-10)
+
+ if in_channel != out_channel:
+ self.skip = conv2d(in_channel, out_channel, 1)
+
+ else:
+ self.skip = None
+
+ def forward(self, input, time):
+ batch = input.shape[0]
+
+ out = self.conv1(self.activation1(self.norm1(input)))
+
+ if self.use_affine_time:
+ gamma, beta = self.time(time).view(batch, -1, 1, 1).chunk(2, dim=1)
+ out = (1 + gamma) * self.norm2(out) + beta
+
+ else:
+ out = out + self.time(time).view(batch, -1, 1, 1)
+ out = self.norm2(out)
+
+ out = self.conv2(self.dropout(self.activation2(out)))
+
+ if self.skip is not None:
+ input = self.skip(input)
+
+ return out + input
+
+
+class SelfAttention(nn.Module):
+ def __init__(self, in_channel, n_head=1):
+ super().__init__()
+
+ self.n_head = n_head
+
+ self.norm = nn.GroupNorm(32, in_channel)
+ self.qkv = conv2d(in_channel, in_channel * 3, 1)
+ self.out = conv2d(in_channel, in_channel, 1, scale=1e-10)
+
+ def forward(self, input):
+ batch, channel, height, width = input.shape
+ n_head = self.n_head
+ head_dim = channel // n_head
+
+ norm = self.norm(input)
+ qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
+ query, key, value = qkv.chunk(3, dim=2) # bhdyx
+
+ attn = torch.einsum(
+ "bnchw, bncyx -> bnhwyx", query, key
+ ).contiguous() / math.sqrt(channel)
+ attn = attn.view(batch, n_head, height, width, -1)
+ attn = torch.softmax(attn, -1)
+ attn = attn.view(batch, n_head, height, width, height, width)
+
+ out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()
+ out = self.out(out.view(batch, channel, height, width))
+
+ return out + input
+
+
+class TimeEmbedding(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+
+ self.dim = dim
+
+ inv_freq = torch.exp(
+ torch.arange(0, dim, 2, dtype=torch.float32) * (-math.log(10000) / dim)
+ )
+
+ self.register_buffer("inv_freq", inv_freq)
+
+ def forward(self, input):
+ shape = input.shape
+ sinusoid_in = torch.ger(input.view(-1).float(), self.inv_freq)
+ pos_emb = torch.cat([sinusoid_in.sin(), sinusoid_in.cos()], dim=-1)
+ pos_emb = pos_emb.view(*shape, self.dim)
+
+ return pos_emb
+
+
+class ResBlockWithAttention(nn.Module):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ time_dim,
+ dropout,
+ use_attention=False,
+ attention_head=1,
+ use_affine_time=False,
+ ):
+ super().__init__()
+
+ self.resblocks = ResBlock(
+ in_channel, out_channel, time_dim, use_affine_time, dropout
+ )
+
+ if use_attention:
+ self.attention = SelfAttention(out_channel, n_head=attention_head)
+
+ else:
+ self.attention = None
+
+ def forward(self, input, time):
+ out = self.resblocks(input, time)
+
+ if self.attention is not None:
+ out = self.attention(out)
+
+ return out
+
+
+def spatial_fold(input, fold):
+ if fold == 1:
+ return input
+
+ batch, channel, height, width = input.shape
+ h_fold = height // fold
+ w_fold = width // fold
+
+ return (
+ input.view(batch, channel, h_fold, fold, w_fold, fold)
+ .permute(0, 1, 3, 5, 2, 4)
+ .reshape(batch, -1, h_fold, w_fold)
+ )
+
+
+def spatial_unfold(input, unfold):
+ if unfold == 1:
+ return input
+
+ batch, channel, height, width = input.shape
+ h_unfold = height * unfold
+ w_unfold = width * unfold
+
+ return (
+ input.view(batch, -1, unfold, unfold, height, width)
+ .permute(0, 1, 4, 2, 5, 3)
+ .reshape(batch, -1, h_unfold, w_unfold)
+ )
+
+
+class UNet(nn.Module):
+ def __init__(
+ self,
+ in_channel: StrictInt,
+ channel: StrictInt,
+ channel_multiplier: List[StrictInt],
+ n_res_blocks: StrictInt,
+ attn_strides: List[StrictInt],
+ attn_heads: StrictInt = 1,
+ use_affine_time: StrictBool = False,
+ dropout: StrictFloat = 0,
+ fold: StrictInt = 1,
+ ):
+ super().__init__()
+
+ self.fold = fold
+
+ time_dim = channel * 4
+
+ n_block = len(channel_multiplier)
+
+ self.time = nn.Sequential(
+ TimeEmbedding(channel),
+ linear(channel, time_dim),
+ Swish(),
+ linear(time_dim, time_dim),
+ )
+
+
+
+ cond_down_layers = [conv2d(3 * (fold ** 2), channel, 3, padding=1)]
+ down_layers = [conv2d(in_channel * (fold ** 2), channel, 3, padding=1)]
+
+ feat_channels = [channel]
+ in_channel = channel
+ for i in range(n_block):
+ for _ in range(n_res_blocks):
+ channel_mult = channel * channel_multiplier[i]
+
+ cond_down_layers.append(
+ ResBlockWithAttention(
+ in_channel,
+ channel_mult,
+ time_dim,
+ dropout,
+ use_attention=2 ** i in attn_strides,
+ attention_head=attn_heads,
+ use_affine_time=use_affine_time,
+ )
+ )
+
+ feat_channels.append(channel_mult)
+ in_channel = channel_mult
+
+ if i != n_block - 1:
+ cond_down_layers.append(Downsample(in_channel))
+ feat_channels.append(in_channel)
+
+ self.cond_down = nn.ModuleList(cond_down_layers)
+
+
+
+
+ feat_channels = [channel]
+ in_channel = channel
+ for i in range(n_block):
+ for _ in range(n_res_blocks):
+ channel_mult = channel * channel_multiplier[i]
+
+ down_layers.append(
+ ResBlockWithAttention(
+ in_channel*2,
+ channel_mult,
+ time_dim,
+ dropout,
+ use_attention=2 ** i in attn_strides,
+ attention_head=attn_heads,
+ use_affine_time=use_affine_time,
+ )
+ )
+
+ feat_channels.append(channel_mult)
+ in_channel = channel_mult
+
+ if i != n_block - 1:
+ down_layers.append(Downsample(in_channel))
+ feat_channels.append(in_channel)
+
+ self.down = nn.ModuleList(down_layers)
+
+ self.mid = nn.ModuleList(
+ [
+ ResBlockWithAttention(
+ in_channel,
+ in_channel,
+ time_dim,
+ dropout=dropout,
+ use_attention=True,
+ attention_head=attn_heads,
+ use_affine_time=use_affine_time,
+ ),
+ ResBlockWithAttention(
+ in_channel,
+ in_channel,
+ time_dim,
+ dropout=dropout,
+ use_affine_time=use_affine_time,
+ ),
+ ]
+ )
+
+ up_layers = []
+ for i in reversed(range(n_block)):
+ for _ in range(n_res_blocks + 1):
+ channel_mult = channel * channel_multiplier[i]
+
+ up_layers.append(
+ ResBlockWithAttention(
+ in_channel + feat_channels.pop(),
+ channel_mult,
+ time_dim,
+ dropout=dropout,
+ use_attention=2 ** i in attn_strides,
+ attention_head=attn_heads,
+ use_affine_time=use_affine_time,
+ )
+ )
+
+ in_channel = channel_mult
+
+ if i != 0:
+ up_layers.append(Upsample(in_channel))
+
+ self.up = nn.ModuleList(up_layers)
+
+ self.out = nn.Sequential(
+ nn.GroupNorm(32, in_channel),
+ Swish(),
+ conv2d(in_channel, 3 * (fold ** 2), 3, padding=1, scale=1e-10),
+ )
+
+ def forward(self, input, cond_input, time):
+ time_embed = self.time(time)
+
+ feats = []
+
+ out = spatial_fold(input, self.fold)
+ cond = spatial_fold(cond_input, self.fold)
+
+ for layer, cond_layer in zip(self.down, self.cond_down):
+
+ if isinstance(layer, ResBlockWithAttention):
+
+ out = torch.cat([out, cond], 1)
+
+ out = layer(out, time_embed)
+ cond = cond_layer(cond, time_embed)
+
+ else:
+ out = layer(out)
+ cond = cond_layer(cond)
+
+ #out = torch.cat([out, cond], 1)
+
+ feats.append(out)
+
+ for layer in self.mid:
+ out = layer(out, time_embed)
+
+ for layer in self.up:
+ if isinstance(layer, ResBlockWithAttention):
+ out = layer(torch.cat((out, feats.pop()), 1), time_embed)
+
+ else:
+ out = layer(out)
+
+ out = self.out(out)
+ out = spatial_unfold(out, self.fold)
+
+ return out
diff --git a/PyTorch/built-in/mlm/PIDM/models/__init__.py b/PyTorch/built-in/mlm/PIDM/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a501aa19ef5bf4885fe99bec236b5addd07ed0d
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/models/__init__.py
@@ -0,0 +1,6 @@
+from typing import Union
+from .unet import BeatGANsUNetModel, BeatGANsUNetConfig
+from .unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel
+
+Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel]
+ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig]
diff --git a/PyTorch/built-in/mlm/PIDM/models/blocks.py b/PyTorch/built-in/mlm/PIDM/models/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d810ba4af2e864010969e5cb36f814999c31105
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/models/blocks.py
@@ -0,0 +1,639 @@
+import math
+from abc import abstractmethod
+from dataclasses import dataclass
+from numbers import Number
+
+import torch as th
+import torch.nn.functional as F
+from .choices import *
+from .config_base import BaseConfig
+from torch import nn
+
+from .nn import (avg_pool_nd, conv_nd, linear, normalization,
+ timestep_embedding, torch_checkpoint, zero_module)
+
+
+class ScaleAt(Enum):
+ after_norm = 'afternorm'
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+ @abstractmethod
+ def forward(self, x, emb=None, cond=None, lateral=None):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+ def forward(self, x, emb=None, cond=None, lateral=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb=emb, cond=None, lateral=lateral)
+ elif isinstance(layer, AttentionBlock):
+ x = layer(x, cond)
+ else:
+ x = layer(x)
+ return x
+
+
+@dataclass
+class ResBlockConfig(BaseConfig):
+ channels: int
+ emb_channels: int
+ dropout: float
+ out_channels: int = None
+ # condition the resblock with time (and encoder's output)
+ use_condition: bool = True
+ # whether to use 3x3 conv for skip path when the channels aren't matched
+ use_conv: bool = False
+ # dimension of conv (always 2 = 2d)
+ dims: int = 2
+ # gradient checkpoint
+ use_checkpoint: bool = False
+ up: bool = False
+ down: bool = False
+ # whether to condition with both time & encoder's output
+ two_cond: bool = False
+ # number of encoders' output channels
+ cond_emb_channels: int = None
+ # suggest: False
+ has_lateral: bool = False
+ lateral_channels: int = None
+ # whether to init the convolution with zero weights
+ # this is default from BeatGANs and seems to help learning
+ use_zero_module: bool = True
+
+ def __post_init__(self):
+ self.out_channels = self.out_channels or self.channels
+ self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
+
+ def make_model(self):
+ return ResBlock(self)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+
+ total layers:
+ in_layers
+ - norm
+ - act
+ - conv
+ out_layers
+ - norm
+ - (modulation)
+ - act
+ - conv
+ """
+ def __init__(self, conf: ResBlockConfig):
+ super().__init__()
+ self.conf = conf
+
+ #############################
+ # IN LAYERS
+ #############################
+ assert conf.lateral_channels is None
+ layers = [
+
+ normalization(conf.channels),
+ nn.SiLU(),
+ conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1)
+ ]
+ self.in_layers = nn.Sequential(*layers)
+
+ self.updown = conf.up or conf.down
+
+ if conf.up:
+ self.h_upd = Upsample(conf.channels, False, conf.dims)
+ self.x_upd = Upsample(conf.channels, False, conf.dims)
+ elif conf.down:
+ self.h_upd = Downsample(conf.channels, False, conf.dims)
+ self.x_upd = Downsample(conf.channels, False, conf.dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ #############################
+ # OUT LAYERS CONDITIONS
+ #############################
+ if conf.use_condition:
+ # condition layers for the out_layers
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(conf.emb_channels, 2 * conf.out_channels),
+ )
+
+ if conf.two_cond:
+ self.cond_emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(conf.out_channels, conf.out_channels),
+ )
+ #############################
+ # OUT LAYERS (ignored when there is no condition)
+ #############################
+ # original version
+ conv = conv_nd(conf.dims,
+ conf.out_channels,
+ conf.out_channels,
+ 3,
+ padding=1)
+ if conf.use_zero_module:
+ # zere out the weights
+ # it seems to help training
+ conv = zero_module(conv)
+
+ # construct the layers
+ # - norm
+ # - (modulation)
+ # - act
+ # - dropout
+ # - conv
+
+ layers = []
+ layers += [
+ normalization(conf.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=conf.dropout),
+ conv,
+ ]
+ self.out_layers = nn.Sequential(*layers)
+
+ #############################
+ # SKIP LAYERS
+ #############################
+ if conf.out_channels == conf.channels:
+ # cannot be used with gatedconv, also gatedconv is alsways used as the first block
+ self.skip_connection = nn.Identity()
+ else:
+ if conf.use_conv:
+ kernel_size = 3
+ padding = 1
+ else:
+ kernel_size = 1
+ padding = 0
+
+ self.skip_connection = conv_nd(conf.dims,
+ conf.channels,
+ conf.out_channels,
+ kernel_size,
+ padding=padding)
+
+ def forward(self, x, emb=None, cond=None, lateral=None):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+
+ Args:
+ x: input
+ lateral: lateral connection from the encoder
+ """
+ return torch_checkpoint(self._forward, (x, emb, cond, lateral),
+ self.conf.use_checkpoint)
+
+ def _forward(
+ self,
+ x,
+ emb=None,
+ cond=None,
+ lateral=None,
+ ):
+ """
+ Args:
+ lateral: required if "has_lateral" and non-gated, with gated, it can be supplied optionally
+ """
+ if self.conf.has_lateral:
+ # lateral may be supplied even if it doesn't require
+ # the model will take the lateral only if "has_lateral"
+ assert lateral is not None
+ x = th.cat([x, lateral], dim=1)
+
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+
+ if self.conf.use_condition:
+ # it's possible that the network may not receieve the time emb
+ # this happens with autoenc and setting the time_at
+ if emb is not None:
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ else:
+ emb_out = None
+
+ if self.conf.two_cond:
+ # it's possible that the network is two_cond
+ # but it doesn't get the second condition
+ # in which case, we ignore the second condition
+ # and treat as if the network has one condition
+ if cond is None:
+ cond_out = None
+ else:
+ cond_out = self.cond_emb_layers(cond).type(h.dtype)
+
+ if cond_out is not None:
+ while len(cond_out.shape) < len(h.shape):
+ cond_out = cond_out[..., None]
+ else:
+ cond_out = None
+
+ # this is the new refactored code
+ h = apply_conditions(
+ h=h,
+ emb=emb_out,
+ cond=cond_out,
+ layers=self.out_layers,
+ scale_bias=1,
+ in_channels=self.conf.out_channels,
+ up_down_layer=None,
+ )
+
+ return self.skip_connection(x) + h
+
+
+def apply_conditions(
+ h,
+ emb=None,
+ cond=None,
+ layers: nn.Sequential = None,
+ scale_bias: float = 1,
+ in_channels: int = 512,
+ up_down_layer: nn.Module = None,
+):
+ """
+ apply conditions on the feature maps
+
+ Args:
+ emb: time conditional (ready to scale + shift)
+ cond: encoder's conditional (read to scale + shift)
+ """
+ two_cond = emb is not None and cond is not None
+
+ if emb is not None:
+ # adjusting shapes
+ while len(emb.shape) < len(h.shape):
+ emb = emb[..., None]
+
+ if two_cond:
+ # adjusting shapes
+ while len(cond.shape) < len(h.shape):
+ cond = cond[..., None]
+ # time first
+ scale_shifts = [emb, cond]
+ else:
+ # "cond" is not used with single cond mode
+ scale_shifts = [emb]
+
+ # support scale, shift or shift only
+ for i, each in enumerate(scale_shifts):
+ if each is None:
+ # special case: the condition is not provided
+ a = None
+ b = None
+ else:
+ if each.shape[1] == in_channels * 2:
+ a, b = th.chunk(each, 2, dim=1)
+ else:
+ a = each
+ b = None
+ scale_shifts[i] = (a, b)
+
+ # condition scale bias could be a list
+ if isinstance(scale_bias, Number):
+ biases = [scale_bias] * len(scale_shifts)
+ else:
+ # a list
+ biases = scale_bias
+
+ # default, the scale & shift are applied after the group norm but BEFORE SiLU
+ pre_layers, post_layers = layers[0], layers[1:]
+
+ # spilt the post layer to be able to scale up or down before conv
+ # post layers will contain only the conv
+ mid_layers, post_layers = post_layers[:-2], post_layers[-2:]
+
+ h = pre_layers(h)
+ # scale and shift for each condition
+ for i, (scale, shift) in enumerate(scale_shifts):
+ # if scale is None, it indicates that the condition is not provided
+ if scale is not None:
+ h = h * (biases[i] + scale)
+ if shift is not None:
+ h = h + shift
+ h = mid_layers(h)
+
+ # upscale or downscale if any just before the last conv
+ if up_down_layer is not None:
+ h = up_down_layer(h)
+ h = post_layers(h)
+ return h
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims,
+ self.channels,
+ self.out_channels,
+ 3,
+ padding=1)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
+ mode="nearest")
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(dims,
+ self.channels,
+ self.out_channels,
+ 3,
+ stride=stride,
+ padding=1)
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ self.to_kv = conv_nd(1, channels, channels * 2, 1)
+ self.to_q = conv_nd(1, channels, channels * 1, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.selfattention = QKVAttention(self.num_heads)
+ self.crossattention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.selfattention = QKVAttentionLegacy(self.num_heads)
+ self.crossattention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out1 = zero_module(conv_nd(1, channels, channels, 1))
+ self.proj_out2 = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x, cond):
+ return torch_checkpoint(self._forward, (x, cond,), self.use_checkpoint)
+
+ def _forward(self, x, cond):
+
+ # # self-attn
+
+ # b, c, *spatial = x.shape
+ # x = x.reshape(b, c, -1)
+ # qkv = self.qkv(self.norm(x))
+ # h = self.selfattention(qkv)
+ # h = self.proj_out1(h)
+ # x = (x + h).reshape(b, c, *spatial)
+
+ # cross-attn
+
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ cond = cond.reshape(b, c, -1)
+ kv = self.to_kv(cond)
+ q = self.to_q(self.norm(x))
+ qkv = th.cat([q, kv], 1)
+ h = self.crossattention(qkv)
+ h = self.proj_out2(h)
+
+ return (x + h).reshape(b, c, *spatial)
+
+
+class AttentionBlock_self(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = nn.LayerNorm(channels)#normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return torch_checkpoint(self._forward, (x, ), self.use_checkpoint)
+
+ def _forward(self, x):
+ # b, c, *spatial = x.shape
+ # x = x.reshape(b, c, -1)
+ qkv = self.qkv(x)
+ h = self.attention(qkv)
+ h = self.proj_out(h+x)
+ return h #.reshape(b, c, *spatial)
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial**2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch,
+ dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale,
+ k * scale) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight,
+ v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
diff --git a/PyTorch/built-in/mlm/PIDM/models/choices.py b/PyTorch/built-in/mlm/PIDM/models/choices.py
new file mode 100644
index 0000000000000000000000000000000000000000..740552ae31243b6dd4318cddaf5e2f7f6b3a8f69
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/models/choices.py
@@ -0,0 +1,179 @@
+from enum import Enum
+from torch import nn
+
+
+class TrainMode(Enum):
+ # manipulate mode = training the classifier
+ manipulate = 'manipulate'
+ # default trainin mode!
+ diffusion = 'diffusion'
+ # default latent training mode!
+ # fitting the a DDPM to a given latent
+ latent_diffusion = 'latentdiffusion'
+
+ def is_manipulate(self):
+ return self in [
+ TrainMode.manipulate,
+ ]
+
+ def is_diffusion(self):
+ return self in [
+ TrainMode.diffusion,
+ TrainMode.latent_diffusion,
+ ]
+
+ def is_autoenc(self):
+ # the network possibly does autoencoding
+ return self in [
+ TrainMode.diffusion,
+ ]
+
+ def is_latent_diffusion(self):
+ return self in [
+ TrainMode.latent_diffusion,
+ ]
+
+ def use_latent_net(self):
+ return self.is_latent_diffusion()
+
+ def require_dataset_infer(self):
+ """
+ whether training in this mode requires the latent variables to be available?
+ """
+ # this will precalculate all the latents before hand
+ # and the dataset will be all the predicted latents
+ return self in [
+ TrainMode.latent_diffusion,
+ TrainMode.manipulate,
+ ]
+
+
+class ManipulateMode(Enum):
+ """
+ how to train the classifier to manipulate
+ """
+ # train on whole celeba attr dataset
+ celebahq_all = 'celebahq_all'
+ # celeba with D2C's crop
+ d2c_fewshot = 'd2cfewshot'
+ d2c_fewshot_allneg = 'd2cfewshotallneg'
+
+ def is_celeba_attr(self):
+ return self in [
+ ManipulateMode.d2c_fewshot,
+ ManipulateMode.d2c_fewshot_allneg,
+ ManipulateMode.celebahq_all,
+ ]
+
+ def is_single_class(self):
+ return self in [
+ ManipulateMode.d2c_fewshot,
+ ManipulateMode.d2c_fewshot_allneg,
+ ]
+
+ def is_fewshot(self):
+ return self in [
+ ManipulateMode.d2c_fewshot,
+ ManipulateMode.d2c_fewshot_allneg,
+ ]
+
+ def is_fewshot_allneg(self):
+ return self in [
+ ManipulateMode.d2c_fewshot_allneg,
+ ]
+
+
+class ModelType(Enum):
+ """
+ Kinds of the backbone models
+ """
+
+ # unconditional ddpm
+ ddpm = 'ddpm'
+ # autoencoding ddpm cannot do unconditional generation
+ autoencoder = 'autoencoder'
+
+ def has_autoenc(self):
+ return self in [
+ ModelType.autoencoder,
+ ]
+
+ def can_sample(self):
+ return self in [ModelType.ddpm]
+
+
+class ModelName(Enum):
+ """
+ List of all supported model classes
+ """
+
+ beatgans_ddpm = 'beatgans_ddpm'
+ beatgans_autoenc = 'beatgans_autoenc'
+
+
+class ModelMeanType(Enum):
+ """
+ Which type of output the model predicts.
+ """
+
+ eps = 'eps' # the model predicts epsilon
+
+
+class ModelVarType(Enum):
+ """
+ What is used as the model's output variance.
+
+ The LEARNED_RANGE option has been added to allow the model to predict
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
+ """
+
+ # posterior beta_t
+ fixed_small = 'fixed_small'
+ # beta_t
+ fixed_large = 'fixed_large'
+
+
+class LossType(Enum):
+ mse = 'mse' # use raw MSE loss (and KL when learning variances)
+ l1 = 'l1'
+
+
+class GenerativeType(Enum):
+ """
+ How's a sample generated
+ """
+
+ ddpm = 'ddpm'
+ ddim = 'ddim'
+
+
+class OptimizerType(Enum):
+ adam = 'adam'
+ adamw = 'adamw'
+
+
+class Activation(Enum):
+ none = 'none'
+ relu = 'relu'
+ lrelu = 'lrelu'
+ silu = 'silu'
+ tanh = 'tanh'
+
+ def get_act(self):
+ if self == Activation.none:
+ return nn.Identity()
+ elif self == Activation.relu:
+ return nn.ReLU()
+ elif self == Activation.lrelu:
+ return nn.LeakyReLU(negative_slope=0.2)
+ elif self == Activation.silu:
+ return nn.SiLU()
+ elif self == Activation.tanh:
+ return nn.Tanh()
+ else:
+ raise NotImplementedError()
+
+
+class ManipulateLossType(Enum):
+ bce = 'bce'
+ mse = 'mse'
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/PIDM/models/config_base.py b/PyTorch/built-in/mlm/PIDM/models/config_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..385f9eef8bf1fb39ab354c407f5ea681765936bb
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/models/config_base.py
@@ -0,0 +1,72 @@
+import json
+import os
+from copy import deepcopy
+from dataclasses import dataclass
+
+
+@dataclass
+class BaseConfig:
+ def clone(self):
+ return deepcopy(self)
+
+ def inherit(self, another):
+ """inherit common keys from a given config"""
+ common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys())
+ for k in common_keys:
+ setattr(self, k, getattr(another, k))
+
+ def propagate(self):
+ """push down the configuration to all members"""
+ for k, v in self.__dict__.items():
+ if isinstance(v, BaseConfig):
+ v.inherit(self)
+ v.propagate()
+
+ def save(self, save_path):
+ """save config to json file"""
+ dirname = os.path.dirname(save_path)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+ conf = self.as_dict_jsonable()
+ with open(save_path, 'w') as f:
+ json.dump(conf, f)
+
+ def load(self, load_path):
+ """load json config"""
+ with open(load_path) as f:
+ conf = json.load(f)
+ self.from_dict(conf)
+
+ def from_dict(self, dict, strict=False):
+ for k, v in dict.items():
+ if not hasattr(self, k):
+ if strict:
+ raise ValueError(f"loading extra '{k}'")
+ else:
+ print(f"loading extra '{k}'")
+ continue
+ if isinstance(self.__dict__[k], BaseConfig):
+ self.__dict__[k].from_dict(v)
+ else:
+ self.__dict__[k] = v
+
+ def as_dict_jsonable(self):
+ conf = {}
+ for k, v in self.__dict__.items():
+ if isinstance(v, BaseConfig):
+ conf[k] = v.as_dict_jsonable()
+ else:
+ if jsonable(v):
+ conf[k] = v
+ else:
+ # ignore not jsonable
+ pass
+ return conf
+
+
+def jsonable(x):
+ try:
+ json.dumps(x)
+ return True
+ except TypeError:
+ return False
diff --git a/PyTorch/built-in/mlm/PIDM/models/latentnet.py b/PyTorch/built-in/mlm/PIDM/models/latentnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..169fa3ce449fa2b0d32f03e7735f5876235aebdc
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/models/latentnet.py
@@ -0,0 +1,193 @@
+import math
+from dataclasses import dataclass
+from enum import Enum
+from typing import NamedTuple, Tuple
+
+import torch
+from .choices import *
+from .config_base import BaseConfig
+from torch import nn
+from torch.nn import init
+
+from .blocks import *
+from .nn import timestep_embedding
+from .unet import *
+
+
+class LatentNetType(Enum):
+ none = 'none'
+ # injecting inputs into the hidden layers
+ skip = 'skip'
+
+
+class LatentNetReturn(NamedTuple):
+ pred: torch.Tensor = None
+
+
+@dataclass
+class MLPSkipNetConfig(BaseConfig):
+ """
+ default MLP for the latent DPM in the paper!
+ """
+ num_channels: int
+ skip_layers: Tuple[int]
+ num_hid_channels: int
+ num_layers: int
+ num_time_emb_channels: int = 64
+ activation: Activation = Activation.silu
+ use_norm: bool = True
+ condition_bias: float = 1
+ dropout: float = 0
+ last_act: Activation = Activation.none
+ num_time_layers: int = 2
+ time_last_act: bool = False
+
+ def make_model(self):
+ return MLPSkipNet(self)
+
+
+class MLPSkipNet(nn.Module):
+ """
+ concat x to hidden layers
+
+ default MLP for the latent DPM in the paper!
+ """
+ def __init__(self, conf: MLPSkipNetConfig):
+ super().__init__()
+ self.conf = conf
+
+ layers = []
+ for i in range(conf.num_time_layers):
+ if i == 0:
+ a = conf.num_time_emb_channels
+ b = conf.num_channels
+ else:
+ a = conf.num_channels
+ b = conf.num_channels
+ layers.append(nn.Linear(a, b))
+ if i < conf.num_time_layers - 1 or conf.time_last_act:
+ layers.append(conf.activation.get_act())
+ self.time_embed = nn.Sequential(*layers)
+
+ self.layers = nn.ModuleList([])
+ for i in range(conf.num_layers):
+ if i == 0:
+ act = conf.activation
+ norm = conf.use_norm
+ cond = True
+ a, b = conf.num_channels, conf.num_hid_channels
+ dropout = conf.dropout
+ elif i == conf.num_layers - 1:
+ act = Activation.none
+ norm = False
+ cond = False
+ a, b = conf.num_hid_channels, conf.num_channels
+ dropout = 0
+ else:
+ act = conf.activation
+ norm = conf.use_norm
+ cond = True
+ a, b = conf.num_hid_channels, conf.num_hid_channels
+ dropout = conf.dropout
+
+ if i in conf.skip_layers:
+ a += conf.num_channels
+
+ self.layers.append(
+ MLPLNAct(
+ a,
+ b,
+ norm=norm,
+ activation=act,
+ cond_channels=conf.num_channels,
+ use_cond=cond,
+ condition_bias=conf.condition_bias,
+ dropout=dropout,
+ ))
+ self.last_act = conf.last_act.get_act()
+
+ def forward(self, x, t, **kwargs):
+ t = timestep_embedding(t, self.conf.num_time_emb_channels)
+ cond = self.time_embed(t)
+ h = x
+ for i in range(len(self.layers)):
+ if i in self.conf.skip_layers:
+ # injecting input into the hidden layers
+ h = torch.cat([h, x], dim=1)
+ h = self.layers[i].forward(x=h, cond=cond)
+ h = self.last_act(h)
+ return LatentNetReturn(h)
+
+
+class MLPLNAct(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ norm: bool,
+ use_cond: bool,
+ activation: Activation,
+ cond_channels: int,
+ condition_bias: float = 0,
+ dropout: float = 0,
+ ):
+ super().__init__()
+ self.activation = activation
+ self.condition_bias = condition_bias
+ self.use_cond = use_cond
+
+ self.linear = nn.Linear(in_channels, out_channels)
+ self.act = activation.get_act()
+ if self.use_cond:
+ self.linear_emb = nn.Linear(cond_channels, out_channels)
+ self.cond_layers = nn.Sequential(self.act, self.linear_emb)
+ if norm:
+ self.norm = nn.LayerNorm(out_channels)
+ else:
+ self.norm = nn.Identity()
+
+ if dropout > 0:
+ self.dropout = nn.Dropout(p=dropout)
+ else:
+ self.dropout = nn.Identity()
+
+ self.init_weights()
+
+ def init_weights(self):
+ for module in self.modules():
+ if isinstance(module, nn.Linear):
+ if self.activation == Activation.relu:
+ init.kaiming_normal_(module.weight,
+ a=0,
+ nonlinearity='relu')
+ elif self.activation == Activation.lrelu:
+ init.kaiming_normal_(module.weight,
+ a=0.2,
+ nonlinearity='leaky_relu')
+ elif self.activation == Activation.silu:
+ init.kaiming_normal_(module.weight,
+ a=0,
+ nonlinearity='relu')
+ else:
+ # leave it as default
+ pass
+
+ def forward(self, x, cond=None):
+ x = self.linear(x)
+ if self.use_cond:
+ # (n, c) or (n, c * 2)
+ cond = self.cond_layers(cond)
+ cond = (cond, None)
+
+ # scale shift first
+ x = x * (self.condition_bias + cond[0])
+ if cond[1] is not None:
+ x = x + cond[1]
+ # then norm
+ x = self.norm(x)
+ else:
+ # no condition
+ x = self.norm(x)
+ x = self.act(x)
+ x = self.dropout(x)
+ return x
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/PIDM/models/losses.py b/PyTorch/built-in/mlm/PIDM/models/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..539d112efeb187c52e3a64dd36eb04cb26990b08
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/models/losses.py
@@ -0,0 +1,75 @@
+"""
+Helpers for various likelihood-based losses. These are ported from the original
+Ho et al. diffusion models codebase:
+https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
+"""
+
+import numpy as np
+
+import torch as th
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, th.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for th.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + th.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
+ )
+
+
+def approx_standard_normal_cdf(x):
+ """
+ A fast approximation of the cumulative distribution function of the
+ standard normal.
+ """
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
+
+
+def discretized_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
+ given image.
+ :param x: the target images. It is assumed that this was uint8 values,
+ rescaled to the range [-1, 1].
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ assert x.shape == means.shape == log_scales.shape
+ centered_x = x - means
+ inv_stdv = th.exp(-log_scales)
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
+ cdf_plus = approx_standard_normal_cdf(plus_in)
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
+ cdf_min = approx_standard_normal_cdf(min_in)
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
+ cdf_delta = cdf_plus - cdf_min
+ log_probs = th.where(
+ x < -0.999,
+ log_cdf_plus,
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
+ )
+ assert log_probs.shape == x.shape
+ return log_probs
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/PIDM/models/nn.py b/PyTorch/built-in/mlm/PIDM/models/nn.py
new file mode 100755
index 0000000000000000000000000000000000000000..b3398c87ffa5de9282c99dc037efecfa54e9490f
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/models/nn.py
@@ -0,0 +1,136 @@
+"""
+Various utilities for neural networks.
+"""
+
+from enum import Enum
+import math
+from typing import Optional
+
+import torch as th
+import torch.nn as nn
+import torch.utils.checkpoint
+
+import torch.nn.functional as F
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ # @th.jit.script
+ def forward(self, x):
+ return x * th.sigmoid(x)
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def update_ema(target_params, source_params, rate=0.99):
+ """
+ Update target parameters to be closer to those of source parameters using
+ an exponential moving average.
+
+ :param target_params: the target parameter sequence.
+ :param source_params: the source parameter sequence.
+ :param rate: the EMA rate (closer to 1 means slower).
+ """
+ for targ, src in zip(target_params, source_params):
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(min(32, channels), channels)
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = th.exp(-math.log(max_period) *
+ th.arange(start=0, end=half, dtype=th.float32) /
+ half).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = th.cat(
+ [embedding, th.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+
+def torch_checkpoint(func, args, flag, preserve_rng_state=False):
+ # torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8
+ if flag:
+ return torch.utils.checkpoint.checkpoint(
+ func, *args, preserve_rng_state=preserve_rng_state)
+ else:
+ return func(*args)
diff --git a/PyTorch/built-in/mlm/PIDM/models/pose_guide_network.py b/PyTorch/built-in/mlm/PIDM/models/pose_guide_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..883a778f6f5806f7cae60d6deff48096951007f5
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/models/pose_guide_network.py
@@ -0,0 +1,110 @@
+from enum import Enum
+
+import torch
+from torch import Tensor
+from torch.nn.functional import silu
+from torch import nn
+
+from .unet import *
+from choices import *
+from .blocks import *
+from .latentnet import *
+from einops import rearrange, reduce, repeat
+
+
+@dataclass
+class BeatGANsPoseGuideConfig(BeatGANsUNetConfig):
+ # number of style channels
+ enc_out_channels: int = 512
+ enc_attn_resolutions: Tuple[int] = None
+ enc_pool: str = 'depthconv'
+ enc_num_res_block: int = 2
+ enc_channel_mult: Tuple[int] = None
+ enc_grad_checkpoint: bool = False
+ latent_net_conf: MLPSkipNetConfig = None
+
+ def make_model(self):
+ return BeatGANsPoseGuideModel(self)
+
+
+class BeatGANsPoseGuideModel(nn.Module):
+
+ def __init__(self, conf: BeatGANsPoseGuideConfig):
+ super().__init__()
+ self.conf = conf
+
+ self.time_embed = TimeStyleSeperateEmbed(
+ time_channels=conf.model_channels,
+ time_out_channels=conf.embed_channels,
+ )
+
+ self.ref_encoder = BeatGANsEncoder(conf)
+ self.xt_encoder = BeatGANsEncoder(conf)
+
+ conf.in_channels = 20
+ self.pose_encoder = BeatGANsEncoder(conf)
+
+ self.cros_attn1 = AttentionBlock(channels = 512)
+ self.cros_attn2 = AttentionBlock(channels = 512)
+
+ self.self_attn = AttentionBlock_self(channels = 512)
+ self.token = nn.Parameter(torch.randn((512)))
+
+ self.linear = nn.Sequential(
+ nn.Linear(1024, 2048),
+ nn.ReLU(),
+ nn.Linear(2048, 4),
+ )
+
+ def forward(self,
+ xt,
+ ref,
+ pose,
+ t,
+ ):
+
+ emb_t = self.time_embed(timestep_embedding(t, self.conf.model_channels))
+ ref_feats = self.ref_encoder(ref, t = emb_t)
+ pose_feats = self.pose_encoder(pose, t = emb_t)
+ xt_feats = self.xt_encoder(xt, t = emb_t)
+
+ ref_out = self.cros_attn1(x = xt_feats[-1], cond = ref_feats[-1]).mean([2,3])
+
+ pose_out = self.cros_attn2(x = xt_feats[-1], cond = pose_feats[-1]).mean([2,3])
+
+ logits = self.linear(torch.cat([ref_out,pose_out],1))
+
+
+ # pose_out = rearrange(pose_out, 'b c h w -> b c (h w)')
+ # ref_out = rearrange(ref_out, 'b c h w -> b c (h w)')
+
+ # tkn = self.token.repeat(ref.shape[0],1).unsqueeze(-1)
+
+ # concat_out = torch.cat([tkn, pose_out, ref_out], 2)
+
+ # out = self.self_attn(concat_out)[:,:,0]
+
+ # logits = self.linear(out)
+
+ return logits
+
+
+class TimeStyleSeperateEmbed(nn.Module):
+ # embed only style
+ def __init__(self, time_channels, time_out_channels):
+ super().__init__()
+ self.time_embed = nn.Sequential(
+ linear(time_channels, time_out_channels),
+ nn.SiLU(),
+ linear(time_out_channels, time_out_channels),
+ )
+ self.style = nn.Identity()
+
+ def forward(self, time_emb=None, **kwargs):
+ if time_emb is None:
+ # happens with autoenc training mode
+ time_emb = None
+ else:
+ time_emb = self.time_embed(time_emb)
+
+ return time_emb
diff --git a/PyTorch/built-in/mlm/PIDM/models/unet.py b/PyTorch/built-in/mlm/PIDM/models/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..87df51d4e54a23635cea0b7783b0f63a3f9dee50
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/models/unet.py
@@ -0,0 +1,713 @@
+import math
+from dataclasses import dataclass
+from numbers import Number
+from typing import NamedTuple, Tuple, Union
+
+import numpy as np
+import torch as th
+from torch import nn
+import torch.nn.functional as F
+from .choices import *
+from .config_base import BaseConfig
+from .blocks import *
+
+from .nn import (conv_nd, linear, normalization, timestep_embedding,
+ torch_checkpoint, zero_module)
+
+
+@dataclass
+class BeatGANsUNetConfig(BaseConfig):
+ image_size: int = 64
+ in_channels: int = 3
+ # base channels, will be multiplied
+ model_channels: int = 64
+ # output of the unet
+ # suggest: 3
+ # you only need 6 if you also model the variance of the noise prediction (usually we use an analytical variance hence 3)
+ out_channels: int = 3
+ # how many repeating resblocks per resolution
+ # the decoding side would have "one more" resblock
+ # default: 2
+ num_res_blocks: int = 2
+ # you can also set the number of resblocks specifically for the input blocks
+ # default: None = above
+ num_input_res_blocks: int = None
+ # number of time embed channels and style channels
+ embed_channels: int = 512
+ # at what resolutions you want to do self-attention of the feature maps
+ # attentions generally improve performance
+ # default: [16]
+ # beatgans: [32, 16, 8]
+ attention_resolutions: Tuple[int] = (16, )
+ # number of time embed channels
+ time_embed_channels: int = None
+ # dropout applies to the resblocks (on feature maps)
+ dropout: float = 0.1
+ channel_mult: Tuple[int] = (1, 2, 4, 8)
+ input_channel_mult: Tuple[int] = None
+ conv_resample: bool = True
+ # always 2 = 2d conv
+ dims: int = 2
+ # don't use this, legacy from BeatGANs
+ num_classes: int = None
+ use_checkpoint: bool = False
+ # number of attention heads
+ num_heads: int = 1
+ # or specify the number of channels per attention head
+ num_head_channels: int = -1
+ # what's this?
+ num_heads_upsample: int = -1
+ # use resblock for upscale/downscale blocks (expensive)
+ # default: True (BeatGANs)
+ resblock_updown: bool = True
+ # never tried
+ use_new_attention_order: bool = False
+ resnet_two_cond: bool = False
+ resnet_cond_channels: int = None
+ # init the decoding conv layers with zero weights, this speeds up training
+ # default: True (BeattGANs)
+ resnet_use_zero_module: bool = True
+ # gradient checkpoint the attention operation
+ attn_checkpoint: bool = False
+
+ def make_model(self):
+ return BeatGANsUNetModel(self)
+
+
+class BeatGANsUNetModel(nn.Module):
+ def __init__(self, conf: BeatGANsUNetConfig):
+ super().__init__()
+ self.conf = conf
+
+ if conf.num_heads_upsample == -1:
+ self.num_heads_upsample = conf.num_heads
+
+ self.dtype = th.float32
+
+ self.time_emb_channels = conf.time_embed_channels or conf.model_channels
+ self.time_embed = nn.Sequential(
+ linear(self.time_emb_channels, conf.embed_channels),
+ nn.SiLU(),
+ linear(conf.embed_channels, conf.embed_channels),
+ )
+
+ if conf.num_classes is not None:
+ self.label_emb = nn.Embedding(conf.num_classes,
+ conf.embed_channels)
+
+ ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
+ self.input_blocks = nn.ModuleList([
+ TimestepEmbedSequential(
+ conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
+ ])
+
+ kwargs = dict(
+ use_condition=True,
+ two_cond=conf.resnet_two_cond,
+ use_zero_module=conf.resnet_use_zero_module,
+ # style channels for the resnet block
+ cond_emb_channels=conf.resnet_cond_channels,
+ )
+
+ self._feature_size = ch
+
+ # input_block_chans = [ch]
+ input_block_chans = [[] for _ in range(len(conf.channel_mult))]
+ input_block_chans[0].append(ch)
+
+ # number of blocks at each resolution
+ self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
+ self.input_num_blocks[0] = 1
+ self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
+
+ ds = 1
+ resolution = conf.image_size
+ for level, mult in enumerate(conf.input_channel_mult
+ or conf.channel_mult):
+ for block_id in range(conf.num_input_res_blocks or conf.num_res_blocks):
+ layers = [
+ ResBlockConfig(
+ ch,
+ conf.embed_channels,
+ conf.dropout,
+ out_channels=int(mult * conf.model_channels),
+ dims=conf.dims,
+ use_checkpoint=conf.use_checkpoint,
+ **kwargs,
+ ).make_model()
+ ]
+ ch = int(mult * conf.model_channels)
+ if resolution in conf.attention_resolutions and block_id==conf.num_res_blocks-1:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=conf.use_checkpoint
+ or conf.attn_checkpoint,
+ num_heads=conf.num_heads,
+ num_head_channels=conf.num_head_channels,
+ use_new_attention_order=conf.
+ use_new_attention_order,
+ ))
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ # input_block_chans.append(ch)
+ input_block_chans[level].append(ch)
+ self.input_num_blocks[level] += 1
+ # print(input_block_chans)
+ if level != len(conf.channel_mult) - 1:
+ resolution //= 2
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlockConfig(
+ ch,
+ conf.embed_channels,
+ conf.dropout,
+ out_channels=out_ch,
+ dims=conf.dims,
+ use_checkpoint=conf.use_checkpoint,
+ down=True,
+ **kwargs,
+ ).make_model() if conf.
+ resblock_updown else Downsample(ch,
+ conf.conv_resample,
+ dims=conf.dims,
+ out_channels=out_ch)))
+ ch = out_ch
+ # input_block_chans.append(ch)
+ input_block_chans[level + 1].append(ch)
+ self.input_num_blocks[level + 1] += 1
+ ds *= 2
+ self._feature_size += ch
+ #ch = ch*2
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlockConfig(
+ ch,
+ conf.embed_channels,
+ conf.dropout,
+ out_channels=ch,
+ dims=conf.dims,
+ use_checkpoint=conf.use_checkpoint,
+ **kwargs,
+ ).make_model(),
+ AttentionBlock(
+ ch,
+ use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
+ num_heads=conf.num_heads,
+ num_head_channels=conf.num_head_channels,
+ use_new_attention_order=conf.use_new_attention_order,
+ ),
+ ResBlockConfig(
+ ch,
+ conf.embed_channels,
+ conf.dropout,
+ out_channels=ch,
+ dims=conf.dims,
+ use_checkpoint=conf.use_checkpoint,
+ **kwargs,
+ ).make_model(),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(conf.channel_mult))[::-1]:
+ for i in range(conf.num_res_blocks + 1):
+ # print(input_block_chans)
+ # ich = input_block_chans.pop()
+ try:
+ ich = input_block_chans[level].pop()
+ except IndexError:
+ # this happens only when num_res_block > num_enc_res_block
+ # we will not have enough lateral (skip) connecions for all decoder blocks
+ ich = 0
+ # print('pop:', ich)
+ layers = [
+ ResBlockConfig(
+ # only direct channels when gated
+ channels=ch + ich,
+ emb_channels=conf.embed_channels,
+ dropout=conf.dropout,
+ out_channels=int(conf.model_channels * mult),
+ dims=conf.dims,
+ use_checkpoint=conf.use_checkpoint,
+ # lateral channels are described here when gated
+ has_lateral=True if ich > 0 else False,
+ lateral_channels=None,
+ **kwargs,
+ ).make_model()
+ ]
+ ch = int(conf.model_channels * mult)
+ if resolution in conf.attention_resolutions and i==conf.num_res_blocks-1:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=conf.use_checkpoint
+ or conf.attn_checkpoint,
+ num_heads=self.num_heads_upsample,
+ num_head_channels=conf.num_head_channels,
+ use_new_attention_order=conf.
+ use_new_attention_order,
+ ))
+ if level and i == conf.num_res_blocks:
+ resolution *= 2
+ out_ch = ch
+ layers.append(
+ ResBlockConfig(
+ ch,
+ conf.embed_channels,
+ conf.dropout,
+ out_channels=out_ch,
+ dims=conf.dims,
+ use_checkpoint=conf.use_checkpoint,
+ up=True,
+ **kwargs,
+ ).make_model() if (
+ conf.resblock_updown
+ ) else Upsample(ch,
+ conf.conv_resample,
+ dims=conf.dims,
+ out_channels=out_ch))
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self.output_num_blocks[level] += 1
+ self._feature_size += ch
+
+ # print(input_block_chans)
+ # print('inputs:', self.input_num_blocks)
+ # print('outputs:', self.output_num_blocks)
+
+ if conf.resnet_use_zero_module:
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(
+ conv_nd(conf.dims,
+ input_ch,
+ conf.out_channels,
+ 3,
+ padding=1)),
+ )
+ else:
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
+ )
+
+ def forward(self, x, t, y=None, **kwargs):
+ """
+ Apply the model to an input batch.
+
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.conf.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+
+ # hs = []
+ hs = [[] for _ in range(len(self.conf.channel_mult))]
+ emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
+
+ if self.conf.num_classes is not None:
+ raise NotImplementedError()
+ # assert y.shape == (x.shape[0], )
+ # emb = emb + self.label_emb(y)
+
+ # new code supports input_num_blocks != output_num_blocks
+ h = x.type(self.dtype)
+ k = 0
+ for i in range(len(self.input_num_blocks)):
+ for j in range(self.input_num_blocks[i]):
+ h = self.input_blocks[k](h, emb=emb)
+ # print(i, j, h.shape)
+ hs[i].append(h)
+ k += 1
+ assert k == len(self.input_blocks)
+
+ h = self.middle_block(h, emb=emb)
+ k = 0
+ for i in range(len(self.output_num_blocks)):
+ for j in range(self.output_num_blocks[i]):
+ # take the lateral connection from the same layer (in reserve)
+ # until there is no more, use None
+ try:
+ lateral = hs[-i - 1].pop()
+ # print(i, j, lateral.shape)
+ except IndexError:
+ lateral = None
+ # print(i, j, lateral)
+ h = self.output_blocks[k](h, emb=emb, lateral=lateral)
+ k += 1
+
+ h = h.type(x.dtype)
+ pred = self.out(h)
+ return Return(pred=pred)
+
+
+class Return(NamedTuple):
+ pred: th.Tensor
+
+
+@dataclass
+class BeatGANsEncoderConfig(BaseConfig):
+ image_size: int
+ in_channels: int
+ model_channels: int
+ out_hid_channels: int
+ out_channels: int
+ num_res_blocks: int
+ attention_resolutions: Tuple[int]
+ dropout: float = 0
+ channel_mult: Tuple[int] = (1, 2, 4, 8)
+ use_time_condition: bool = True
+ conv_resample: bool = True
+ dims: int = 2
+ use_checkpoint: bool = False
+ num_heads: int = 1
+ num_head_channels: int = -1
+ resblock_updown: bool = False
+ use_new_attention_order: bool = False
+ pool: str = 'adaptivenonzero'
+
+ def make_model(self):
+ return BeatGANsEncoderModel(self)
+
+
+class BeatGANsEncoderModel(nn.Module):
+ """
+ The half UNet model with attention and timestep embedding.
+
+ For usage, see UNet.
+ """
+ def __init__(self, conf: BeatGANsEncoderConfig):
+ super().__init__()
+ self.conf = conf
+ self.dtype = th.float32
+
+
+ if conf.use_time_condition:
+ time_embed_dim = conf.model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(conf.model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ else:
+ time_embed_dim = None
+
+ ch = int(conf.channel_mult[0] * conf.model_channels)
+ self.input_blocks = nn.ModuleList([
+ TimestepEmbedSequential(
+ conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
+ ])
+ self._feature_size = ch
+ input_block_chans = [ch]
+ ds = 1
+ resolution = conf.image_size
+ for level, mult in enumerate(conf.channel_mult):
+ for _ in range(conf.num_res_blocks):
+ layers = [
+ ResBlockConfig(
+ ch,
+ time_embed_dim,
+ conf.dropout,
+ out_channels=int(mult * conf.model_channels),
+ dims=conf.dims,
+ use_condition=conf.use_time_condition,
+ use_checkpoint=conf.use_checkpoint,
+ ).make_model()
+ ]
+ ch = int(mult * conf.model_channels)
+ # if resolution in conf.attention_resolutions:
+ # layers.append(
+ # AttentionBlock(
+ # ch,
+ # use_checkpoint=conf.use_checkpoint,
+ # num_heads=conf.num_heads,
+ # num_head_channels=conf.num_head_channels,
+ # use_new_attention_order=conf.
+ # use_new_attention_order,
+ # ))
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(conf.channel_mult) - 1:
+ resolution //= 2
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlockConfig(
+ ch,
+ time_embed_dim,
+ conf.dropout,
+ out_channels=out_ch,
+ dims=conf.dims,
+ use_condition=conf.use_time_condition,
+ use_checkpoint=conf.use_checkpoint,
+ down=True,
+ ).make_model() if (
+ conf.resblock_updown
+ ) else Downsample(ch,
+ conf.conv_resample,
+ dims=conf.dims,
+ out_channels=out_ch)))
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlockConfig(
+ ch,
+ time_embed_dim,
+ conf.dropout,
+ dims=conf.dims,
+ use_condition=conf.use_time_condition,
+ use_checkpoint=conf.use_checkpoint,
+ ).make_model(),
+ # AttentionBlock(
+ # ch,
+ # use_checkpoint=conf.use_checkpoint,
+ # num_heads=conf.num_heads,
+ # num_head_channels=conf.num_head_channels,
+ # use_new_attention_order=conf.use_new_attention_order,
+ # ),
+ ResBlockConfig(
+ ch,
+ time_embed_dim,
+ conf.dropout,
+ dims=conf.dims,
+ use_condition=conf.use_time_condition,
+ use_checkpoint=conf.use_checkpoint,
+ ).make_model(),
+ )
+ self._feature_size += ch
+ if conf.pool == "adaptivenonzero":
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ nn.AdaptiveAvgPool2d((1, 1)),
+ conv_nd(conf.dims, ch, conf.out_channels, 1),
+ nn.Flatten(),
+ )
+ else:
+ raise NotImplementedError(f"Unexpected {conf.pool} pooling")
+
+ def forward(self, x, t=None, return_2d_feature=False):
+ """
+ Apply the model to an input batch.
+
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :return: an [N x K] Tensor of outputs.
+ """
+ if self.conf.use_time_condition:
+ emb = self.time_embed(timestep_embedding(t, self.model_channels))
+ else:
+ emb = None
+
+ results = []
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb=emb)
+ if self.conf.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = self.middle_block(h, emb=emb)
+ if self.conf.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = th.cat(results, axis=-1)
+ else:
+ h = h.type(x.dtype)
+
+ h_2d = h
+ h = self.out(h)
+
+ if return_2d_feature:
+ return h, h_2d
+ else:
+ return h
+
+ def forward_flatten(self, x):
+ """
+ transform the last 2d feature into a flatten vector
+ """
+ h = self.out(x)
+ return h
+
+
+class SuperResModel(BeatGANsUNetModel):
+ """
+ A UNetModel that performs super-resolution.
+
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
+ """
+ def __init__(self, image_size, in_channels, *args, **kwargs):
+ super().__init__(image_size, in_channels * 2, *args, **kwargs)
+
+ def forward(self, x, timesteps, low_res=None, **kwargs):
+ _, _, new_height, new_width = x.shape
+ upsampled = F.interpolate(low_res, (new_height, new_width),
+ mode="bilinear")
+ x = th.cat([x, upsampled], dim=1)
+ return super().forward(x, timesteps, **kwargs)
+
+
+
+
+
+class BeatGANsEncoder(nn.Module):
+ def __init__(self, conf: BeatGANsUNetConfig):
+ super().__init__()
+ self.conf = conf
+
+ if conf.num_heads_upsample == -1:
+ self.num_heads_upsample = conf.num_heads
+
+ self.dtype = th.float32
+
+ self.time_emb_channels = conf.time_embed_channels or conf.model_channels
+ self.time_embed = nn.Sequential(
+ linear(self.time_emb_channels, conf.embed_channels),
+ nn.SiLU(),
+ linear(conf.embed_channels, conf.embed_channels),
+ )
+
+ if conf.num_classes is not None:
+ self.label_emb = nn.Embedding(conf.num_classes,
+ conf.embed_channels)
+
+ ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
+ self.input_blocks = nn.ModuleList([
+ TimestepEmbedSequential(
+ conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
+ ])
+
+ kwargs = dict(
+ use_condition=True,
+ two_cond=conf.resnet_two_cond,
+ use_zero_module=conf.resnet_use_zero_module,
+ # style channels for the resnet block
+ cond_emb_channels=conf.resnet_cond_channels,
+ )
+
+ self._feature_size = [ch]
+
+ # input_block_chans = [ch]
+ input_block_chans = [[] for _ in range(len(conf.channel_mult))]
+ input_block_chans[0].append(ch)
+
+ # number of blocks at each resolution
+ self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
+ self.input_num_blocks[0] = 1
+ self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
+
+ ds = 1
+ resolution = conf.image_size
+ for level, mult in enumerate(conf.input_channel_mult
+ or conf.channel_mult):
+ for _ in range(conf.num_input_res_blocks or conf.num_res_blocks):
+ layers = [
+ ResBlockConfig(
+ ch,
+ conf.embed_channels,
+ conf.dropout,
+ out_channels=int(mult * conf.model_channels),
+ dims=conf.dims,
+ use_checkpoint=conf.use_checkpoint,
+ **kwargs,
+ ).make_model()
+ ]
+ ch = int(mult * conf.model_channels)
+ # if resolution in conf.attention_resolutions:
+ # layers.append(
+ # AttentionBlock(
+ # ch,
+ # use_checkpoint=conf.use_checkpoint
+ # or conf.attn_checkpoint,
+ # num_heads=conf.num_heads,
+ # num_head_channels=conf.num_head_channels,
+ # use_new_attention_order=conf.
+ # use_new_attention_order,
+ # ))
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size.append(ch)
+ # input_block_chans.append(ch)
+ input_block_chans[level].append(ch)
+ self.input_num_blocks[level] += 1
+ # print(input_block_chans)
+ if level != len(conf.channel_mult) - 1:
+ resolution //= 2
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlockConfig(
+ ch,
+ conf.embed_channels,
+ conf.dropout,
+ out_channels=out_ch,
+ dims=conf.dims,
+ use_checkpoint=conf.use_checkpoint,
+ down=True,
+ **kwargs,
+ ).make_model() if conf.
+ resblock_updown else Downsample(ch,
+ conf.conv_resample,
+ dims=conf.dims,
+ out_channels=out_ch)))
+ ch = out_ch
+ # input_block_chans.append(ch)
+ input_block_chans[level + 1].append(ch)
+ self.input_num_blocks[level + 1] += 1
+ ds *= 2
+ self._feature_size.append(ch)
+
+ # self._to_vector_layers = [nn.Sequential(
+ # normalization(ch),
+ # nn.SiLU(),
+ # nn.AdaptiveAvgPool2d((1, 1)),
+ # conv_nd(conf.dims, ch, ch, 1),
+ # nn.Flatten(),
+ # ).cuda() for ch in self._feature_size]
+
+ def forward(self, x, t=None, y=None, **kwargs):
+ """
+ Apply the model to an input batch.
+
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ # hs = []
+ hs = [[] for _ in range(len(self.conf.channel_mult))]
+ #emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
+
+ if self.conf.num_classes is not None:
+ raise NotImplementedError()
+ # assert y.shape == (x.shape[0], )
+ # emb = emb + self.label_emb(y)
+
+ # new code supports input_num_blocks != output_num_blocks
+ h = x.type(self.dtype)
+ k = 0
+ results = []
+ for i in range(len(self.input_num_blocks)):
+ for j in range(self.input_num_blocks[i]):
+ h = self.input_blocks[k](h, emb=None)
+ # print(i, j, h.shape)
+ hs[i].append(h)
+ results.append(h)
+ #print (h.shape)
+ k += 1
+ assert k == len(self.input_blocks)
+
+ # vectors = []
+
+ # for i, feat in enumerate(results):
+ # vectors.append(self._to_vector_layers[i](feat))
+
+ return results
\ No newline at end of file
diff --git a/PyTorch/built-in/mlm/PIDM/models/unet_autoenc.py b/PyTorch/built-in/mlm/PIDM/models/unet_autoenc.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b9f2fccc2bfc3f5dc0350a1c7e5504bd1d1400c
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/models/unet_autoenc.py
@@ -0,0 +1,308 @@
+from enum import Enum
+
+import torch
+from torch import Tensor
+from torch.nn.functional import silu
+
+from .latentnet import *
+from .unet import *
+from .choices import *
+
+def prob_mask_like(shape, prob, device):
+ if prob == 1:
+ return torch.ones(shape, device = device, dtype = torch.bool)
+ elif prob == 0:
+ return torch.zeros(shape, device = device, dtype = torch.bool)
+ else:
+ return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
+
+
+@dataclass
+class BeatGANsAutoencConfig(BeatGANsUNetConfig):
+ # number of style channels
+ enc_out_channels: int = 512
+ enc_attn_resolutions: Tuple[int] = None
+ enc_pool: str = 'depthconv'
+ enc_num_res_block: int = 2
+ enc_channel_mult: Tuple[int] = None
+ enc_grad_checkpoint: bool = False
+ latent_net_conf: MLPSkipNetConfig = None
+
+ def make_model(self):
+ return BeatGANsAutoencModel(self)
+
+
+class BeatGANsAutoencModel(BeatGANsUNetModel):
+ def __init__(self, conf: BeatGANsAutoencConfig):
+ super().__init__(conf)
+ self.conf = conf
+
+ # having only time, cond
+ self.time_embed = TimeStyleSeperateEmbed(
+ time_channels=conf.model_channels,
+ time_out_channels=conf.embed_channels,
+ )
+
+ conf.in_channels = 3
+ self.encoder = BeatGANsEncoder(conf)
+
+ if conf.latent_net_conf is not None:
+ self.latent_net = conf.latent_net_conf.make_model()
+
+ def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
+ """
+ Reparameterization trick to sample from N(mu, var) from
+ N(0,1).
+ :param mu: (Tensor) Mean of the latent Gaussian [B x D]
+ :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
+ :return: (Tensor) [B x D]
+ """
+ assert self.conf.is_stochastic
+ std = torch.exp(0.5 * logvar)
+ eps = torch.randn_like(std)
+ return eps * std + mu
+
+ def sample_z(self, n: int, device):
+ assert self.conf.is_stochastic
+ return torch.randn(n, self.conf.enc_out_channels, device=device)
+
+ def noise_to_cond(self, noise: Tensor):
+ raise NotImplementedError()
+ assert self.conf.noise_net_conf is not None
+ return self.noise_net.forward(noise)
+
+ def encode(self, x):
+
+ cond = self.encoder.forward(x)
+
+ return {'cond': cond}
+
+ @property
+ def stylespace_sizes(self):
+ modules = list(self.input_blocks.modules()) + list(
+ self.middle_block.modules()) + list(self.output_blocks.modules())
+ sizes = []
+ for module in modules:
+ if isinstance(module, ResBlock):
+ linear = module.cond_emb_layers[-1]
+ sizes.append(linear.weight.shape[0])
+ return sizes
+
+ def encode_stylespace(self, x, return_vector: bool = True):
+ """
+ encode to style space
+ """
+ modules = list(self.input_blocks.modules()) + list(
+ self.middle_block.modules()) + list(self.output_blocks.modules())
+ # (n, c)
+ cond = self.encoder.forward(x)
+ S = []
+ for module in modules:
+ if isinstance(module, ResBlock):
+ # (n, c')
+ s = module.cond_emb_layers.forward(cond)
+ S.append(s)
+
+ if return_vector:
+ # (n, sum_c)
+ return torch.cat(S, dim=1)
+ else:
+ return S
+
+
+ def forward_with_cond_scale(
+ self,
+ x,
+ t,
+ cond,
+ cond_scale,
+ ):
+ logits = self.forward(x, t, cond=cond, prob = 1)
+
+ if cond_scale == 1:
+ return [logits, _, _]
+
+ null_logits = self.forward(x, t, cond=cond, prob = 0)
+
+ return [null_logits + (logits - null_logits) * cond_scale, logits, null_logits]
+
+ def forward(self,
+ x,
+ t,
+ x_cond=None,
+ prob=1,
+ y=None,
+ cond=None,
+ style=None,
+ noise=None,
+ t_cond=None,
+ **kwargs):
+ """
+ Apply the model to an input batch.
+
+ Args:
+ x_start: the original image to encode
+ cond: output of the encoder
+ noise: random noise (to predict the cond)
+ """
+ cond_mask = prob_mask_like((x.shape[0],), prob = prob, device = x.device)
+
+
+ if t_cond is None:
+ t_cond = t
+
+ if noise is not None:
+ # if the noise is given, we predict the cond from noise
+ cond = self.noise_to_cond(noise)
+
+ if cond is None:
+ x_cond = (cond_mask.view(-1,1,1,1)*x_cond)
+ if x is not None:
+ assert len(x) == len(x_cond), f'{len(x)} != {len(x_cond)}'
+
+ tmp = self.encode(x_cond)
+ cond = tmp['cond']
+
+ else:
+ if prob==1:
+ cond = cond[0]
+ elif prob==0:
+ cond = cond[1]
+
+
+ if t is not None:
+ _t_emb = timestep_embedding(t, self.conf.model_channels)
+ _t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels)
+ else:
+ # this happens when training only autoenc
+ _t_emb = None
+ _t_cond_emb = None
+
+ if self.conf.resnet_two_cond:
+ res = self.time_embed.forward(
+ time_emb=_t_emb,
+ cond=cond,
+ time_cond_emb=_t_cond_emb,
+ )
+ else:
+ raise NotImplementedError()
+
+ if self.conf.resnet_two_cond:
+ # two cond: first = time emb, second = cond_emb
+ emb = res.time_emb
+ cond_emb = res.emb
+ else:
+ # one cond = combined of both time and cond
+ emb = res.emb
+ cond_emb = None
+
+ # override the style if given
+ style = style or res.style
+
+ assert (y is not None) == (
+ self.conf.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+
+ if self.conf.num_classes is not None:
+ raise NotImplementedError()
+ # assert y.shape == (x.shape[0], )
+ # emb = emb + self.label_emb(y)
+
+ # where in the model to supply time conditions
+ enc_time_emb = emb
+ mid_time_emb = emb
+ dec_time_emb = emb
+ # where in the model to supply style conditions
+ enc_cond_emb = cond
+ mid_cond_emb = cond[-1]
+ dec_cond_emb = cond #+ [cond[-1]]
+
+ # hs = []
+ hs = [[] for _ in range(len(self.conf.channel_mult))]
+
+ if x is not None:
+ h = x.type(self.dtype)
+
+ # input blocks
+ k = 0
+ for i in range(len(self.input_num_blocks)):
+ for j in range(self.input_num_blocks[i]):
+ h = self.input_blocks[k](h,
+ emb=enc_time_emb,
+ cond=enc_cond_emb[k])
+
+ hs[i].append(h)
+ #h = th.concat([h, enc_cond_emb[k]], 1)
+
+ k += 1
+
+
+
+ assert k == len(self.input_blocks)
+
+ # middle blocks
+ h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb)
+ else:
+ # no lateral connections
+ # happens when training only the autonecoder
+ h = None
+ hs = [[] for _ in range(len(self.conf.channel_mult))]
+
+ # output blocks
+ k = 0
+ for i in range(len(self.output_num_blocks)):
+ for j in range(self.output_num_blocks[i]):
+ # take the lateral connection from the same layer (in reserve)
+ # until there is no more, use None
+ try:
+ lateral = hs[-i - 1].pop()
+ # print(i, j, lateral.shape)
+ except IndexError:
+ lateral = None
+ # print(i, j, lateral)
+
+ h = self.output_blocks[k](h,
+ emb=dec_time_emb,
+ cond=dec_cond_emb[-k-1],
+ lateral=lateral)
+
+
+ k += 1
+
+ pred = self.out(h)
+ return pred
+
+
+class AutoencReturn(NamedTuple):
+ pred: Tensor
+ cond: Tensor = None
+
+
+class EmbedReturn(NamedTuple):
+ # style and time
+ emb: Tensor = None
+ # time only
+ time_emb: Tensor = None
+ # style only (but could depend on time)
+ style: Tensor = None
+
+
+class TimeStyleSeperateEmbed(nn.Module):
+ # embed only style
+ def __init__(self, time_channels, time_out_channels):
+ super().__init__()
+ self.time_embed = nn.Sequential(
+ linear(time_channels, time_out_channels),
+ nn.SiLU(),
+ linear(time_out_channels, time_out_channels),
+ )
+ self.style = nn.Identity()
+
+ def forward(self, time_emb=None, cond=None, **kwargs):
+ if time_emb is None:
+ # happens with autoenc training mode
+ time_emb = None
+ else:
+ time_emb = self.time_embed(time_emb)
+ style = self.style(cond)
+ return EmbedReturn(emb=style, time_emb=time_emb, style=style)
diff --git a/PyTorch/built-in/mlm/PIDM/predict.py b/PyTorch/built-in/mlm/PIDM/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..34472ca27fcbc8df89caed8ee758ddc04641257f
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/predict.py
@@ -0,0 +1,126 @@
+# Prediction interface for Cog ⚙️
+# https://github.com/replicate/cog/blob/main/docs/python.md
+
+import warnings
+warnings.filterwarnings('ignore')
+import torch
+import torch.nn as nn
+from tqdm import tqdm
+from torchvision.utils import save_image
+from PIL import Image
+from tensorfn import load_config as DiffConfig
+import numpy as np
+from config.diffconfig import DiffusionConfig, get_model_conf
+import torch.distributed as dist
+import os, glob, cv2, time, shutil
+from models.unet_autoenc import BeatGANsAutoencConfig
+from diffusion import create_gaussian_diffusion, make_beta_schedule, ddim_steps
+import torchvision.transforms as transforms
+import torchvision
+
+class Predictor():
+ def __init__(self):
+ """Load the model into memory to make running multiple predictions efficient"""
+
+ conf = DiffConfig(DiffusionConfig, './config/diffusion.conf', show=False)
+
+ self.model = get_model_conf().make_model()
+ ckpt = torch.load("checkpoints/last.pt")
+ self.model.load_state_dict(ckpt["ema"])
+ self.model = self.model.cuda()
+ self.model.eval()
+
+ self.betas = conf.diffusion.beta_schedule.make()
+ self.diffusion = create_gaussian_diffusion(self.betas, predict_xstart = False)#.to(device)
+
+ self.pose_list = glob.glob('data/deepfashion_256x256/target_pose/*.npy')
+ self.transforms = transforms.Compose([transforms.Resize((256,256), interpolation=Image.BICUBIC),
+ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5),
+ (0.5, 0.5, 0.5))])
+ def predict_pose(
+ self,
+ image,
+ num_poses=1,
+ sample_algorithm='ddim',
+ nsteps=100,
+
+ ):
+ """Run a single prediction on the model"""
+
+ src = Image.open(image)
+ src = self.transforms(src).unsqueeze(0).cuda()
+ tgt_pose = torch.stack([transforms.ToTensor()(np.load(ps)).cuda() for ps in np.random.choice(self.pose_list, num_poses)], 0)
+
+ src = src.repeat(num_poses,1,1,1)
+
+ if sample_algorithm == 'ddpm':
+ samples = self.diffusion.p_sample_loop(self.model, x_cond = [src, tgt_pose], progress = True, cond_scale = 2)
+ elif sample_algorithm == 'ddim':
+ noise = torch.randn(src.shape).cuda()
+ seq = range(0, 1000, 1000//nsteps)
+ xs, x0_preds = ddim_steps(noise, seq, self.model, self.betas.cuda(), [src, tgt_pose])
+ samples = xs[-1].cuda()
+
+
+ samples_grid = torch.cat([src[0],torch.cat([samps for samps in samples], -1)], -1)
+ samples_grid = (torch.clamp(samples_grid, -1., 1.) + 1.0)/2.0
+ pose_grid = torch.cat([torch.zeros_like(src[0]),torch.cat([samps[:3] for samps in tgt_pose], -1)], -1)
+
+ output = torch.cat([1-pose_grid, samples_grid], -2)
+
+ numpy_imgs = output.unsqueeze(0).permute(0,2,3,1).detach().cpu().numpy()
+ fake_imgs = (255*numpy_imgs).astype(np.uint8)
+ Image.fromarray(fake_imgs[0]).save('output.png')
+
+
+ def predict_appearance(
+ self,
+ image,
+ ref_img,
+ ref_mask,
+ ref_pose,
+ sample_algorithm='ddim',
+ nsteps=100,
+
+ ):
+ """Run a single prediction on the model"""
+
+ src = Image.open(image)
+ src = self.transforms(src).unsqueeze(0).cuda()
+
+ ref = Image.open(ref_img)
+ ref = self.transforms(ref).unsqueeze(0).cuda()
+
+ mask = transforms.ToTensor()(Image.open(ref_mask)).unsqueeze(0).cuda()
+ pose = transforms.ToTensor()(np.load(ref_pose)).unsqueeze(0).cuda()
+
+
+ if sample_algorithm == 'ddpm':
+ samples = self.diffusion.p_sample_loop(self.model, x_cond = [src, pose, ref, mask], progress = True, cond_scale = 2)
+ elif sample_algorithm == 'ddim':
+ noise = torch.randn(src.shape).cuda()
+ seq = range(0, 1000, 1000//nsteps)
+ xs, x0_preds = ddim_steps(noise, seq, self.model, self.betas.cuda(), [src, pose, ref, mask], diffusion=self.diffusion)
+ samples = xs[-1].cuda()
+
+
+ samples = torch.clamp(samples, -1., 1.)
+
+ output = (torch.cat([src, ref, mask*2-1, samples], -1) + 1.0)/2.0
+
+ numpy_imgs = output.permute(0,2,3,1).detach().cpu().numpy()
+ fake_imgs = (255*numpy_imgs).astype(np.uint8)
+ Image.fromarray(fake_imgs[0]).save('output.png')
+
+if __name__ == "__main__":
+
+
+ obj = Predictor()
+
+ obj.predict_pose(image='test.jpg', num_poses=4, sample_algorithm = 'ddim', nsteps = 50)
+
+ # ref_img = "data/deepfashion_256x256/target_edits/reference_img_0.png"
+ # ref_mask = "data/deepfashion_256x256/target_mask/lower/reference_mask_0.png"
+ # ref_pose = "data/deepfashion_256x256/target_pose/reference_pose_0.npy"
+
+ # #obj.predict_appearance(image='test.jpg', ref_img = ref_img, ref_mask = ref_mask, ref_pose = ref_pose, sample_algorithm = 'ddim', nsteps = 50)
diff --git a/PyTorch/built-in/mlm/PIDM/requirements.txt b/PyTorch/built-in/mlm/PIDM/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fc60c5053ddbf2a1113915f068719aea9e39e0dd
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/requirements.txt
@@ -0,0 +1,12 @@
+imageio==2.10.3
+lmdb==1.2.1
+opencv-python==4.5.4.58
+Pillow==8.3.2
+PyYAML==5.4.1
+scikit-image==0.17.2
+scipy==1.5.4
+tensorboard==2.6.0
+tqdm==4.62.3
+pydantic==1.10.14
+wandb
+tensorfn
diff --git a/PyTorch/built-in/mlm/PIDM/train.py b/PyTorch/built-in/mlm/PIDM/train.py
new file mode 100755
index 0000000000000000000000000000000000000000..19e77b2b8376c3186adf38a7bae30f7b05e8620b
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/train.py
@@ -0,0 +1,336 @@
+import os
+import warnings
+
+warnings.filterwarnings("ignore")
+
+import time, cv2, torch, wandb
+import torch.distributed as dist
+from config.diffconfig import DiffusionConfig, get_model_conf
+from config.dataconfig import Config as DataConfig
+from tensorfn import load_config as DiffConfig
+from diffusion import create_gaussian_diffusion, make_beta_schedule, ddim_steps
+from tensorfn.optim import lr_scheduler
+from torch import nn, optim
+from torch.utils import data
+from torchvision import transforms
+from tqdm import tqdm
+import numpy as np
+import data as deepfashion_data
+from model import UNet
+
+def init_distributed():
+
+ # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
+ dist_url = "env://" # default
+
+ # only works with torch.distributed.launch // torch.run
+ rank = int(os.environ["RANK"])
+ world_size = int(os.environ['WORLD_SIZE'])
+ local_rank = int(os.environ['LOCAL_RANK'])
+
+ dist.init_process_group(
+ backend="nccl",
+ init_method=dist_url,
+ world_size=world_size,
+ rank=rank)
+
+ # this will make all .cuda() calls work properly
+ torch.cuda.set_device(local_rank)
+ # synchronizes all the threads to reach this point before moving on
+ dist.barrier()
+ setup_for_distributed(rank == 0)
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+def is_main_process():
+ try:
+ if dist.get_rank()==0:
+ return True
+ else:
+ return False
+ except:
+ return True
+
+def sample_data(loader):
+ loader_iter = iter(loader)
+ epoch = 0
+
+ while True:
+ try:
+ yield epoch, next(loader_iter)
+
+ except StopIteration:
+ epoch += 1
+ loader_iter = iter(loader)
+
+ yield epoch, next(loader_iter)
+
+
+def accumulate(model1, model2, decay=0.9999):
+ par1 = dict(model1.named_parameters())
+ par2 = dict(model2.named_parameters())
+
+ for k in par1.keys():
+ par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
+
+
+
+
+def train(conf, loader, val_loader, model, ema, diffusion, betas, optimizer, scheduler, guidance_prob, cond_scale, device, wandb):
+
+ import time
+
+ i = 0
+
+ loss_list = []
+ loss_mean_list = []
+ loss_vb_list = []
+
+ for epoch in range(300):
+
+ if is_main_process: print ('#Epoch - '+str(epoch))
+
+ start_time = time.time()
+
+ for batch in tqdm(loader):
+
+ i = i + 1
+
+ img = torch.cat([batch['source_image'], batch['target_image']], 0)
+ target_img = torch.cat([batch['target_image'], batch['source_image']], 0)
+ target_pose = torch.cat([batch['target_skeleton'], batch['source_skeleton']], 0)
+
+ img = img.to(device)
+ target_img = target_img.to(device)
+ target_pose = target_pose.to(device)
+ time_t = torch.randint(
+ 0,
+ conf.diffusion.beta_schedule["n_timestep"],
+ (img.shape[0],),
+ device=device,
+ )
+
+ loss_dict = diffusion.training_losses(model, x_start = target_img, t = time_t, cond_input = [img, target_pose], prob = 1 - guidance_prob)
+
+ loss = loss_dict['loss'].mean()
+ loss_mse = loss_dict['mse'].mean()
+ loss_vb = loss_dict['vb'].mean()
+
+
+ optimizer.zero_grad()
+ loss.backward()
+ nn.utils.clip_grad_norm_(model.parameters(), 1)
+ scheduler.step()
+ optimizer.step()
+ loss = loss_dict['loss'].mean()
+
+ loss_list.append(loss.detach().item())
+ loss_mean_list.append(loss_mse.detach().item())
+ loss_vb_list.append(loss_vb.detach().item())
+
+ accumulate(
+ ema, model.module, 0 if i < conf.training.scheduler.warmup else 0.9999
+ )
+
+
+ if i%args.save_wandb_logs_every_iters == 0 and is_main_process():
+
+ wandb.log({'loss':(sum(loss_list)/len(loss_list)),
+ 'loss_vb':(sum(loss_vb_list)/len(loss_vb_list)),
+ 'loss_mean':(sum(loss_mean_list)/len(loss_mean_list)),
+ 'epoch':epoch,'steps':i})
+ loss_list = []
+ loss_mean_list = []
+ loss_vb_list = []
+
+
+ if i%args.save_checkpoints_every_iters == 0 and is_main_process():
+
+ if conf.distributed:
+ model_module = model.module
+
+ else:
+ model_module = model
+
+ torch.save(
+ {
+ "model": model_module.state_dict(),
+ "ema": ema.state_dict(),
+ "scheduler": scheduler.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "conf": conf,
+ },
+ conf.training.ckpt_path + f"/model_{str(i).zfill(6)}.pt"
+ )
+
+ if is_main_process():
+
+ print ('Epoch Time '+str(int(time.time()-start_time))+' secs')
+ print ('Model Saved Successfully for #epoch '+str(epoch)+' #steps '+str(i))
+
+ if conf.distributed:
+ model_module = model.module
+
+ else:
+ model_module = model
+
+ torch.save(
+ {
+ "model": model_module.state_dict(),
+ "ema": ema.state_dict(),
+ "scheduler": scheduler.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "conf": conf,
+ },
+ conf.training.ckpt_path + '/last.pt'
+
+ )
+
+ if (epoch)%args.save_wandb_images_every_epochs==0:
+
+ print ('Generating samples at epoch number ' + str(epoch))
+
+ val_batch = next(val_loader)
+ val_img = val_batch['source_image'].cuda()
+ val_pose = val_batch['target_skeleton'].cuda()
+
+ with torch.no_grad():
+
+ if args.sample_algorithm == 'ddpm':
+ print ('Sampling algorithm used: DDPM')
+ samples = diffusion.p_sample_loop(ema, x_cond = [val_img, val_pose], progress = True, cond_scale = cond_scale)
+ elif args.sample_algorithm == 'ddim':
+ print ('Sampling algorithm used: DDIM')
+ nsteps = 50
+ noise = torch.randn(val_img.shape).cuda()
+ seq = range(0, 1000, 1000//nsteps)
+ xs, x0_preds = ddim_steps(noise, seq, ema, betas.cuda(), [val_img, val_pose])
+ samples = xs[-1].cuda()
+
+
+ grid = torch.cat([val_img, val_pose[:,:3], samples], -1)
+
+ gathered_samples = [torch.zeros_like(grid) for _ in range(dist.get_world_size())]
+ dist.all_gather(gathered_samples, grid)
+
+
+ if is_main_process():
+
+ wandb.log({'samples':wandb.Image(torch.cat(gathered_samples, -2))})
+
+
+
+def main(settings, EXP_NAME):
+
+ [args, DiffConf, DataConf] = settings
+
+ if is_main_process(): wandb.init(project="person-synthesis", name = EXP_NAME, settings = wandb.Settings(code_dir="."))
+
+ if DiffConf.ckpt is not None:
+ DiffConf.training.scheduler.warmup = 0
+
+ DiffConf.distributed = True
+ local_rank = int(os.environ['LOCAL_RANK'])
+
+ DataConf.data.train.batch_size = args.batch_size//2 #src -> tgt , tgt -> src
+
+ val_dataset, train_dataset = deepfashion_data.get_train_val_dataloader(DataConf.data, labels_required = True, distributed = True)
+
+ def cycle(iterable):
+ while True:
+ for x in iterable:
+ yield x
+
+ val_dataset = iter(cycle(val_dataset))
+
+ model = get_model_conf().make_model()
+ model = model.to(args.device)
+ ema = get_model_conf().make_model()
+ ema = ema.to(args.device)
+
+ if DiffConf.distributed:
+ model = nn.parallel.DistributedDataParallel(
+ model,
+ device_ids=[local_rank],
+ find_unused_parameters=True
+ )
+
+ optimizer = DiffConf.training.optimizer.make(model.parameters())
+ scheduler = DiffConf.training.scheduler.make(optimizer)
+
+ if DiffConf.ckpt is not None:
+ ckpt = torch.load(DiffConf.ckpt, map_location=lambda storage, loc: storage)
+
+ if DiffConf.distributed:
+ model.module.load_state_dict(ckpt["model"])
+
+ else:
+ model.load_state_dict(ckpt["model"])
+
+ ema.load_state_dict(ckpt["ema"])
+ scheduler.load_state_dict(ckpt["scheduler"])
+
+ if is_main_process(): print ('model loaded successfully')
+
+ betas = DiffConf.diffusion.beta_schedule.make()
+ diffusion = create_gaussian_diffusion(betas, predict_xstart = False)
+
+ train(
+ DiffConf, train_dataset, val_dataset, model, ema, diffusion, betas, optimizer, scheduler, args.guidance_prob, args.cond_scale, args.device, wandb
+ )
+
+if __name__ == "__main__":
+
+ init_distributed()
+
+ import argparse
+
+ parser = argparse.ArgumentParser(description='help')
+ parser.add_argument('--exp_name', type=str, default='pidm_deepfashion')
+ parser.add_argument('--DiffConfigPath', type=str, default='./config/diffusion.conf')
+ parser.add_argument('--DataConfigPath', type=str, default='./config/data.yaml')
+ parser.add_argument('--dataset_path', type=str, default='./dataset/deepfashion')
+ parser.add_argument('--save_path', type=str, default='checkpoints')
+ parser.add_argument('--cond_scale', type=int, default=2)
+ parser.add_argument('--guidance_prob', type=int, default=0.1)
+ parser.add_argument('--sample_algorithm', type=str, default='ddim') # ddpm, ddim
+ parser.add_argument('--batch_size', type=int, default=2)
+ parser.add_argument('--save_wandb_logs_every_iters', type=int, default=50)
+ parser.add_argument('--save_checkpoints_every_iters', type=int, default=2000)
+ parser.add_argument('--save_wandb_images_every_epochs', type=int, default=10)
+ parser.add_argument('--device', type=str, default='cuda')
+ parser.add_argument('--n_gpu', type=int, default=8)
+ parser.add_argument('--n_machine', type=int, default=1)
+ parser.add_argument('--local_rank', type=int, default=0)
+ parser.add_argument("opts", default=None, nargs=argparse.REMAINDER)
+
+ args = parser.parse_args()
+
+ print ('Experiment: '+ args.exp_name)
+
+ DiffConf = DiffConfig(DiffusionConfig, args.DiffConfigPath, args.opts, False)
+ DataConf = DataConfig(args.DataConfigPath)
+
+ DiffConf.training.ckpt_path = os.path.join(args.save_path, args.exp_name)
+ DataConf.data.path = args.dataset_path
+
+ if is_main_process():
+
+ if not os.path.isdir(args.save_path): os.mkdir(args.save_path)
+ if not os.path.isdir(DiffConf.training.ckpt_path): os.mkdir(DiffConf.training.ckpt_path)
+
+ #DiffConf.ckpt = "checkpoints/last.pt"
+
+ main(settings = [args, DiffConf, DataConf], EXP_NAME = args.exp_name)
diff --git a/PyTorch/built-in/mlm/PIDM/utils/README.md b/PyTorch/built-in/mlm/PIDM/utils/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7d1976e34988a8f0c5c7882df2ef5e451acdf987
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/utils/README.md
@@ -0,0 +1,27 @@
+
+### DDIM Sampling (Faster)
+```bash
+python -m torch.distributed.launch --nproc_per_node=8 --master_port 48109 utils/gen.py \
+ --exp_name="pidm_deepfashion" \
+ --checkpoint_name=last.pt \
+ --dataset_path "./dataset/deepfashion/" \
+ --sample_algorithm ddim
+```
+### DDPM Sampling
+```bash
+python -m torch.distributed.launch --nproc_per_node=8 --master_port 48109 utils/gen.py \
+ --exp_name="pidm_deepfashion" \
+ --checkpoint_name=last.pt \
+ --dataset_path "./dataset/deepfashion/" \
+ --sample_algorithm ddpm
+```
+
+Output images are saved inside ```images``` folder.
+
+### Folder structure for checkpoint files
+```
+PIDM/
+ checkpoints/
+ /
+
+```
diff --git a/PyTorch/built-in/mlm/PIDM/utils/gen.py b/PyTorch/built-in/mlm/PIDM/utils/gen.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f5eaaf2c48b94e674c040b097984a2c9ec735a5
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/utils/gen.py
@@ -0,0 +1,176 @@
+import os
+import warnings
+
+warnings.filterwarnings("ignore")
+
+import time, cv2, torch, wandb, shutil
+import torch.distributed as dist
+from config.diffconfig import DiffusionConfig, get_model_conf
+from config.dataconfig import Config as DataConfig
+from tensorfn import load_config as DiffConfig
+from diffusion import create_gaussian_diffusion, make_beta_schedule, ddim_steps
+from tensorfn.optim import lr_scheduler
+from torch import nn, optim
+from torch.utils import data
+from torchvision import transforms
+from tqdm import tqdm
+import numpy as np
+import data as deepfashion_data
+from model import UNet
+from PIL import Image
+
+def init_distributed():
+
+ # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
+ dist_url = "env://" # default
+
+ # only works with torch.distributed.launch // torch.run
+ rank = int(os.environ["RANK"])
+ world_size = int(os.environ['WORLD_SIZE'])
+ local_rank = int(os.environ['LOCAL_RANK'])
+
+ dist.init_process_group(
+ backend="nccl",
+ init_method=dist_url,
+ world_size=world_size,
+ rank=rank)
+
+ # this will make all .cuda() calls work properly
+ torch.cuda.set_device(local_rank)
+ # synchronizes all the threads to reach this point before moving on
+ dist.barrier()
+ setup_for_distributed(rank == 0)
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+def is_main_process():
+ try:
+ if dist.get_rank()==0:
+ return True
+ else:
+ return False
+ except:
+ return True
+
+
+
+if __name__ == "__main__":
+
+ init_distributed()
+ local_rank = int(os.environ['LOCAL_RANK'])
+
+ import argparse
+
+ parser = argparse.ArgumentParser(description='help')
+ parser.add_argument('--exp_name', type=str, default='pidm_deepfashion')
+ parser.add_argument('--DiffConfigPath', type=str, default='./config/diffusion.conf')
+ parser.add_argument('--DataConfigPath', type=str, default='./config/data.yaml')
+ parser.add_argument('--dataset_path', type=str, default='./dataset/deepfashion')
+ parser.add_argument('--save_path', type=str, default='checkpoints')
+ parser.add_argument('--sample_algorithm', type=str, default='ddim') # ddpm, ddim
+ parser.add_argument('--device', type=str, default='cuda')
+ parser.add_argument('--cond_scale', type=float, default=2.0)
+ parser.add_argument('--checkpoint_name', type=str, default="last.pt")
+ parser.add_argument('--batch_size', type=int, default=10)
+ parser.add_argument("--local_rank", type=int, default=0)
+ parser.add_argument("opts", default=None, nargs=argparse.REMAINDER)
+
+ args = parser.parse_args()
+
+ print ('Experiment: '+ args.exp_name)
+
+ cond_scale = args.cond_scale
+ sample_algorithm = args.sample_algorithm # options: DDPM, DDIM
+
+ _folder = args.checkpoint_name+'-'+sample_algorithm+'-'+'scale:'+str(cond_scale)
+
+ fake_folder = 'images/'+args.exp_name+'/'+_folder
+
+ if is_main_process():
+ if not os.path.isdir( 'images/'):
+ os.mkdir( 'images/')
+
+ if not os.path.isdir( 'images/'+args.exp_name):
+ os.mkdir( 'images/'+args.exp_name)
+
+ if os.path.isdir(fake_folder):
+ shutil.rmtree(fake_folder)
+
+ os.mkdir(fake_folder)
+
+
+ DiffConf = DiffConfig(DiffusionConfig, args.DiffConfigPath, args.opts, False)
+ DataConf = DataConfig(args.DataConfigPath)
+ DiffConf.training.ckpt_path = os.path.join(args.save_path, args.exp_name)
+ DataConf.data.path = args.dataset_path
+ DataConf.data.val.batch_size = args.batch_size
+ val_dataset, train_dataset = deepfashion_data.get_train_val_dataloader(DataConf.data, labels_required = True, distributed = True)
+ val_dataset = iter(val_dataset)
+
+ ckpt = torch.load(args.save_path+"/"+args.exp_name+'/'+args.checkpoint_name)
+
+ model = get_model_conf().make_model()
+ model = model.to(args.device)
+ model.load_state_dict(ckpt["ema"])
+ model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
+ betas = DiffConf.diffusion.beta_schedule.make()
+ diffusion = create_gaussian_diffusion(betas, predict_xstart = False)
+ model.eval()
+
+ with torch.no_grad():
+
+ for batch_it in range(len(val_dataset)):
+
+ batch = next(val_dataset)
+
+ print ('batch_id-'+str(batch_it))
+
+ img = batch['source_image'].cuda()
+ target_pose = batch['target_skeleton'].cuda()
+
+ if args.sample_algorithm == 'DDPM' or args.sample_algorithm == 'ddpm' :
+
+ sample_fn = diffusion.ddim_sample_loop
+
+ samples = sample_fn(model.module, x_cond = [img, target_pose], progress = True, cond_scale = cond_scale)
+
+ target_output = torch.clamp(samples, -1., 1.)
+ numpy_imgs = (target_output.permute(0,2,3,1).detach().cpu().numpy() + 1.0)/2.0
+ fake_imgs = (255*numpy_imgs).astype(np.uint8)
+
+ img_save_names = batch['path']
+
+ [Image.fromarray(im).save(os.path.join(fake_folder, img_save_names[idx])) for idx, im in enumerate(fake_imgs)]
+
+ elif args.sample_algorithm == 'DDIM' or args.sample_algorithm == 'ddim' :
+
+ nsteps = 100
+
+ noise = torch.randn(img.shape).cuda()
+ seq = range(0, 1000, 1000//nsteps)
+ xs, x0_preds = ddim_steps(noise, seq, model.module, betas.cuda(), [img, target_pose], diffusion=diffusion, cond_scale=cond_scale)
+ samples = xs[-1].cuda()
+
+ target_output = torch.clamp(samples, -1., 1.)
+ numpy_imgs = (target_output.permute(0,2,3,1).detach().cpu().numpy() + 1.0)/2.0
+ fake_imgs = (255*numpy_imgs).astype(np.uint8)
+
+ img_save_names = batch['path']
+
+ [Image.fromarray(im).save(os.path.join(fake_folder, img_save_names[idx])) for idx, im in enumerate(fake_imgs)]
+
+ else:
+
+ print ('ERROR! Sample algorithm not defined.')
diff --git a/PyTorch/built-in/mlm/PIDM/utils/inception.py b/PyTorch/built-in/mlm/PIDM/utils/inception.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb5aaedce7099a71bf19b2855ed102499c043611
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/utils/inception.py
@@ -0,0 +1,138 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision import models
+
+
+class InceptionV3(nn.Module):
+ """Pretrained InceptionV3 network returning feature maps"""
+
+ # Index of default block of inception to return,
+ # corresponds to output of final average pooling
+ DEFAULT_BLOCK_INDEX = 3
+
+ # Maps feature dimensionality to their output blocks indices
+ BLOCK_INDEX_BY_DIM = {
+ 64: 0, # First max pooling features
+ 192: 1, # Second max pooling featurs
+ 768: 2, # Pre-aux classifier features
+ 2048: 3 # Final average pooling features
+ }
+
+ def __init__(self,
+ output_blocks=[DEFAULT_BLOCK_INDEX],
+ resize_input=True,
+ normalize_input=True,
+ requires_grad=False):
+ """Build pretrained InceptionV3
+ Parameters
+ ----------
+ output_blocks : list of int
+ Indices of blocks to return features of. Possible values are:
+ - 0: corresponds to output of first max pooling
+ - 1: corresponds to output of second max pooling
+ - 2: corresponds to output which is fed to aux classifier
+ - 3: corresponds to output of final average pooling
+ resize_input : bool
+ If true, bilinearly resizes input to width and height 299 before
+ feeding input to model. As the network without fully connected
+ layers is fully convolutional, it should be able to handle inputs
+ of arbitrary size, so resizing might not be strictly needed
+ normalize_input : bool
+ If true, normalizes the input to the statistics the pretrained
+ Inception network expects
+ requires_grad : bool
+ If true, parameters of the model require gradient. Possibly useful
+ for finetuning the network
+ """
+ super(InceptionV3, self).__init__()
+
+ self.resize_input = resize_input
+ self.normalize_input = normalize_input
+ self.output_blocks = sorted(output_blocks)
+ self.last_needed_block = max(output_blocks)
+
+ assert self.last_needed_block <= 3, \
+ 'Last possible output block index is 3'
+
+ self.blocks = nn.ModuleList()
+
+ inception = models.inception_v3(pretrained=True)
+
+ # Block 0: input to maxpool1
+ block0 = [
+ inception.Conv2d_1a_3x3,
+ inception.Conv2d_2a_3x3,
+ inception.Conv2d_2b_3x3,
+ nn.MaxPool2d(kernel_size=3, stride=2)
+ ]
+ self.blocks.append(nn.Sequential(*block0))
+
+ # Block 1: maxpool1 to maxpool2
+ if self.last_needed_block >= 1:
+ block1 = [
+ inception.Conv2d_3b_1x1,
+ inception.Conv2d_4a_3x3,
+ nn.MaxPool2d(kernel_size=3, stride=2)
+ ]
+ self.blocks.append(nn.Sequential(*block1))
+
+ # Block 2: maxpool2 to aux classifier
+ if self.last_needed_block >= 2:
+ block2 = [
+ inception.Mixed_5b,
+ inception.Mixed_5c,
+ inception.Mixed_5d,
+ inception.Mixed_6a,
+ inception.Mixed_6b,
+ inception.Mixed_6c,
+ inception.Mixed_6d,
+ inception.Mixed_6e,
+ ]
+ self.blocks.append(nn.Sequential(*block2))
+
+ # Block 3: aux classifier to final avgpool
+ if self.last_needed_block >= 3:
+ block3 = [
+ inception.Mixed_7a,
+ inception.Mixed_7b,
+ inception.Mixed_7c,
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
+ ]
+ self.blocks.append(nn.Sequential(*block3))
+
+ for param in self.parameters():
+ param.requires_grad = requires_grad
+
+ def forward(self, inp):
+ """Get Inception feature maps
+ Parameters
+ ----------
+ inp : torch.autograd.Variable
+ Input tensor of shape Bx3xHxW. Values are expected to be in
+ range (0, 1)
+ Returns
+ -------
+ List of torch.autograd.Variable, corresponding to the selected output
+ block, sorted ascending by index
+ """
+ outp = []
+ x = inp
+
+ if self.resize_input:
+ x = F.upsample(x, size=(299, 299), mode='bilinear')
+
+ if self.normalize_input:
+ x = x.clone()
+ x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
+ x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
+ x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
+
+ for idx, block in enumerate(self.blocks):
+ x = block(x)
+ if idx in self.output_blocks:
+ outp.append(x)
+
+ if idx == self.last_needed_block:
+ break
+
+ return outp
diff --git a/PyTorch/built-in/mlm/PIDM/utils/metrics.py b/PyTorch/built-in/mlm/PIDM/utils/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..eddfd021bc2097dae4fb78b0b65e156f713d8929
--- /dev/null
+++ b/PyTorch/built-in/mlm/PIDM/utils/metrics.py
@@ -0,0 +1,664 @@
+import os
+import pathlib
+import torch
+import numpy as np
+from imageio import imread
+from scipy import linalg
+from torch.nn.functional import adaptive_avg_pool2d
+from skimage.measure import compare_ssim
+from skimage.measure import compare_psnr
+
+import glob
+import argparse
+import matplotlib.pyplot as plt
+from inception import InceptionV3
+#from scripts.PerceptualSimilarity.models import dist_model as dm
+import lpips
+import pandas as pd
+import json
+import imageio
+#from skimage.draw import circle, line_aa, polygon
+import cv2
+
+
+class FID():
+ """docstring for FID
+ Calculates the Frechet Inception Distance (FID) to evalulate GANs
+ The FID metric calculates the distance between two distributions of images.
+ Typically, we have summary statistics (mean & covariance matrix) of one
+ of these distributions, while the 2nd distribution is given by a GAN.
+ When run as a stand-alone program, it compares the distribution of
+ images that are stored as PNG/JPEG at a specified location with a
+ distribution given by summary statistics (in pickle format).
+ The FID is calculated by assuming that X_1 and X_2 are the activations of
+ the pool_3 layer of the inception net for generated samples and real world
+ samples respectivly.
+ See --help to see further details.
+ Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
+ of Tensorflow
+ Copyright 2018 Institute of Bioinformatics, JKU Linz
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ """
+ def __init__(self):
+ self.dims = 2048
+ self.batch_size = 64
+ self.cuda = True
+ self.verbose=False
+
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims]
+ self.model = InceptionV3([block_idx])
+ if self.cuda:
+ # TODO: put model into specific GPU
+ self.model.cuda()
+
+ def __call__(self, images, gt_path):
+ """ images: list of the generated image. The values must lie between 0 and 1.
+ gt_path: the path of the ground truth images. The values must lie between 0 and 1.
+ """
+ if not os.path.exists(gt_path):
+ raise RuntimeError('Invalid path: %s' % gt_path)
+
+
+ print('calculate gt_path statistics...')
+ m1, s1 = self.compute_statistics_of_path(gt_path, self.verbose)
+ print('calculate generated_images statistics...')
+ m2, s2 = self.calculate_activation_statistics(images, self.verbose)
+ fid_value = self.calculate_frechet_distance(m1, s1, m2, s2)
+ return fid_value
+
+
+ def calculate_from_disk(self, generated_path, gt_path):
+ """
+ """
+ if not os.path.exists(gt_path):
+ raise RuntimeError('Invalid path: %s' % gt_path)
+ if not os.path.exists(generated_path):
+ raise RuntimeError('Invalid path: %s' % generated_path)
+
+ print ('exp-path - '+generated_path)
+
+ print('calculate gt_path statistics...')
+ m1, s1 = self.compute_statistics_of_path(gt_path, self.verbose)
+ print('calculate generated_path statistics...')
+ m2, s2 = self.compute_statistics_of_path(generated_path, self.verbose)
+ print('calculate frechet distance...')
+ fid_value = self.calculate_frechet_distance(m1, s1, m2, s2)
+ print('fid_distance %f' % (fid_value))
+ return fid_value
+
+
+ def compute_statistics_of_path(self, path, verbose):
+
+ npz_file = os.path.join(path, 'statistics.npz')
+ if os.path.exists(npz_file):
+ f = np.load(npz_file)
+ m, s = f['mu'][:], f['sigma'][:]
+ f.close()
+
+ else:
+
+ path = pathlib.Path(path)
+ files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
+ imgs = np.array([(cv2.resize(imread(str(fn)).astype(np.float32),(176, 256))) for fn in files])#np.array([imread(str(fn)).astype(np.float32) for fn in files])
+
+ # Bring images to shape (B, 3, H, W)
+ imgs = imgs.transpose((0, 3, 1, 2))
+
+ # Rescale images to be between 0 and 1
+ imgs /= 255
+
+ m, s = self.calculate_activation_statistics(imgs, verbose)
+ np.savez(npz_file, mu=m, sigma=s)
+
+ return m, s
+
+ def calculate_activation_statistics(self, images, verbose):
+ """Calculation of the statistics used by the FID.
+ Params:
+ -- images : Numpy array of dimension (n_images, 3, hi, wi). The values
+ must lie between 0 and 1.
+ -- model : Instance of inception model
+ -- batch_size : The images numpy array is split into batches with
+ batch size batch_size. A reasonable batch size
+ depends on the hardware.
+ -- dims : Dimensionality of features returned by Inception
+ -- cuda : If set to True, use GPU
+ -- verbose : If set to True and parameter out_step is given, the
+ number of calculated batches is reported.
+ Returns:
+ -- mu : The mean over samples of the activations of the pool_3 layer of
+ the inception model.
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
+ the inception model.
+ """
+ act = self.get_activations(images, verbose)
+ mu = np.mean(act, axis=0)
+ sigma = np.cov(act, rowvar=False)
+ return mu, sigma
+
+
+
+ def get_activations(self, images, verbose=False):
+ """Calculates the activations of the pool_3 layer for all images.
+ Params:
+ -- images : Numpy array of dimension (n_images, 3, hi, wi). The values
+ must lie between 0 and 1.
+ -- model : Instance of inception model
+ -- batch_size : the images numpy array is split into batches with
+ batch size batch_size. A reasonable batch size depends
+ on the hardware.
+ -- dims : Dimensionality of features returned by Inception
+ -- cuda : If set to True, use GPU
+ -- verbose : If set to True and parameter out_step is given, the number
+ of calculated batches is reported.
+ Returns:
+ -- A numpy array of dimension (num images, dims) that contains the
+ activations of the given tensor when feeding inception with the
+ query tensor.
+ """
+ self.model.eval()
+
+ d0 = images.shape[0]
+ if self.batch_size > d0:
+ print(('Warning: batch size is bigger than the data size. '
+ 'Setting batch size to data size'))
+ self.batch_size = d0
+
+ n_batches = d0 // self.batch_size
+ n_used_imgs = n_batches * self.batch_size
+
+ pred_arr = np.empty((n_used_imgs, self.dims))
+ for i in range(n_batches):
+ if verbose:
+ print('\rPropagating batch %d/%d' % (i + 1, n_batches))
+ # end='', flush=True)
+ start = i * self.batch_size
+ end = start + self.batch_size
+
+ batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor)
+ # batch = Variable(batch, volatile=True)
+ if self.cuda:
+ batch = batch.cuda()
+
+ pred = self.model(batch)[0]
+
+ # If model output is not scalar, apply global spatial average pooling.
+ # This happens if you choose a dimensionality not equal 2048.
+ if pred.shape[2] != 1 or pred.shape[3] != 1:
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
+
+ pred_arr[start:end] = pred.cpu().data.numpy().reshape(self.batch_size, -1)
+
+ if verbose:
+ print(' done')
+
+ return pred_arr
+
+
+ def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
+ """Numpy implementation of the Frechet Distance.
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
+ and X_2 ~ N(mu_2, C_2) is
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
+ Stable version by Dougal J. Sutherland.
+ Params:
+ -- mu1 : Numpy array containing the activations of a layer of the
+ inception net (like returned by the function 'get_predictions')
+ for generated samples.
+ -- mu2 : The sample mean over activations, precalculated on an
+ representive data set.
+ -- sigma1: The covariance matrix over activations for generated samples.
+ -- sigma2: The covariance matrix over activations, precalculated on an
+ representive data set.
+ Returns:
+ -- : The Frechet Distance.
+ """
+
+ mu1 = np.atleast_1d(mu1)
+ mu2 = np.atleast_1d(mu2)
+
+ sigma1 = np.atleast_2d(sigma1)
+ sigma2 = np.atleast_2d(sigma2)
+
+ assert mu1.shape == mu2.shape, \
+ 'Training and test mean vectors have different lengths'
+ assert sigma1.shape == sigma2.shape, \
+ 'Training and test covariances have different dimensions'
+
+ diff = mu1 - mu2
+
+ # Product might be almost singular
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
+ if not np.isfinite(covmean).all():
+ msg = ('fid calculation produces singular product; '
+ 'adding %s to diagonal of cov estimates') % eps
+ print(msg)
+ offset = np.eye(sigma1.shape[0]) * eps
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
+
+ # Numerical error might give slight imaginary component
+ if np.iscomplexobj(covmean):
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
+ m = np.max(np.abs(covmean.imag))
+ raise ValueError('Imaginary component {}'.format(m))
+ covmean = covmean.real
+
+ tr_covmean = np.trace(covmean)
+
+ return (diff.dot(diff) + np.trace(sigma1) +
+ np.trace(sigma2) - 2 * tr_covmean)
+
+
+class Reconstruction_Metrics():
+ def __init__(self, metric_list=['ssim', 'psnr', 'l1', 'mae'], data_range=1, win_size=51, multichannel=True):
+ self.data_range = data_range
+ self.win_size = win_size
+ self.multichannel = multichannel
+ for metric in metric_list:
+ if metric in ['ssim', 'psnr', 'l1', 'mae']:
+ setattr(self, metric, True)
+ else:
+ print('unsupport reconstruction metric: %s'%metric)
+
+
+ def __call__(self, inputs, gts):
+ """
+ inputs: the generated image, size (b,c,w,h), data range(0, data_range)
+ gts: the ground-truth image, size (b,c,w,h), data range(0, data_range)
+ """
+ result = dict()
+ [b,n,w,h] = inputs.size()
+ inputs = inputs.view(b*n, w, h).detach().cpu().numpy().astype(np.float32).transpose(1,2,0)
+ gts = gts.view(b*n, w, h).detach().cpu().numpy().astype(np.float32).transpose(1,2,0)
+
+ if hasattr(self, 'ssim'):
+ ssim_value = compare_ssim(inputs, gts, data_range=self.data_range,
+ win_size=self.win_size, multichannel=self.multichannel)
+ result['ssim'] = ssim_value
+
+
+ if hasattr(self, 'psnr'):
+ psnr_value = compare_psnr(inputs, gts, self.data_range)
+ result['psnr'] = psnr_value
+
+ if hasattr(self, 'l1'):
+ l1_value = compare_l1(inputs, gts)
+ result['l1'] = l1_value
+
+ if hasattr(self, 'mae'):
+ mae_value = compare_mae(inputs, gts)
+ result['mae'] = mae_value
+ return result
+
+
+ def calculate_from_disk(self, inputs, gts, save_path=None, sort=True, debug=0):
+ """
+ inputs: .txt files, floders, image files (string), image files (list)
+ gts: .txt files, floders, image files (string), image files (list)
+ """
+ if sort:
+ input_image_list = sorted(get_image_list(inputs))
+ gt_image_list = sorted(get_image_list(gts))
+ else:
+ input_image_list = get_image_list(inputs)
+ gt_image_list = get_image_list(gts)
+ npz_file = os.path.join(save_path, 'metrics2.npz')
+ if os.path.exists(npz_file):
+ f = np.load(npz_file)
+ psnr,ssim,ssim_256,mae,l1=f['psnr'],f['ssim'],f['ssim_256'],f['mae'],f['l1']
+ else:
+ psnr = []
+ ssim = []
+ ssim_256 = []
+ mae = []
+ l1 = []
+ names = []
+
+ for index in range(len(input_image_list)):
+ name = os.path.basename(input_image_list[index])
+ names.append(name)
+
+ img_gt = (imread(str(gt_image_list[index]))).astype(np.float32) / 255.0
+ img_pred = (imread(str(input_image_list[index]))).astype(np.float32) / 255.0
+
+ if debug != 0:
+ plt.subplot('121')
+ plt.imshow(img_gt)
+ plt.title('Groud truth')
+ plt.subplot('122')
+ plt.imshow(img_pred)
+ plt.title('Output')
+ plt.show()
+
+ psnr.append(compare_psnr(img_gt, img_pred, data_range=self.data_range))
+ ssim.append(compare_ssim(img_gt, img_pred, data_range=self.data_range,
+ win_size=self.win_size,multichannel=self.multichannel))
+ mae.append(compare_mae(img_gt, img_pred))
+ l1.append(compare_l1(img_gt, img_pred))
+
+ img_gt_256 = img_gt*255.0
+ img_pred_256 = img_pred*255.0
+ ssim_256.append(compare_ssim(img_gt_256, img_pred_256, gaussian_weights=True, sigma=1.2,
+ use_sample_covariance=False, multichannel=True,
+ data_range=img_pred_256.max() - img_pred_256.min()))
+ if np.mod(index, 200) == 0:
+ print(
+ str(index) + ' images processed',
+ "PSNR: %.4f" % round(np.mean(psnr), 4),
+ "SSIM: %.4f" % round(np.mean(ssim), 4),
+ "SSIM_256: %.4f" % round(np.mean(ssim_256), 4),
+ "MAE: %.4f" % round(np.mean(mae), 4),
+ "l1: %.4f" % round(np.mean(l1), 4),
+ )
+
+ if save_path:
+ np.savez(save_path + '/metrics.npz', psnr=psnr, ssim=ssim, ssim_256=ssim_256, mae=mae, l1=l1, names=names)
+
+ print(
+ "PSNR: %.4f" % round(np.mean(psnr), 4),
+ "PSNR Variance: %.4f" % round(np.var(psnr), 4),
+ "SSIM: %.4f" % round(np.mean(ssim), 4),
+ "SSIM Variance: %.4f" % round(np.var(ssim), 4),
+ "SSIM_256: %.4f" % round(np.mean(ssim_256), 4),
+ "SSIM_256 Variance: %.4f" % round(np.var(ssim_256), 4),
+ "MAE: %.4f" % round(np.mean(mae), 4),
+ "MAE Variance: %.4f" % round(np.var(mae), 4),
+ "l1: %.4f" % round(np.mean(l1), 4),
+ "l1 Variance: %.4f" % round(np.var(l1), 4)
+ )
+
+ dic = {"psnr":[round(np.mean(psnr), 6)],
+ "psnr_variance": [round(np.var(psnr), 6)],
+ "ssim": [round(np.mean(ssim), 6)],
+ "ssim_variance": [round(np.var(ssim), 6)],
+ "ssim_256": [round(np.mean(ssim_256), 6)],
+ "ssim_256_variance": [round(np.var(ssim_256), 6)],
+ "mae": [round(np.mean(mae), 6)],
+ "mae_variance": [round(np.var(mae), 6)],
+ "l1": [round(np.mean(l1), 6)],
+ "l1_variance": [round(np.var(l1), 6)] }
+
+ return dic
+
+
+def get_image_list(flist):
+ if isinstance(flist, list):
+ return flist
+
+ # flist: image file path, image directory path, text file flist path
+ if isinstance(flist, str):
+ if os.path.isdir(flist):
+ flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png'))
+ flist.sort()
+ return flist
+
+ if os.path.isfile(flist):
+ try:
+ return np.genfromtxt(flist, dtype=np.str)
+ except:
+ return [flist]
+ print('can not read files from %s return empty list'%flist)
+ return []
+
+def compare_l1(img_true, img_test):
+ img_true = img_true.astype(np.float32)
+ img_test = img_test.astype(np.float32)
+ return np.mean(np.abs(img_true - img_test))
+
+def compare_mae(img_true, img_test):
+ img_true = img_true.astype(np.float32)
+ img_test = img_test.astype(np.float32)
+ return np.sum(np.abs(img_true - img_test)) / np.sum(img_true + img_test)
+
+def preprocess_path_for_deform_task(gt_path, distorted_path):
+ distorted_image_list = sorted(get_image_list(distorted_path))
+ gt_list=[]
+ distorated_list=[]
+ # for distorted_image in distorted_image_list:
+ # image = os.path.basename(distorted_image)
+ # image = image.split('.jpg___')[-1]
+ # image = image.split('_vis')[0]
+ # gt_image = os.path.join(gt_path, image)
+ # if not os.path.isfile(gt_image):
+ # continue
+ # gt_list.append(gt_image)
+ # distorated_list.append(distorted_image)
+
+ for distorted_image in distorted_image_list:
+ image = os.path.basename(distorted_image)
+ image = image.split('_2_')[-1]
+ image = image.split('_vis')[0] +'.png'
+ gt_image = os.path.join(gt_path, 'deep'+image)
+ if not os.path.isfile(gt_image):
+ print(gt_image)
+ continue
+ gt_list.append(gt_image)
+ distorated_list.append(distorted_image)
+
+ return gt_list, distorated_list
+
+
+
+class LPIPS():
+ def __init__(self, use_gpu=True):
+ # self.model = dm.DistModel()
+ # self.model.initialize(model='net-lin', net='alex',use_gpu=use_gpu)
+ self.model = lpips.LPIPS(net='alex').cuda()
+ self.use_gpu=use_gpu
+
+ def __call__(self, image_1, image_2):
+ """
+ image_1: images with size (n, 3, w, h) with value [-1, 1]
+ image_2: images with size (n, 3, w, h) with value [-1, 1]
+ """
+ result = self.model.forward(image_1, image_2)
+ return result
+
+ def calculate_from_disk(self, path_1, path_2, batch_size=64, verbose=False, sort=True):
+
+ if sort:
+ files_1 = sorted(get_image_list(path_1))
+ files_2 = sorted(get_image_list(path_2))
+ else:
+ files_1 = get_image_list(path_1)
+ files_2 = get_image_list(path_2)
+
+ #imgs_1 = np.array([imread(str(fn)).astype(np.float32)/127.5-1 for fn in files_1])
+ #imgs_2 = np.array([imread(str(fn)).astype(np.float32)/127.5-1 for fn in files_2])
+
+ # Bring images to shape (B, 3, H, W)
+ #imgs_1 = imgs_1.transpose((0, 3, 1, 2))
+ #imgs_2 = imgs_2.transpose((0, 3, 1, 2))
+
+ result=[]
+
+
+ d0 = len(files_1)
+ if batch_size > d0:
+ print(('Warning: batch size is bigger than the data size. '
+ 'Setting batch size to data size'))
+ batch_size = d0
+
+ n_batches = d0 // batch_size
+ n_used_imgs = n_batches * batch_size
+
+ # imgs_1_arr = np.empty((n_used_imgs, self.dims))
+ # imgs_2_arr = np.empty((n_used_imgs, self.dims))
+
+ for i in range(n_batches):
+ if verbose:
+ print('\rPropagating batch %d/%d' % (i + 1, n_batches))
+ # end='', flush=True)
+ start = i * batch_size
+ end = start + batch_size
+
+ imgs_1 = np.array([cv2.resize(imread(str(fn)).astype(np.float32),(176, 256))/127.5-1 for fn in files_1[start:end]])
+ imgs_2 = np.array([cv2.resize(imread(str(fn)).astype(np.float32),(176, 256))/127.5-1 for fn in files_2[start:end]])
+
+ imgs_1 = imgs_1.transpose((0, 3, 1, 2))
+ imgs_2 = imgs_2.transpose((0, 3, 1, 2))
+
+ img_1_batch = torch.from_numpy(imgs_1).type(torch.FloatTensor)
+ img_2_batch = torch.from_numpy(imgs_2).type(torch.FloatTensor)
+
+ if self.use_gpu:
+ img_1_batch = img_1_batch.cuda()
+ img_2_batch = img_2_batch.cuda()
+
+ # img_1_batch = torch.nn.functional.interpolate(img_1_batch, size=(256, 176), mode='bilinear', align_corners=False)
+ # img_2_batch = torch.nn.functional.interpolate(img_2_batch, size=(256, 176), mode='bilinear', align_corners=False)
+
+
+ result.append(self.model.forward(img_1_batch, img_2_batch))
+
+
+ distance = torch.cat(result,0)[:,0,0,0].mean()
+
+ print('lpips: %.3f'%distance)
+ return distance
+
+ def calculate_mask_lpips(self, path_1, path_2, batch_size=64, verbose=False, sort=True):
+ if sort:
+ files_1 = sorted(get_image_list(path_1))
+ files_2 = sorted(get_image_list(path_2))
+ else:
+ files_1 = get_image_list(path_1)
+ files_2 = get_image_list(path_2)
+
+ imgs_1=[]
+ imgs_2=[]
+ bonesLst = './dataset/market_data/market-annotation-test.csv'
+ annotation_file = pd.read_csv(bonesLst, sep=':')
+ annotation_file = annotation_file.set_index('name')
+
+ for i in range(len(files_1)):
+ string = annotation_file.loc[os.path.basename(files_2[i])]
+ mask = np.tile(np.expand_dims(create_masked_image(string).astype(np.float32), -1), (1,1,3))#.repeat(1,1,3)
+ imgs_1.append((imread(str(files_1[i])).astype(np.float32)/127.5-1)*mask)
+ imgs_2.append((imread(str(files_2[i])).astype(np.float32)/127.5-1)*mask)
+
+ # Bring images to shape (B, 3, H, W)
+ imgs_1 = np.array(imgs_1)
+ imgs_2 = np.array(imgs_2)
+ imgs_1 = imgs_1.transpose((0, 3, 1, 2))
+ imgs_2 = imgs_2.transpose((0, 3, 1, 2))
+
+ result=[]
+
+
+ d0 = imgs_1.shape[0]
+ if batch_size > d0:
+ print(('Warning: batch size is bigger than the data size. '
+ 'Setting batch size to data size'))
+ batch_size = d0
+
+ n_batches = d0 // batch_size
+ n_used_imgs = n_batches * batch_size
+
+ for i in range(n_batches):
+ if verbose:
+ print('\rPropagating batch %d/%d' % (i + 1, n_batches))
+ # end='', flush=True)
+ start = i * batch_size
+ end = start + batch_size
+
+ img_1_batch = torch.from_numpy(imgs_1[start:end]).type(torch.FloatTensor)
+ img_2_batch = torch.from_numpy(imgs_2[start:end]).type(torch.FloatTensor)
+
+ if self.use_gpu:
+ img_1_batch = img_1_batch.cuda()
+ img_2_batch = img_2_batch.cuda()
+
+
+ result.append(self.model.forward(img_1_batch, img_2_batch))
+
+
+ distance = np.average(result)
+ print('lpips: %.3f'%distance)
+ return distance
+
+
+
+def produce_ma_mask(kp_array, img_size=(128, 64), point_radius=4):
+ MISSING_VALUE = -1
+ from skimage.morphology import dilation, erosion, square
+ mask = np.zeros(shape=img_size, dtype=bool)
+ limbs = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10],
+ [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17],
+ [1,16], [16,18], [2,17], [2,18], [9,12], [12,6], [9,3], [17,18]]
+ limbs = np.array(limbs) - 1
+ for f, t in limbs:
+ from_missing = kp_array[f][0] == MISSING_VALUE or kp_array[f][1] == MISSING_VALUE
+ to_missing = kp_array[t][0] == MISSING_VALUE or kp_array[t][1] == MISSING_VALUE
+ if from_missing or to_missing:
+ continue
+
+ norm_vec = kp_array[f] - kp_array[t]
+ norm_vec = np.array([-norm_vec[1], norm_vec[0]])
+ norm_vec = point_radius * norm_vec / np.linalg.norm(norm_vec)
+
+
+ vetexes = np.array([
+ kp_array[f] + norm_vec,
+ kp_array[f] - norm_vec,
+ kp_array[t] - norm_vec,
+ kp_array[t] + norm_vec
+ ])
+ yy, xx = polygon(vetexes[:, 0], vetexes[:, 1], shape=img_size)
+ mask[yy, xx] = True
+
+ for i, joint in enumerate(kp_array):
+ if kp_array[i][0] == MISSING_VALUE or kp_array[i][1] == MISSING_VALUE:
+ continue
+ yy, xx = circle(joint[0], joint[1], radius=point_radius, shape=img_size)
+ mask[yy, xx] = True
+
+ mask = dilation(mask, square(5))
+ mask = erosion(mask, square(5))
+ return mask
+
+def load_pose_cords_from_strings(y_str, x_str):
+ y_cords = json.loads(y_str)
+ x_cords = json.loads(x_str)
+ return np.concatenate([np.expand_dims(y_cords, -1), np.expand_dims(x_cords, -1)], axis=1)
+
+def create_masked_image(ano_to):
+ kp_to = load_pose_cords_from_strings(ano_to['keypoints_y'], ano_to['keypoints_x'])
+ mask = produce_ma_mask(kp_to)
+ return mask
+
+if __name__ == "__main__":
+
+ fid = FID()
+ lpips_obj = LPIPS()
+ rec = Reconstruction_Metrics()
+
+ real_path = './deepfashion/train_256x256_pngs/'
+ gt_path = './deepfashion/test_256x256_pngs/'
+ distorated_path = './output_images_pngs/'
+
+ gt_list, distorated_list = preprocess_path_for_deform_task(gt_path, distorated_path)
+
+ FID = fid.calculate_from_disk(distorated_path, real_path)
+ LPIPS = lpips_obj.calculate_from_disk(distorated_list, gt_list, sort=False)
+ REC = rec.calculate_from_disk(distorated_list, gt_list, distorated_path, sort=False, debug=False)
+
+ print ("FID: "+str(FID)+"\nLPIPS: "+str(LPIPS)+"\nSSIM: "+str(REC))
+
+
+
+
+
+
+
+
+
+