# ddae
**Repository Path**: y_eeeeee/ddae
## Basic Information
- **Project Name**: ddae
- **Description**: No description available
- **Primary Language**: Unknown
- **License**: Not specified
- **Default Branch**: main
- **Homepage**: None
- **GVP Project**: No
## Statistics
- **Stars**: 0
- **Forks**: 0
- **Created**: 2024-06-13
- **Last Updated**: 2024-06-13
## Categories & Tags
**Categories**: Uncategorized
**Tags**: None
## README
# Denoising Diffusion Autoencoders (DDAE)
This is a multi-gpu PyTorch implementation of the paper [Denoising Diffusion Autoencoders are Unified Self-supervised Learners](https://arxiv.org/abs/2303.09769):
```bibtex
@inproceedings{ddae2023,
title={Denoising Diffusion Autoencoders are Unified Self-supervised Learners},
author={Xiang, Weilai and Yang, Hongyu and Huang, Di and Wang, Yunhong},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
year={2023}
}
```
:star: (News) Our paper is cited by Kaiming He's new paper [Deconstructing Denoising Diffusion Models for Self-Supervised Learning](https://arxiv.org/abs/2401.14404), check it out! :fire:
## Overview
This repo contains:
- [x] Pre-training, sampling and FID evaluation code for diffusion models, including
- Frameworks:
- [x] DDPM & DDIM
- [x] EDM (w/ or w/o data augmentation)
- Networks:
- [x] The basic 35.7M DDPM UNet
- [x] A larger 56M DDPM++ UNet
- Datasets:
- [x] CIFAR-10
- [ ] Tiny-ImageNet
- [x] Feature quality evaluation code, including
- [x] Linear probing and grid searching
- [x] Contrastive metrics, i.e., alignment and uniformity
- [ ] Fine-tuning
- [x] Noise-conditional classifier training and evaluation, including
- [x] MLP classifier based on DDPM/EDM features
- [x] WideResNet with VP/VE perturbation
- [x] Evaluation code for ImageNet-256 pre-trained [DiT-XL/2](https://github.com/facebookresearch/DiT) checkpoint
## Requirements
- In addition to PyTorch environments, please install:
```sh
conda install pyyaml
pip install pytorch-fid ema-pytorch
```
- We use 4 or 8 3080ti GPUs to conduct all the experiments presented in the paper. With automatic mixed precision enabled and 4 GPUs, training a basic 35.7M UNet on CIFAR-10 takes ~14 hours.
- The `pytorch-fid` requires image files to calculate the FID metric. Please refer to `extract_cifar10_pngs.ipynb` to unpack the CIFAR-10 training dataset into 50000 `.png` image files.
## Main results
We present the generative and discriminative evaluation results that can be obtained by this codebase. The `EDM_ddpmpp_aug.yaml` training is performed on 8 GPUs, while other models are trained on 4 GPUs.
Please note that this is a *over-simplified* DDPM / EDM implementation, and some network details, initialization, and hyper-parameters may *differ from* official ones. Please refer to their respective official codebases to reproduce the *exact results* reported in the paper.
| Config |
Model |
Network |
Best linear probe checkpoint |
Best FID checkpoint |
| epoch |
FID |
acc |
epoch |
FID |
acc |
| DDPM_ddpm.yaml |
DDPM |
35.7M UNet |
800 |
4.09 |
90.05 |
1999 |
3.62 |
88.23 |
| EDM_ddpm.yaml |
EDM |
35.7M UNet |
1200 |
3.97 |
90.44 |
1999 |
3.56 |
89.71 |
| DDPM_ddpmpp.yaml |
DDPM |
56.5M DDPM++ |
1200 |
3.08 |
93.97 |
1999 |
2.98 |
93.03 |
| EDM_ddpmpp.yaml |
EDM |
56.5M DDPM++ |
1200 |
2.23 |
94.50 |
(same) |
| EDM_ddpmpp_aug.yaml |
EDM + data aug |
56.5M DDPM++ |
2000 |
2.34 |
95.49 |
3200 |
2.12 |
95.19 |
FIDs are calculated using 50000 images generated by the deterministic fast sampler (DDIM 100 steps or EDM 18 steps).
## Latent-space DiT
We evaluate pre-trained Transformer-based diffusion networks, [DiT](https://github.com/facebookresearch/DiT), from the perspective of *transfer learning*. Please refer to the [ddae/DiT](DiT/) subfolder.
## Usage
### Diffusion pre-training
To train a DDAE model and generate 50000 image samples with 4 GPUs, for example, run:
```sh
python -m torch.distributed.launch --nproc_per_node=4
# diffusion pre-training with AMP enabled
train.py --config config/DDPM_ddpm.yaml --use_amp
# deterministic fast sampling (i.e. DDIM 100 steps / EDM 18 steps)
sample.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400
# stochastic sampling (i.e. DDPM 1000 steps)
sample.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400 --mode DDPM
```
To calculate the FID metric on the training set, for example, run:
```sh
python -m pytorch_fid data/cifar10-pngs/ output_DDPM_ddpm/EMAgenerated_ep400_ddim_steps100_eta0.0/pngs/
```
### Features produced by DDAE
To evaluate the features produced by pre-trained DDAE, for example, run:
```sh
python -m torch.distributed.launch --nproc_per_node=4
# grid searching for proper layer-noise combination
linear.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400 --grid
# linear probing, using the layer-noise combination specified by config.yaml
linear.py --config config/DDPM_ddpm.yaml --use_amp --epoch 400
# showing the alignment-uniformity metrics with respect to different checkpoints
contrastive.py --config config/DDPM_ddpm.yaml --use_amp
```
### Noise-conditional classifier
To train WideResNet-based classifiers from scratch:
```sh
python -m torch.distributed.launch --nproc_per_node=4
# VP (DDPM) perturbation
noisy_classifier_WRN.py --mode DDPM
# VE (EDM) perturbation
noisy_classifier_WRN.py --mode EDM
```
and compare their noise-conditional recognition rates with DDAE-based MLP classifier heads:
```sh
python -m torch.distributed.launch --nproc_per_node=4
# using DDPM DDAE encoder
noisy_classifier_DDAE.py --config config/DDPM_ddpm.yaml --use_amp --epoch 1999
# using EDM DDAE encoder
noisy_classifier_DDAE.py --config config/EDM_ddpmpp.yaml --use_amp --epoch 1200
```
## Acknowledgments
This repository is built on numerous open-source codebases such as [DDPM](https://github.com/hojonathanho/diffusion), [DDPM-pytorch](https://github.com/pesser/pytorch_diffusion), [DDIM](https://github.com/ermongroup/ddim), [EDM](https://github.com/NVlabs/edm), [Score-based SDE](https://github.com/yang-song/score_sde), [DiT](https://github.com/facebookresearch/DiT), and [align_uniform](https://github.com/SsnL/align_uniform).