diff --git a/multimodal/diffusion/DDPM/pytorch-ddpm/.gitignore b/multimodal/diffusion/DDPM/pytorch-ddpm/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4b185d423fc0a4fcbb90ff875635bf2947dc71b8 --- /dev/null +++ b/multimodal/diffusion/DDPM/pytorch-ddpm/.gitignore @@ -0,0 +1,8 @@ +.vscode + +stats +data +logs + +.python-version +__pycache__ \ No newline at end of file diff --git a/multimodal/diffusion/DDPM/pytorch-ddpm/LICENSE b/multimodal/diffusion/DDPM/pytorch-ddpm/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a1b7448fcf7a36254b33cfc4edf59fc682355166 --- /dev/null +++ b/multimodal/diffusion/DDPM/pytorch-ddpm/LICENSE @@ -0,0 +1,13 @@ +DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + Version 2, December 2004 + + Copyright (C) 2004 Sam Hocevar + + Everyone is permitted to copy and distribute verbatim or modified + copies of this license document, and changing it is allowed as long + as the name is changed. + + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + + 0. You just DO WHAT THE FUCK YOU WANT TO. diff --git a/multimodal/diffusion/DDPM/pytorch-ddpm/README.md b/multimodal/diffusion/DDPM/pytorch-ddpm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1737cddb9750f56dc48000f7c65ea50c0a87d049 --- /dev/null +++ b/multimodal/diffusion/DDPM/pytorch-ddpm/README.md @@ -0,0 +1,85 @@ +# Denoising Diffusion Probabilistic Models + +## Model description + +Unofficial PyTorch implementation of Denoising Diffusion Probabilistic Models. This implementation follows the most of details in official TensorFlow implementation. + +## Step 1: Installation + +``` +cd pytorch-ddpm +pip3 install -U pip setuptools +pip3 install -r requirements.txt +pip3 install protobuf==3.20.3 +yum install mesa-libGL +pip3 install urllib3==1.26.6 + +``` + + +## Step 2: Preparing datasets + +``` +mkdir -p stats && cd stats +``` + +Download precalculated statistic for dataset: + +[cifar10.train.npz](https://drive.google.com/drive/folders/1UBdzl6GtNMwNQ5U-4ESlIer43tNjiGJC) + +the dataset structure sholud look like: + +``` +stats +└── cifar10.train.npz +``` + +## Step 3: Training + +``` +cd .. + +# 8 GPUs +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +python3 main.py --train \ + --flagfile ./config/CIFAR10.txt \ + --parallel + +# 1 GPU +export CUDA_VISIBLE_DEVICES=0 + +python3 main.py --train \ + --flagfile ./config/CIFAR10.txt +``` + +## Step 4: Evaluate + +``` +# 8 GPUs +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +python3 main.py \ + --flagfile ./logs/DDPM_CIFAR10_EPS/flagfile.txt \ + --notrain \ + --eval \ + --parallel + +# 1 GPU +export CUDA_VISIBLE_DEVICES=0 + +python3 main.py \ + --flagfile ./logs/DDPM_CIFAR10_EPS/flagfile.txt \ + --notrain \ + --eval +``` + + +## Results + +| Model | FPS (BI x 8) | metric | +| ------ |-------- |--------------:| +| DDPM | | | | + +## Reference +- [DDPM](https://github.com/w86763777/pytorch-ddpm/tree/master) \ No newline at end of file diff --git a/multimodal/diffusion/DDPM/pytorch-ddpm/config/CIFAR10.txt b/multimodal/diffusion/DDPM/pytorch-ddpm/config/CIFAR10.txt new file mode 100644 index 0000000000000000000000000000000000000000..44abefbf642af1901ae5647dc43cc638501eefae --- /dev/null +++ b/multimodal/diffusion/DDPM/pytorch-ddpm/config/CIFAR10.txt @@ -0,0 +1,32 @@ +--T=1000 +--attn=1 +--batch_size=128 +--beta_1=0.0001 +--beta_T=0.02 +--ch=128 +--ch_mult=1 +--ch_mult=2 +--ch_mult=2 +--ch_mult=2 +--dropout=0.1 +--ema_decay=0.9999 +--noeval +--eval_step=0 +--fid_cache=./stats/cifar10.train.npz +--nofid_use_torch +--grad_clip=1.0 +--img_size=32 +--logdir=./logs/DDPM_CIFAR10_EPS +--lr=0.0002 +--mean_type=epsilon +--num_images=50000 +--num_res_blocks=2 +--num_workers=4 +--noparallel +--sample_size=64 +--sample_step=1000 +--save_step=5000 +--total_steps=800000 +--train +--var_type=fixedlarge +--warmup=5000 diff --git a/multimodal/diffusion/DDPM/pytorch-ddpm/diffusion.py b/multimodal/diffusion/DDPM/pytorch-ddpm/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..a73dcc5b3e8fb5c1c472d2685ee3243dc44970f9 --- /dev/null +++ b/multimodal/diffusion/DDPM/pytorch-ddpm/diffusion.py @@ -0,0 +1,178 @@ +# Copyright (c) 2023, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def extract(v, t, x_shape): + """ + Extract some coefficients at specified timesteps, then reshape to + [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. + """ + out = torch.gather(v, index=t, dim=0).float() + return out.view([t.shape[0]] + [1] * (len(x_shape) - 1)) + + +class GaussianDiffusionTrainer(nn.Module): + def __init__(self, model, beta_1, beta_T, T): + super().__init__() + + self.model = model + self.T = T + + self.register_buffer( + 'betas', torch.linspace(beta_1, beta_T, T).float()) + alphas = 1. - self.betas + alphas_bar = torch.cumprod(alphas, dim=0) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + 'sqrt_alphas_bar', torch.sqrt(alphas_bar)) + self.register_buffer( + 'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar)) + + def forward(self, x_0): + """ + Algorithm 1. + """ + t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device) + noise = torch.randn_like(x_0) + x_t = ( + extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 + + extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise) + loss = F.mse_loss(self.model(x_t, t), noise, reduction='none') + return loss + + +class GaussianDiffusionSampler(nn.Module): + def __init__(self, model, beta_1, beta_T, T, img_size=32, + mean_type='eps', var_type='fixedlarge'): + assert mean_type in ['xprev' 'xstart', 'epsilon'] + assert var_type in ['fixedlarge', 'fixedsmall'] + super().__init__() + + self.model = model + self.T = T + self.img_size = img_size + self.mean_type = mean_type + self.var_type = var_type + + self.register_buffer( + 'betas', torch.linspace(beta_1, beta_T, T).float()) + alphas = 1. - self.betas + alphas_bar = torch.cumprod(alphas, dim=0) + alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T] + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + 'sqrt_recip_alphas_bar', torch.sqrt(1. / alphas_bar)) + self.register_buffer( + 'sqrt_recipm1_alphas_bar', torch.sqrt(1. / alphas_bar - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.register_buffer( + 'posterior_var', + self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar)) + # below: log calculation clipped because the posterior variance is 0 at + # the beginning of the diffusion chain + self.register_buffer( + 'posterior_log_var_clipped', + torch.log( + torch.cat([self.posterior_var[1:2], self.posterior_var[1:]]))) + self.register_buffer( + 'posterior_mean_coef1', + torch.sqrt(alphas_bar_prev) * self.betas / (1. - alphas_bar)) + self.register_buffer( + 'posterior_mean_coef2', + torch.sqrt(alphas) * (1. - alphas_bar_prev) / (1. - alphas_bar)) + + def q_mean_variance(self, x_0, x_t, t): + """ + Compute the mean and variance of the diffusion posterior + q(x_{t-1} | x_t, x_0) + """ + assert x_0.shape == x_t.shape + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_0 + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_log_var_clipped = extract( + self.posterior_log_var_clipped, t, x_t.shape) + return posterior_mean, posterior_log_var_clipped + + def predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + extract(self.sqrt_recip_alphas_bar, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_bar, t, x_t.shape) * eps + ) + + def predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( # (xprev - coef2*x_t) / coef1 + extract( + 1. / self.posterior_mean_coef1, t, x_t.shape) * xprev - + extract( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, + x_t.shape) * x_t + ) + + def p_mean_variance(self, x_t, t): + # below: only log_variance is used in the KL computations + model_log_var = { + # for fixedlarge, we set the initial (log-)variance like so to + # get a better decoder log likelihood + 'fixedlarge': torch.log(torch.cat([self.posterior_var[1:2], + self.betas[1:]])), + 'fixedsmall': self.posterior_log_var_clipped, + }[self.var_type] + model_log_var = extract(model_log_var, t, x_t.shape) + + # Mean parameterization + if self.mean_type == 'xprev': # the model predicts x_{t-1} + x_prev = self.model(x_t, t) + x_0 = self.predict_xstart_from_xprev(x_t, t, xprev=x_prev) + model_mean = x_prev + elif self.mean_type == 'xstart': # the model predicts x_0 + x_0 = self.model(x_t, t) + model_mean, _ = self.q_mean_variance(x_0, x_t, t) + elif self.mean_type == 'epsilon': # the model predicts epsilon + eps = self.model(x_t, t) + x_0 = self.predict_xstart_from_eps(x_t, t, eps=eps) + model_mean, _ = self.q_mean_variance(x_0, x_t, t) + else: + raise NotImplementedError(self.mean_type) + x_0 = torch.clip(x_0, -1., 1.) + + return model_mean, model_log_var + + def forward(self, x_T): + """ + Algorithm 2. + """ + x_t = x_T + for time_step in reversed(range(self.T)): + t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step + mean, log_var = self.p_mean_variance(x_t=x_t, t=t) + # no noise when t == 0 + if time_step > 0: + noise = torch.randn_like(x_t) + else: + noise = 0 + x_t = mean + torch.exp(0.5 * log_var) * noise + x_0 = x_t + return torch.clip(x_0, -1, 1) diff --git a/multimodal/diffusion/DDPM/pytorch-ddpm/images/cifar10_samples.png b/multimodal/diffusion/DDPM/pytorch-ddpm/images/cifar10_samples.png new file mode 100644 index 0000000000000000000000000000000000000000..8a64e127997d6729df4202d4087ec8a9c8bb65f9 Binary files /dev/null and b/multimodal/diffusion/DDPM/pytorch-ddpm/images/cifar10_samples.png differ diff --git a/multimodal/diffusion/DDPM/pytorch-ddpm/main.py b/multimodal/diffusion/DDPM/pytorch-ddpm/main.py new file mode 100644 index 0000000000000000000000000000000000000000..c6875501f94a33099c900232be9c7e68b4f20169 --- /dev/null +++ b/multimodal/diffusion/DDPM/pytorch-ddpm/main.py @@ -0,0 +1,254 @@ +import copy +import json +import os +import warnings +from absl import app, flags + +import torch +from tensorboardX import SummaryWriter +from torchvision.datasets import CIFAR10 +from torchvision.utils import make_grid, save_image +from torchvision import transforms +from tqdm import trange + +from diffusion import GaussianDiffusionTrainer, GaussianDiffusionSampler +from model import UNet +from score.both import get_inception_and_fid_score + + +FLAGS = flags.FLAGS +flags.DEFINE_bool('train', False, help='train from scratch') +flags.DEFINE_bool('eval', False, help='load ckpt.pt and evaluate FID and IS') +# UNet +flags.DEFINE_integer('ch', 128, help='base channel of UNet') +flags.DEFINE_multi_integer('ch_mult', [1, 2, 2, 2], help='channel multiplier') +flags.DEFINE_multi_integer('attn', [1], help='add attention to these levels') +flags.DEFINE_integer('num_res_blocks', 2, help='# resblock in each level') +flags.DEFINE_float('dropout', 0.1, help='dropout rate of resblock') +# Gaussian Diffusion +flags.DEFINE_float('beta_1', 1e-4, help='start beta value') +flags.DEFINE_float('beta_T', 0.02, help='end beta value') +flags.DEFINE_integer('T', 1000, help='total diffusion steps') +flags.DEFINE_enum('mean_type', 'epsilon', ['xprev', 'xstart', 'epsilon'], help='predict variable') +flags.DEFINE_enum('var_type', 'fixedlarge', ['fixedlarge', 'fixedsmall'], help='variance type') +# Training +flags.DEFINE_float('lr', 2e-4, help='target learning rate') +flags.DEFINE_float('grad_clip', 1., help="gradient norm clipping") +flags.DEFINE_integer('total_steps', 800000, help='total training steps') +flags.DEFINE_integer('img_size', 32, help='image size') +flags.DEFINE_integer('warmup', 5000, help='learning rate warmup') +flags.DEFINE_integer('batch_size', 128, help='batch size') +flags.DEFINE_integer('num_workers', 4, help='workers of Dataloader') +flags.DEFINE_float('ema_decay', 0.9999, help="ema decay rate") +flags.DEFINE_bool('parallel', False, help='multi gpu training') +# Logging & Sampling +flags.DEFINE_string('logdir', './logs/DDPM_CIFAR10_EPS', help='log directory') +flags.DEFINE_integer('sample_size', 64, "sampling size of images") +flags.DEFINE_integer('sample_step', 1000, help='frequency of sampling') +# Evaluation +flags.DEFINE_integer('save_step', 5000, help='frequency of saving checkpoints, 0 to disable during training') +flags.DEFINE_integer('eval_step', 0, help='frequency of evaluating model, 0 to disable during training') +flags.DEFINE_integer('num_images', 50000, help='the number of generated images for evaluation') +flags.DEFINE_bool('fid_use_torch', False, help='calculate IS and FID on gpu') +flags.DEFINE_string('fid_cache', './stats/cifar10.train.npz', help='FID cache') + +device = torch.device('cuda:0') + + +def ema(source, target, decay): + source_dict = source.state_dict() + target_dict = target.state_dict() + for key in source_dict.keys(): + target_dict[key].data.copy_( + target_dict[key].data * decay + + source_dict[key].data * (1 - decay)) + + +def infiniteloop(dataloader): + while True: + for x, y in iter(dataloader): + yield x + + +def warmup_lr(step): + return min(step, FLAGS.warmup) / FLAGS.warmup + + +def evaluate(sampler, model): + model.eval() + with torch.no_grad(): + images = [] + desc = "generating images" + for i in trange(0, FLAGS.num_images, FLAGS.batch_size, desc=desc): + batch_size = min(FLAGS.batch_size, FLAGS.num_images - i) + x_T = torch.randn((batch_size, 3, FLAGS.img_size, FLAGS.img_size)) + batch_images = sampler(x_T.to(device)).cpu() + images.append((batch_images + 1) / 2) + images = torch.cat(images, dim=0).numpy() + model.train() + (IS, IS_std), FID = get_inception_and_fid_score( + images, FLAGS.fid_cache, num_images=FLAGS.num_images, + use_torch=FLAGS.fid_use_torch, verbose=True) + return (IS, IS_std), FID, images + + +def train(): + # dataset + dataset = CIFAR10( + root='./data', train=True, download=True, + transform=transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ])) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=FLAGS.batch_size, shuffle=True, + num_workers=FLAGS.num_workers, drop_last=True) + datalooper = infiniteloop(dataloader) + + # model setup + net_model = UNet( + T=FLAGS.T, ch=FLAGS.ch, ch_mult=FLAGS.ch_mult, attn=FLAGS.attn, + num_res_blocks=FLAGS.num_res_blocks, dropout=FLAGS.dropout) + ema_model = copy.deepcopy(net_model) + optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr) + sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr) + trainer = GaussianDiffusionTrainer( + net_model, FLAGS.beta_1, FLAGS.beta_T, FLAGS.T).to(device) + net_sampler = GaussianDiffusionSampler( + net_model, FLAGS.beta_1, FLAGS.beta_T, FLAGS.T, FLAGS.img_size, + FLAGS.mean_type, FLAGS.var_type).to(device) + ema_sampler = GaussianDiffusionSampler( + ema_model, FLAGS.beta_1, FLAGS.beta_T, FLAGS.T, FLAGS.img_size, + FLAGS.mean_type, FLAGS.var_type).to(device) + if FLAGS.parallel: + trainer = torch.nn.DataParallel(trainer) + net_sampler = torch.nn.DataParallel(net_sampler) + ema_sampler = torch.nn.DataParallel(ema_sampler) + + # log setup + os.makedirs(os.path.join(FLAGS.logdir, 'sample')) + x_T = torch.randn(FLAGS.sample_size, 3, FLAGS.img_size, FLAGS.img_size) + x_T = x_T.to(device) + grid = (make_grid(next(iter(dataloader))[0][:FLAGS.sample_size]) + 1) / 2 + writer = SummaryWriter(FLAGS.logdir) + writer.add_image('real_sample', grid) + writer.flush() + # backup all arguments + with open(os.path.join(FLAGS.logdir, "flagfile.txt"), 'w') as f: + f.write(FLAGS.flags_into_string()) + # show model size + model_size = 0 + for param in net_model.parameters(): + model_size += param.data.nelement() + print('Model params: %.2f M' % (model_size / 1024 / 1024)) + + # start training + with trange(FLAGS.total_steps, dynamic_ncols=True) as pbar: + for step in pbar: + # train + optim.zero_grad() + x_0 = next(datalooper).to(device) + loss = trainer(x_0).mean() + loss.backward() + torch.nn.utils.clip_grad_norm_( + net_model.parameters(), FLAGS.grad_clip) + optim.step() + sched.step() + ema(net_model, ema_model, FLAGS.ema_decay) + + # log + writer.add_scalar('loss', loss, step) + pbar.set_postfix(loss='%.3f' % loss) + + # sample + if FLAGS.sample_step > 0 and step % FLAGS.sample_step == 0: + net_model.eval() + with torch.no_grad(): + x_0 = ema_sampler(x_T) + grid = (make_grid(x_0) + 1) / 2 + path = os.path.join( + FLAGS.logdir, 'sample', '%d.png' % step) + save_image(grid, path) + writer.add_image('sample', grid, step) + net_model.train() + + # save + if FLAGS.save_step > 0 and step % FLAGS.save_step == 0: + ckpt = { + 'net_model': net_model.state_dict(), + 'ema_model': ema_model.state_dict(), + 'sched': sched.state_dict(), + 'optim': optim.state_dict(), + 'step': step, + 'x_T': x_T, + } + torch.save(ckpt, os.path.join(FLAGS.logdir, 'ckpt.pt')) + + # evaluate + if FLAGS.eval_step > 0 and step % FLAGS.eval_step == 0: + net_IS, net_FID, _ = evaluate(net_sampler, net_model) + ema_IS, ema_FID, _ = evaluate(ema_sampler, ema_model) + metrics = { + 'IS': net_IS[0], + 'IS_std': net_IS[1], + 'FID': net_FID, + 'IS_EMA': ema_IS[0], + 'IS_std_EMA': ema_IS[1], + 'FID_EMA': ema_FID + } + pbar.write( + "%d/%d " % (step, FLAGS.total_steps) + + ", ".join('%s:%.3f' % (k, v) for k, v in metrics.items())) + for name, value in metrics.items(): + writer.add_scalar(name, value, step) + writer.flush() + with open(os.path.join(FLAGS.logdir, 'eval.txt'), 'a') as f: + metrics['step'] = step + f.write(json.dumps(metrics) + "\n") + writer.close() + + +def eval(): + # model setup + model = UNet( + T=FLAGS.T, ch=FLAGS.ch, ch_mult=FLAGS.ch_mult, attn=FLAGS.attn, + num_res_blocks=FLAGS.num_res_blocks, dropout=FLAGS.dropout) + sampler = GaussianDiffusionSampler( + model, FLAGS.beta_1, FLAGS.beta_T, FLAGS.T, img_size=FLAGS.img_size, + mean_type=FLAGS.mean_type, var_type=FLAGS.var_type).to(device) + if FLAGS.parallel: + sampler = torch.nn.DataParallel(sampler) + + # load model and evaluate + ckpt = torch.load(os.path.join(FLAGS.logdir, 'ckpt.pt')) + model.load_state_dict(ckpt['net_model']) + (IS, IS_std), FID, samples = evaluate(sampler, model) + print("Model : IS:%6.3f(%.3f), FID:%7.3f" % (IS, IS_std, FID)) + save_image( + torch.tensor(samples[:256]), + os.path.join(FLAGS.logdir, 'samples.png'), + nrow=16) + + model.load_state_dict(ckpt['ema_model']) + (IS, IS_std), FID, samples = evaluate(sampler, model) + print("Model(EMA): IS:%6.3f(%.3f), FID:%7.3f" % (IS, IS_std, FID)) + save_image( + torch.tensor(samples[:256]), + os.path.join(FLAGS.logdir, 'samples_ema.png'), + nrow=16) + + +def main(argv): + # suppress annoying inception_v3 initialization warning + warnings.simplefilter(action='ignore', category=FutureWarning) + if FLAGS.train: + train() + if FLAGS.eval: + eval() + if not FLAGS.train and not FLAGS.eval: + print('Add --train and/or --eval to execute corresponding tasks') + + +if __name__ == '__main__': + app.run(main) diff --git a/multimodal/diffusion/DDPM/pytorch-ddpm/model.py b/multimodal/diffusion/DDPM/pytorch-ddpm/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ac0c1b5f1020798297681673768dcd03e906236c --- /dev/null +++ b/multimodal/diffusion/DDPM/pytorch-ddpm/model.py @@ -0,0 +1,244 @@ +import math +import torch +from torch import nn +from torch.nn import init +from torch.nn import functional as F + + +class Swish(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class TimeEmbedding(nn.Module): + def __init__(self, T, d_model, dim): + assert d_model % 2 == 0 + super().__init__() + emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000) + emb = torch.exp(-emb) + pos = torch.arange(T).float() + emb = pos[:, None] * emb[None, :] + assert list(emb.shape) == [T, d_model // 2] + emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1) + assert list(emb.shape) == [T, d_model // 2, 2] + emb = emb.view(T, d_model) + + self.timembedding = nn.Sequential( + nn.Embedding.from_pretrained(emb), + nn.Linear(d_model, dim), + Swish(), + nn.Linear(dim, dim), + ) + self.initialize() + + def initialize(self): + for module in self.modules(): + if isinstance(module, nn.Linear): + init.xavier_uniform_(module.weight) + init.zeros_(module.bias) + + def forward(self, t): + emb = self.timembedding(t) + return emb + + +class DownSample(nn.Module): + def __init__(self, in_ch): + super().__init__() + self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1) + self.initialize() + + def initialize(self): + init.xavier_uniform_(self.main.weight) + init.zeros_(self.main.bias) + + def forward(self, x, temb): + x = self.main(x) + return x + + +class UpSample(nn.Module): + def __init__(self, in_ch): + super().__init__() + self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1) + self.initialize() + + def initialize(self): + init.xavier_uniform_(self.main.weight) + init.zeros_(self.main.bias) + + def forward(self, x, temb): + _, _, H, W = x.shape + x = F.interpolate( + x, scale_factor=2, mode='nearest') + x = self.main(x) + return x + + +class AttnBlock(nn.Module): + def __init__(self, in_ch): + super().__init__() + self.group_norm = nn.GroupNorm(32, in_ch) + self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) + self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) + self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) + self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) + self.initialize() + + def initialize(self): + for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]: + init.xavier_uniform_(module.weight) + init.zeros_(module.bias) + init.xavier_uniform_(self.proj.weight, gain=1e-5) + + def forward(self, x): + B, C, H, W = x.shape + h = self.group_norm(x) + q = self.proj_q(h) + k = self.proj_k(h) + v = self.proj_v(h) + + q = q.permute(0, 2, 3, 1).view(B, H * W, C) + k = k.view(B, C, H * W) + w = torch.bmm(q, k) * (int(C) ** (-0.5)) + assert list(w.shape) == [B, H * W, H * W] + w = F.softmax(w, dim=-1) + + v = v.permute(0, 2, 3, 1).view(B, H * W, C) + h = torch.bmm(w, v) + assert list(h.shape) == [B, H * W, C] + h = h.view(B, H, W, C).permute(0, 3, 1, 2) + h = self.proj(h) + + return x + h + + +class ResBlock(nn.Module): + def __init__(self, in_ch, out_ch, tdim, dropout, attn=False): + super().__init__() + self.block1 = nn.Sequential( + nn.GroupNorm(32, in_ch), + Swish(), + nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), + ) + self.temb_proj = nn.Sequential( + Swish(), + nn.Linear(tdim, out_ch), + ) + self.block2 = nn.Sequential( + nn.GroupNorm(32, out_ch), + Swish(), + nn.Dropout(dropout), + nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), + ) + if in_ch != out_ch: + self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) + else: + self.shortcut = nn.Identity() + if attn: + self.attn = AttnBlock(out_ch) + else: + self.attn = nn.Identity() + self.initialize() + + def initialize(self): + for module in self.modules(): + if isinstance(module, (nn.Conv2d, nn.Linear)): + init.xavier_uniform_(module.weight) + init.zeros_(module.bias) + init.xavier_uniform_(self.block2[-1].weight, gain=1e-5) + + def forward(self, x, temb): + h = self.block1(x) + h += self.temb_proj(temb)[:, :, None, None] + h = self.block2(h) + + h = h + self.shortcut(x) + h = self.attn(h) + return h + + +class UNet(nn.Module): + def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout): + super().__init__() + assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' + tdim = ch * 4 + self.time_embedding = TimeEmbedding(T, ch, tdim) + + self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) + self.downblocks = nn.ModuleList() + chs = [ch] # record output channel when dowmsample for upsample + now_ch = ch + for i, mult in enumerate(ch_mult): + out_ch = ch * mult + for _ in range(num_res_blocks): + self.downblocks.append(ResBlock( + in_ch=now_ch, out_ch=out_ch, tdim=tdim, + dropout=dropout, attn=(i in attn))) + now_ch = out_ch + chs.append(now_ch) + if i != len(ch_mult) - 1: + self.downblocks.append(DownSample(now_ch)) + chs.append(now_ch) + + self.middleblocks = nn.ModuleList([ + ResBlock(now_ch, now_ch, tdim, dropout, attn=True), + ResBlock(now_ch, now_ch, tdim, dropout, attn=False), + ]) + + self.upblocks = nn.ModuleList() + for i, mult in reversed(list(enumerate(ch_mult))): + out_ch = ch * mult + for _ in range(num_res_blocks + 1): + self.upblocks.append(ResBlock( + in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, + dropout=dropout, attn=(i in attn))) + now_ch = out_ch + if i != 0: + self.upblocks.append(UpSample(now_ch)) + assert len(chs) == 0 + + self.tail = nn.Sequential( + nn.GroupNorm(32, now_ch), + Swish(), + nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) + ) + self.initialize() + + def initialize(self): + init.xavier_uniform_(self.head.weight) + init.zeros_(self.head.bias) + init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) + init.zeros_(self.tail[-1].bias) + + def forward(self, x, t): + # Timestep embedding + temb = self.time_embedding(t) + # Downsampling + h = self.head(x) + hs = [h] + for layer in self.downblocks: + h = layer(h, temb) + hs.append(h) + # Middle + for layer in self.middleblocks: + h = layer(h, temb) + # Upsampling + for layer in self.upblocks: + if isinstance(layer, ResBlock): + h = torch.cat([h, hs.pop()], dim=1) + h = layer(h, temb) + h = self.tail(h) + + assert len(hs) == 0 + return h + + +if __name__ == '__main__': + batch_size = 8 + model = UNet( + T=1000, ch=128, ch_mult=[1, 2, 2, 2], attn=[1], + num_res_blocks=2, dropout=0.1) + x = torch.randn(batch_size, 3, 32, 32) + t = torch.randint(1000, (batch_size, )) + y = model(x, t) diff --git a/multimodal/diffusion/DDPM/pytorch-ddpm/requirements.txt b/multimodal/diffusion/DDPM/pytorch-ddpm/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..c63ce4b91983346b36b19bdbf845489626adfe1f --- /dev/null +++ b/multimodal/diffusion/DDPM/pytorch-ddpm/requirements.txt @@ -0,0 +1,4 @@ +absl-py==2.0.0 +scipy==1.5.4 +tensorboardX==2.1 +tqdm==4.55.1 diff --git a/multimodal/diffusion/DDPM/pytorch-ddpm/score/both.py b/multimodal/diffusion/DDPM/pytorch-ddpm/score/both.py new file mode 100644 index 0000000000000000000000000000000000000000..0b52fb40b165c2519aff795e4358f29dfd505b14 --- /dev/null +++ b/multimodal/diffusion/DDPM/pytorch-ddpm/score/both.py @@ -0,0 +1,115 @@ +import numpy as np +import torch +import types +from tqdm import tqdm + +from .inception import InceptionV3 +from .fid import calculate_frechet_distance, torch_cov + + +device = torch.device('cuda:0') + + +def get_inception_and_fid_score(images, fid_cache, num_images=None, + splits=10, batch_size=50, + use_torch=False, + verbose=False, + parallel=False): + """when `images` is a python generator, `num_images` should be given""" + + if num_images is None and isinstance(images, types.GeneratorType): + raise ValueError( + "when `images` is a python generator, " + "`num_images` should be given") + + if num_images is None: + num_images = len(images) + + block_idx1 = InceptionV3.BLOCK_INDEX_BY_DIM[2048] + block_idx2 = InceptionV3.BLOCK_INDEX_BY_DIM['prob'] + model = InceptionV3([block_idx1, block_idx2]).to(device) + model.eval() + + if parallel: + model = torch.nn.DataParallel(model) + + if use_torch: + fid_acts = torch.empty((num_images, 2048)).to(device) + is_probs = torch.empty((num_images, 1008)).to(device) + else: + fid_acts = np.empty((num_images, 2048)) + is_probs = np.empty((num_images, 1008)) + + iterator = iter(tqdm( + images, total=num_images, + dynamic_ncols=True, leave=False, disable=not verbose, + desc="get_inception_and_fid_score")) + + start = 0 + while True: + batch_images = [] + # get a batch of images from iterator + try: + for _ in range(batch_size): + batch_images.append(next(iterator)) + except StopIteration: + if len(batch_images) == 0: + break + pass + batch_images = np.stack(batch_images, axis=0) + end = start + len(batch_images) + + # calculate inception feature + batch_images = torch.from_numpy(batch_images).type(torch.FloatTensor) + batch_images = batch_images.to(device) + with torch.no_grad(): + pred = model(batch_images) + if use_torch: + fid_acts[start: end] = pred[0].view(-1, 2048) + is_probs[start: end] = pred[1] + else: + fid_acts[start: end] = pred[0].view(-1, 2048).cpu().numpy() + is_probs[start: end] = pred[1].cpu().numpy() + start = end + + # Inception Score + scores = [] + for i in range(splits): + part = is_probs[ + (i * is_probs.shape[0] // splits): + ((i + 1) * is_probs.shape[0] // splits), :] + if use_torch: + kl = part * ( + torch.log(part) - + torch.log(torch.unsqueeze(torch.mean(part, 0), 0))) + kl = torch.mean(torch.sum(kl, 1)) + scores.append(torch.exp(kl)) + else: + kl = part * ( + np.log(part) - + np.log(np.expand_dims(np.mean(part, 0), 0))) + kl = np.mean(np.sum(kl, 1)) + scores.append(np.exp(kl)) + if use_torch: + scores = torch.stack(scores) + is_score = (torch.mean(scores).cpu().item(), + torch.std(scores).cpu().item()) + else: + is_score = (np.mean(scores), np.std(scores)) + + # FID Score + f = np.load(fid_cache) + m2, s2 = f['mu'][:], f['sigma'][:] + f.close() + if use_torch: + m1 = torch.mean(fid_acts, axis=0) + s1 = torch_cov(fid_acts, rowvar=False) + m2 = torch.tensor(m2).to(m1.dtype).to(device) + s2 = torch.tensor(s2).to(s1.dtype).to(device) + else: + m1 = np.mean(fid_acts, axis=0) + s1 = np.cov(fid_acts, rowvar=False) + fid_score = calculate_frechet_distance(m1, s1, m2, s2, use_torch=use_torch) + + del fid_acts, is_probs, scores, model + return is_score, fid_score diff --git a/multimodal/diffusion/DDPM/pytorch-ddpm/score/fid.py b/multimodal/diffusion/DDPM/pytorch-ddpm/score/fid.py new file mode 100644 index 0000000000000000000000000000000000000000..2af064a741bca179da287215e7c2c122d89ea67e --- /dev/null +++ b/multimodal/diffusion/DDPM/pytorch-ddpm/score/fid.py @@ -0,0 +1,223 @@ +import numpy as np +import torch +from scipy import linalg +from tqdm import tqdm +from torch.nn.functional import adaptive_avg_pool2d + +from .inception import InceptionV3 + + +DIM = 2048 +device = torch.device('cuda:0') + + +def torch_cov(m, rowvar=False): + '''Estimate a covariance matrix given data. + Covariance indicates the level to which two variables vary together. + If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, + then the covariance matrix element `C_{ij}` is the covariance of + `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. + Args: + m: A 1-D or 2-D array containing multiple variables and observations. + Each row of `m` represents a variable, and each column a single + observation of all those variables. + rowvar: If `rowvar` is True, then each row represents a + variable, with observations in the columns. Otherwise, the + relationship is transposed: each column represents a variable, + while the rows contain observations. + Returns: + The covariance matrix of the variables. + ''' + if m.dim() > 2: + raise ValueError('m has more than 2 dimensions') + if m.dim() < 2: + m = m.view(1, -1) + if not rowvar and m.size(0) != 1: + m = m.t() + # m = m.type(torch.double) # uncomment this line if desired + fact = 1.0 / (m.size(1) - 1) + m -= torch.mean(m, dim=1, keepdim=True) + mt = m.t() # if complex: mt = m.t().conj() + return fact * m.matmul(mt).squeeze() + + +# Pytorch implementation of matrix sqrt, from Tsung-Yu Lin, and Subhransu Maji +# https://github.com/msubhransu/matrix-sqrt +def sqrt_newton_schulz(A, numIters, dtype=None): + with torch.no_grad(): + if dtype is None: + dtype = A.type() + batchSize = A.shape[0] + dim = A.shape[1] + normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt() + Y = A.div(normA.view(batchSize, 1, 1).expand_as(A)) + K = torch.eye(dim, dim).view(1, dim, dim).repeat(batchSize, 1, 1) + Z = torch.eye(dim, dim).view(1, dim, dim).repeat(batchSize, 1, 1) + K = K.type(dtype) + Z = Z.type(dtype) + for i in range(numIters): + T = 0.5 * (3.0 * K - Z.bmm(Y)) + Y = Y.bmm(T) + Z = T.bmm(Z) + sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A) + return sA + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6, + use_torch=False): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + + Returns: + -- : The Frechet Distance. + """ + + if use_torch: + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + # Run 50 itrs of newton-schulz to get the matrix sqrt of + # sigma1 dot sigma2 + covmean = sqrt_newton_schulz(sigma1.mm(sigma2).unsqueeze(0), 50) + if torch.any(torch.isnan(covmean)): + return float('nan') + covmean = covmean.squeeze() + out = (diff.dot(diff) + + torch.trace(sigma1) + + torch.trace(sigma2) - + 2 * torch.trace(covmean)).cpu().item() + else: + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + out = (diff.dot(diff) + + np.trace(sigma1) + + np.trace(sigma2) - + 2 * tr_covmean) + return out + + +def get_statistics(images, num_images=None, batch_size=50, use_torch=False, + verbose=False, parallel=False): + """when `images` is a python generator, `num_images` should be given""" + + if num_images is None: + try: + num_images = len(images) + except: + raise ValueError( + "when `images` is not a list like object (e.g. generator), " + "`num_images` should be given") + + block_idx1 = InceptionV3.BLOCK_INDEX_BY_DIM[2048] + model = InceptionV3([block_idx1]).to(device) + model.eval() + + if parallel: + model = torch.nn.DataParallel(model) + + if use_torch: + fid_acts = torch.empty((num_images, 2048)).to(device) + else: + fid_acts = np.empty((num_images, 2048)) + + iterator = iter(tqdm( + images, total=num_images, + dynamic_ncols=True, leave=False, disable=not verbose, + desc="get_inception_and_fid_score")) + + start = 0 + while True: + batch_images = [] + # get a batch of images from iterator + try: + for _ in range(batch_size): + batch_images.append(next(iterator)) + except StopIteration: + if len(batch_images) == 0: + break + pass + batch_images = np.stack(batch_images, axis=0) + end = start + len(batch_images) + + # calculate inception feature + batch_images = torch.from_numpy(batch_images).type(torch.FloatTensor) + batch_images = batch_images.to(device) + with torch.no_grad(): + pred = model(batch_images) + if use_torch: + fid_acts[start: end] = pred[0].view(-1, 2048) + else: + fid_acts[start: end] = pred[0].view(-1, 2048).cpu().numpy() + start = end + + if use_torch: + m1 = torch.mean(fid_acts, axis=0) + s1 = torch_cov(fid_acts, rowvar=False) + else: + m1 = np.mean(fid_acts, axis=0) + s1 = np.cov(fid_acts, rowvar=False) + return m1, s1 + + +def get_fid_score(stats_cache, images, num_images=None, batch_size=50, + use_torch=False, verbose=False, parallel=False): + m1, s1 = get_statistics( + images, num_images, batch_size, use_torch, verbose, parallel) + + f = np.load(stats_cache) + m2, s2 = f['mu'][:], f['sigma'][:] + f.close() + if use_torch: + m2 = torch.tensor(m2).to(m1.dtype) + s2 = torch.tensor(s2).to(s1.dtype) + fid_value = calculate_frechet_distance(m1, s1, m2, s2, use_torch=use_torch) + + if use_torch: + fid_value = fid_value.cpu().item() + return fid_value diff --git a/multimodal/diffusion/DDPM/pytorch-ddpm/score/inception.py b/multimodal/diffusion/DDPM/pytorch-ddpm/score/inception.py new file mode 100644 index 0000000000000000000000000000000000000000..5c9277d12643c7c34b2370aa7877709aed253765 --- /dev/null +++ b/multimodal/diffusion/DDPM/pytorch-ddpm/score/inception.py @@ -0,0 +1,324 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import models + +try: + from torchvision.models.utils import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# Inception weights ported to Pytorch from +# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' + + +class InceptionV3(nn.Module): + """Pretrained InceptionV3 network returning feature maps""" + + # Index of default block of inception to return, + # corresponds to output of final average pooling + DEFAULT_BLOCK_INDEX = 3 + + # Maps feature dimensionality to their output blocks indices + BLOCK_INDEX_BY_DIM = { + 64: 0, # First max pooling features + 192: 1, # Second max pooling featurs + 768: 2, # Pre-aux classifier features + 2048: 3, # Final average pooling features + 'prob': 4, # softmax layer + } + + def __init__(self, + output_blocks=[DEFAULT_BLOCK_INDEX], + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True): + """Build pretrained InceptionV3 + + Parameters + ---------- + output_blocks : list of int + Indices of blocks to return features of. Possible values are: + - 0: corresponds to output of first max pooling + - 1: corresponds to output of second max pooling + - 2: corresponds to output which is fed to aux classifier + - 3: corresponds to output of final average pooling + resize_input : bool + If true, bilinearly resizes input to width and height 299 before + feeding input to model. As the network without fully connected + layers is fully convolutional, it should be able to handle inputs + of arbitrary size, so resizing might not be strictly needed + normalize_input : bool + If true, scales the input from range (0, 1) to the range the + pretrained Inception network expects, namely (-1, 1) + requires_grad : bool + If true, parameters of the model require gradients. Possibly useful + for finetuning the network + use_fid_inception : bool + If true, uses the pretrained Inception model used in Tensorflow's + FID implementation. If false, uses the pretrained Inception model + available in torchvision. The FID Inception model has different + weights and a slightly different structure from torchvision's + Inception model. If you want to compute FID scores, you are + strongly advised to set this parameter to true to get comparable + results. + """ + super(InceptionV3, self).__init__() + + self.resize_input = resize_input + self.normalize_input = normalize_input + self.output_blocks = sorted(output_blocks) + self.last_needed_block = max(output_blocks) + + # assert self.last_needed_block <= 3, \ + # 'Last possible output block index is 3' + + self.blocks = nn.ModuleList() + + if use_fid_inception: + inception = fid_inception_v3() + else: + inception = models.inception_v3(pretrained=True) + + # Block 0: input to maxpool1 + block0 = [ + inception.Conv2d_1a_3x3, + inception.Conv2d_2a_3x3, + inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block0)) + + # Block 1: maxpool1 to maxpool2 + if self.last_needed_block >= 1: + block1 = [ + inception.Conv2d_3b_1x1, + inception.Conv2d_4a_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block1)) + + # Block 2: maxpool2 to aux classifier + if self.last_needed_block >= 2: + block2 = [ + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e, + ] + self.blocks.append(nn.Sequential(*block2)) + + # Block 3: aux classifier to final avgpool + if self.last_needed_block >= 3: + block3 = [ + inception.Mixed_7a, + inception.Mixed_7b, + inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1)) + ] + self.blocks.append(nn.Sequential(*block3)) + + if self.last_needed_block >= 4: + self.fc = inception.fc + self.fc.bias = None + + for param in self.parameters(): + param.requires_grad = requires_grad + + def forward(self, inp): + """Get Inception feature maps + + Parameters + ---------- + inp : torch.autograd.Variable + Input tensor of shape Bx3xHxW. Values are expected to be in + range (0, 1) + + Returns + ------- + List of torch.autograd.Variable, corresponding to the selected output + block, sorted ascending by index + """ + outp = [] + x = inp + + if self.resize_input: + x = F.interpolate(x, + size=(299, 299), + mode='bilinear', + align_corners=False) + + if self.normalize_input: + x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) + + for idx, block in enumerate(self.blocks): + x = block(x) + if idx in self.output_blocks: + outp.append(x) + + if idx == self.last_needed_block: + break + + if self.last_needed_block >= 4: + x = F.dropout(x, training=self.training) + # N x 2048 x 1 x 1 + x = torch.flatten(x, 1) + # N x 2048 + x = self.fc(x) + x = F.softmax(x, dim=1) + outp.append(x) + + return outp + + +def fid_inception_v3(): + """Build pretrained Inception model for FID computation + + The Inception model for FID computation uses a different set of weights + and has a slightly different structure than torchvision's Inception. + + This method first constructs torchvision's Inception and then patches the + necessary parts that are different in the FID Inception model. + """ + inception = models.inception_v3(num_classes=1008, + aux_logits=False, + pretrained=False) + inception.Mixed_5b = FIDInceptionA(192, pool_features=32) + inception.Mixed_5c = FIDInceptionA(256, pool_features=64) + inception.Mixed_5d = FIDInceptionA(288, pool_features=64) + inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) + inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) + inception.Mixed_7b = FIDInceptionE_1(1280) + inception.Mixed_7c = FIDInceptionE_2(2048) + + state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) + inception.load_state_dict(state_dict) + return inception + + +class FIDInceptionA(models.inception.InceptionA): + """InceptionA block patched for FID computation""" + def __init__(self, in_channels, pool_features): + super(FIDInceptionA, self).__init__(in_channels, pool_features) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionC(models.inception.InceptionC): + """InceptionC block patched for FID computation""" + def __init__(self, in_channels, channels_7x7): + super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_1(models.inception.InceptionE): + """First InceptionE block patched for FID computation""" + def __init__(self, in_channels): + super(FIDInceptionE_1, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_2(models.inception.InceptionE): + """Second InceptionE block patched for FID computation""" + def __init__(self, in_channels): + super(FIDInceptionE_2, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: The FID Inception model uses max pooling instead of average + # pooling. This is likely an error in this specific Inception + # implementation, as other Inception models use average pooling here + # (which matches the description in the paper). + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) diff --git a/multimodal/diffusion/DDPM/pytorch-ddpm/score/inception_score.py b/multimodal/diffusion/DDPM/pytorch-ddpm/score/inception_score.py new file mode 100644 index 0000000000000000000000000000000000000000..2bb4ad384102cca4537e4cd552f8bed415a6965c --- /dev/null +++ b/multimodal/diffusion/DDPM/pytorch-ddpm/score/inception_score.py @@ -0,0 +1,64 @@ +import numpy as np +import torch +from tqdm import trange + +from .inception import InceptionV3 + + +device = torch.device('cuda:0') + + +def get_inception_score(images, splits=10, batch_size=32, use_torch=False, + verbose=False, parallel=False): + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM['prob'] + model = InceptionV3([block_idx]).to(device) + model.eval() + + if parallel: + model = torch.nn.DataParallel(model) + + preds = [] + iterator = trange( + 0, len(images), batch_size, dynamic_ncols=True, leave=False, + disable=not verbose, desc="get_inception_score") + + for start in iterator: + end = start + batch_size + batch_images = images[start: end] + batch_images = torch.from_numpy(batch_images).type(torch.FloatTensor) + batch_images = batch_images.to(device) + with torch.no_grad(): + pred = model(batch_images)[0] + if use_torch: + preds.append(pred) + else: + preds.append(pred.cpu().numpy()) + if use_torch: + preds = torch.cat(preds, 0) + else: + preds = np.concatenate(preds, 0) + scores = [] + for i in range(splits): + part = preds[ + (i * preds.shape[0] // splits): + ((i + 1) * preds.shape[0] // splits), :] + if use_torch: + kl = part * ( + torch.log(part) - + torch.log(torch.unsqueeze(torch.mean(part, 0), 0))) + kl = torch.mean(torch.sum(kl, 1)) + scores.append(torch.exp(kl)) + else: + kl = part * ( + np.log(part) - + np.log(np.expand_dims(np.mean(part, 0), 0))) + kl = np.mean(np.sum(kl, 1)) + scores.append(np.exp(kl)) + if use_torch: + scores = torch.stack(scores) + is_mean, is_std = ( + torch.mean(scores).cpu().item(), torch.std(scores).cpu().item()) + else: + is_mean, is_std = np.mean(scores), np.std(scores) + del preds, scores, model + return is_mean, is_std