# ddae **Repository Path**: y_eeeeee/ddae ## Basic Information - **Project Name**: ddae - **Description**: No description available - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: main - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2024-06-13 - **Last Updated**: 2024-06-13 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # Denoising Diffusion Autoencoders (DDAE)

This is a multi-gpu PyTorch implementation of the paper [Denoising Diffusion Autoencoders are Unified Self-supervised Learners](https://arxiv.org/abs/2303.09769): ```bibtex @inproceedings{ddae2023, title={Denoising Diffusion Autoencoders are Unified Self-supervised Learners}, author={Xiang, Weilai and Yang, Hongyu and Huang, Di and Wang, Yunhong}, booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, year={2023} } ``` :star: (News) Our paper is cited by Kaiming He's new paper [Deconstructing Denoising Diffusion Models for Self-Supervised Learning](https://arxiv.org/abs/2401.14404), check it out! :fire: ## Overview This repo contains: - [x] Pre-training, sampling and FID evaluation code for diffusion models, including - Frameworks: - [x] DDPM & DDIM - [x] EDM (w/ or w/o data augmentation) - Networks: - [x] The basic 35.7M DDPM UNet - [x] A larger 56M DDPM++ UNet - Datasets: - [x] CIFAR-10 - [ ] Tiny-ImageNet - [x] Feature quality evaluation code, including - [x] Linear probing and grid searching - [x] Contrastive metrics, i.e., alignment and uniformity - [ ] Fine-tuning - [x] Noise-conditional classifier training and evaluation, including - [x] MLP classifier based on DDPM/EDM features - [x] WideResNet with VP/VE perturbation - [x] Evaluation code for ImageNet-256 pre-trained [DiT-XL/2](https://github.com/facebookresearch/DiT) checkpoint ## Requirements - In addition to PyTorch environments, please install: ```sh conda install pyyaml pip install pytorch-fid ema-pytorch ``` - We use 4 or 8 3080ti GPUs to conduct all the experiments presented in the paper. With automatic mixed precision enabled and 4 GPUs, training a basic 35.7M UNet on CIFAR-10 takes ~14 hours. - The `pytorch-fid` requires image files to calculate the FID metric. Please refer to `extract_cifar10_pngs.ipynb` to unpack the CIFAR-10 training dataset into 50000 `.png` image files. ## Main results We present the generative and discriminative evaluation results that can be obtained by this codebase. The `EDM_ddpmpp_aug.yaml` training is performed on 8 GPUs, while other models are trained on 4 GPUs. Please note that this is a *over-simplified* DDPM / EDM implementation, and some network details, initialization, and hyper-parameters may *differ from* official ones. Please refer to their respective official codebases to reproduce the *exact results* reported in the paper.
Config Model Network Best linear probe checkpoint Best FID checkpoint
epoch FID acc epoch FID acc
DDPM_ddpm.yaml DDPM 35.7M UNet 800 4.09 90.05 1999 3.62 88.23
EDM_ddpm.yaml EDM 35.7M UNet 1200 3.97 90.44 1999 3.56 89.71
DDPM_ddpmpp.yaml DDPM 56.5M DDPM++ 1200 3.08 93.97 1999 2.98 93.03
EDM_ddpmpp.yaml EDM 56.5M DDPM++ 1200 2.23 94.50 (same)
EDM_ddpmpp_aug.yaml EDM + data aug 56.5M DDPM++ 2000 2.34 95.49 3200 2.12 95.19
FIDs are calculated using 50000 images generated by the deterministic fast sampler (DDIM 100 steps or EDM 18 steps). ## Latent-space DiT We evaluate pre-trained Transformer-based diffusion networks, [DiT](https://github.com/facebookresearch/DiT), from the perspective of *transfer learning*. Please refer to the [ddae/DiT](DiT/) subfolder. ## Usage ### Diffusion pre-training To train a DDAE model and generate 50000 image samples with 4 GPUs, for example, run: ```sh python -m torch.distributed.launch --nproc_per_node=4 # diffusion pre-training with AMP enabled train.py --config config/DDPM_ddpm.yaml --use_amp # deterministic fast sampling (i.e. DDIM 100 steps / EDM 18 steps) sample.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400 # stochastic sampling (i.e. DDPM 1000 steps) sample.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400 --mode DDPM ``` To calculate the FID metric on the training set, for example, run: ```sh python -m pytorch_fid data/cifar10-pngs/ output_DDPM_ddpm/EMAgenerated_ep400_ddim_steps100_eta0.0/pngs/ ``` ### Features produced by DDAE To evaluate the features produced by pre-trained DDAE, for example, run: ```sh python -m torch.distributed.launch --nproc_per_node=4 # grid searching for proper layer-noise combination linear.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400 --grid # linear probing, using the layer-noise combination specified by config.yaml linear.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400 # showing the alignment-uniformity metrics with respect to different checkpoints contrastive.py --config config/DDPM_ddpm.yaml --use_amp ``` ### Noise-conditional classifier To train WideResNet-based classifiers from scratch: ```sh python -m torch.distributed.launch --nproc_per_node=4 # VP (DDPM) perturbation noisy_classifier_WRN.py --mode DDPM # VE (EDM) perturbation noisy_classifier_WRN.py --mode EDM ``` and compare their noise-conditional recognition rates with DDAE-based MLP classifier heads: ```sh python -m torch.distributed.launch --nproc_per_node=4 # using DDPM DDAE encoder noisy_classifier_DDAE.py --config config/DDPM_ddpm.yaml --use_amp --epoch 1999 # using EDM DDAE encoder noisy_classifier_DDAE.py --config config/EDM_ddpmpp.yaml --use_amp --epoch 1200 ``` ## Acknowledgments This repository is built on numerous open-source codebases such as [DDPM](https://github.com/hojonathanho/diffusion), [DDPM-pytorch](https://github.com/pesser/pytorch_diffusion), [DDIM](https://github.com/ermongroup/ddim), [EDM](https://github.com/NVlabs/edm), [Score-based SDE](https://github.com/yang-song/score_sde), [DiT](https://github.com/facebookresearch/DiT), and [align_uniform](https://github.com/SsnL/align_uniform).