From cc1924035f0dc9a376189255bfc3e5b99274e952 Mon Sep 17 00:00:00 2001 From: liuqiyuan Date: Sat, 11 May 2024 11:28:19 +0800 Subject: [PATCH 1/2] Add DiT source code --- PyTorch/built-in/mlm/DiT/CODE_OF_CONDUCT.md | 80 ++ PyTorch/built-in/mlm/DiT/CONTRIBUTING.md | 34 + PyTorch/built-in/mlm/DiT/LICENSE.txt | 400 ++++++++ PyTorch/built-in/mlm/DiT/README.md | 163 ++++ .../built-in/mlm/DiT/diffusion/__init__.py | 46 + .../mlm/DiT/diffusion/diffusion_utils.py | 88 ++ .../mlm/DiT/diffusion/gaussian_diffusion.py | 873 ++++++++++++++++++ PyTorch/built-in/mlm/DiT/diffusion/respace.py | 129 +++ .../mlm/DiT/diffusion/timestep_sampler.py | 150 +++ PyTorch/built-in/mlm/DiT/download.py | 50 + PyTorch/built-in/mlm/DiT/environment.yml | 13 + PyTorch/built-in/mlm/DiT/models.py | 370 ++++++++ PyTorch/built-in/mlm/DiT/sample.py | 83 ++ PyTorch/built-in/mlm/DiT/sample_ddp.py | 166 ++++ PyTorch/built-in/mlm/DiT/train.py | 269 ++++++ 15 files changed, 2914 insertions(+) create mode 100644 PyTorch/built-in/mlm/DiT/CODE_OF_CONDUCT.md create mode 100644 PyTorch/built-in/mlm/DiT/CONTRIBUTING.md create mode 100644 PyTorch/built-in/mlm/DiT/LICENSE.txt create mode 100644 PyTorch/built-in/mlm/DiT/README.md create mode 100644 PyTorch/built-in/mlm/DiT/diffusion/__init__.py create mode 100644 PyTorch/built-in/mlm/DiT/diffusion/diffusion_utils.py create mode 100644 PyTorch/built-in/mlm/DiT/diffusion/gaussian_diffusion.py create mode 100644 PyTorch/built-in/mlm/DiT/diffusion/respace.py create mode 100644 PyTorch/built-in/mlm/DiT/diffusion/timestep_sampler.py create mode 100644 PyTorch/built-in/mlm/DiT/download.py create mode 100644 PyTorch/built-in/mlm/DiT/environment.yml create mode 100644 PyTorch/built-in/mlm/DiT/models.py create mode 100644 PyTorch/built-in/mlm/DiT/sample.py create mode 100644 PyTorch/built-in/mlm/DiT/sample_ddp.py create mode 100644 PyTorch/built-in/mlm/DiT/train.py diff --git a/PyTorch/built-in/mlm/DiT/CODE_OF_CONDUCT.md b/PyTorch/built-in/mlm/DiT/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000..3232ed6655 --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/PyTorch/built-in/mlm/DiT/CONTRIBUTING.md b/PyTorch/built-in/mlm/DiT/CONTRIBUTING.md new file mode 100644 index 0000000000..b45bfbaa59 --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/CONTRIBUTING.md @@ -0,0 +1,34 @@ +# Contributing to DiT +We want to make contributing to this project as easy and transparent as +possible. + +## Our Development Process +Work on the `DiT` repo has mostly concluded. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Meta's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to `DiT`, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/PyTorch/built-in/mlm/DiT/LICENSE.txt b/PyTorch/built-in/mlm/DiT/LICENSE.txt new file mode 100644 index 0000000000..a115f899f8 --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/LICENSE.txt @@ -0,0 +1,400 @@ + +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/PyTorch/built-in/mlm/DiT/README.md b/PyTorch/built-in/mlm/DiT/README.md new file mode 100644 index 0000000000..1337ab3468 --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/README.md @@ -0,0 +1,163 @@ +## Scalable Diffusion Models with Transformers (DiT)
Official PyTorch Implementation + +### [Paper](http://arxiv.org/abs/2212.09748) | [Project Page](https://www.wpeebles.com/DiT) | Run DiT-XL/2 [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/wpeebles/DiT) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb) + +![DiT samples](visuals/sample_grid_0.png) + +This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring +diffusion models with transformers (DiTs). You can find more visualizations on our [project page](https://www.wpeebles.com/DiT). + +> [**Scalable Diffusion Models with Transformers**](https://www.wpeebles.com/DiT)
+> [William Peebles](https://www.wpeebles.com), [Saining Xie](https://www.sainingxie.com) +>
UC Berkeley, New York University
+ +We train latent diffusion models, replacing the commonly-used U-Net backbone with a transformer that operates on +latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass +complexity as measured by Gflops. We find that DiTs with higher Gflops---through increased transformer depth/width or +increased number of input tokens---consistently have lower FID. In addition to good scalability properties, our +DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512×512 and 256×256 benchmarks, +achieving a state-of-the-art FID of 2.27 on the latter. + +This repository contains: + +* 🪐 A simple PyTorch [implementation](models.py) of DiT +* ⚡️ Pre-trained class-conditional DiT models trained on ImageNet (512x512 and 256x256) +* 💥 A self-contained [Hugging Face Space](https://huggingface.co/spaces/wpeebles/DiT) and [Colab notebook](http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb) for running pre-trained DiT-XL/2 models +* 🛸 A DiT [training script](train.py) using PyTorch DDP + +An implementation of DiT directly in Hugging Face `diffusers` can also be found [here](https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/pipelines/dit.mdx). + + +## Setup + +First, download and set up the repo: + +```bash +git clone https://github.com/facebookresearch/DiT.git +cd DiT +``` + +We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want +to run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the file. + +```bash +conda env create -f environment.yml +conda activate DiT +``` + + +## Sampling [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/wpeebles/DiT) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb) +![More DiT samples](visuals/sample_grid_1.png) + +**Pre-trained DiT checkpoints.** You can sample from our pre-trained DiT models with [`sample.py`](sample.py). Weights for our pre-trained DiT model will be +automatically downloaded depending on the model you use. The script has various arguments to switch between the 256x256 +and 512x512 models, adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from +our 512x512 DiT-XL/2 model, you can use: + +```bash +python sample.py --image-size 512 --seed 1 +``` + +For convenience, our pre-trained DiT models can be downloaded directly here as well: + +| DiT Model | Image Resolution | FID-50K | Inception Score | Gflops | +|---------------|------------------|---------|-----------------|--------| +| [XL/2](https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt) | 256x256 | 2.27 | 278.24 | 119 | +| [XL/2](https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt) | 512x512 | 3.04 | 240.82 | 525 | + + +**Custom DiT checkpoints.** If you've trained a new DiT model with [`train.py`](train.py) (see [below](#training-dit)), you can add the `--ckpt` +argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom +256x256 DiT-L/4 model, run: + +```bash +python sample.py --model DiT-L/4 --image-size 256 --ckpt /path/to/model.pt +``` + + +## Training DiT + +We provide a training script for DiT in [`train.py`](train.py). This script can be used to train class-conditional +DiT models, but it can be easily modified to support other types of conditioning. To launch DiT-XL/2 (256x256) training with `N` GPUs on +one node: + +```bash +torchrun --nnodes=1 --nproc_per_node=N train.py --model DiT-XL/2 --data-path /path/to/imagenet/train +``` + +### PyTorch Training Results + +We've trained DiT-XL/2 and DiT-B/4 models from scratch with the PyTorch training script +to verify that it reproduces the original JAX results up to several hundred thousand training iterations. Across our experiments, the PyTorch-trained models give +similar (and sometimes slightly better) results compared to the JAX-trained models up to reasonable random variation. Some data points: + +| DiT Model | Train Steps | FID-50K
(JAX Training) | FID-50K
(PyTorch Training) | PyTorch Global Training Seed | +|------------|-------------|----------------------------|--------------------------------|------------------------------| +| XL/2 | 400K | 19.5 | **18.1** | 42 | +| B/4 | 400K | **68.4** | 68.9 | 42 | +| B/4 | 400K | 68.4 | **68.3** | 100 | + +These models were trained at 256x256 resolution; we used 8x A100s to train XL/2 and 4x A100s to train B/4. Note that FID +here is computed with 250 DDPM sampling steps, with the `mse` VAE decoder and without guidance (`cfg-scale=1`). + +**TF32 Note (important for A100 users).** When we ran the above tests, TF32 matmuls were disabled per PyTorch's defaults. +We've enabled them at the top of `train.py` and `sample.py` because it makes training and sampling way way way faster on +A100s (and should for other Ampere GPUs too), but note that the use of TF32 may lead to some differences compared to +the above results. + +### Enhancements +Training (and sampling) could likely be sped-up significantly by: +- [ ] using [Flash Attention](https://github.com/HazyResearch/flash-attention) in the DiT model +- [ ] using `torch.compile` in PyTorch 2.0 + +Basic features that would be nice to add: +- [ ] Monitor FID and other metrics +- [ ] Generate and save samples from the EMA model periodically +- [ ] Resume training from a checkpoint +- [ ] AMP/bfloat16 support + +**🔥 Feature Update** Check out this repository at https://github.com/chuanyangjin/fast-DiT to preview a selection of training speed acceleration and memory saving features including gradient checkpointing, mixed precision training and pre-extrated VAE features. With these advancements, we have achieved a training speed of 0.84 steps/sec for DiT-XL/2 using just a single A100 GPU. + +## Evaluation (FID, Inception Score, etc.) + +We include a [`sample_ddp.py`](sample_ddp.py) script which samples a large number of images from a DiT model in parallel. This script +generates a folder of samples as well as a `.npz` file which can be directly used with [ADM's TensorFlow +evaluation suite](https://github.com/openai/guided-diffusion/tree/main/evaluations) to compute FID, Inception Score and +other metrics. For example, to sample 50K images from our pre-trained DiT-XL/2 model over `N` GPUs, run: + +```bash +torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py --model DiT-XL/2 --num-fid-samples 50000 +``` + +There are several additional options; see [`sample_ddp.py`](sample_ddp.py) for details. + + +## Differences from JAX + +Our models were originally trained in JAX on TPUs. The weights in this repo are ported directly from the JAX models. +There may be minor differences in results stemming from sampling with different floating point precisions. We re-evaluated +our ported PyTorch weights at FP32, and they actually perform marginally better than sampling in JAX (2.21 FID +versus 2.27 in the paper). + + +## BibTeX + +```bibtex +@article{Peebles2022DiT, + title={Scalable Diffusion Models with Transformers}, + author={William Peebles and Saining Xie}, + year={2022}, + journal={arXiv preprint arXiv:2212.09748}, +} +``` + + +## Acknowledgments +We thank Kaiming He, Ronghang Hu, Alexander Berg, Shoubhik Debnath, Tim Brooks, Ilija Radosavovic and Tete Xiao for helpful discussions. +William Peebles is supported by the NSF Graduate Research Fellowship. + +This codebase borrows from OpenAI's diffusion repos, most notably [ADM](https://github.com/openai/guided-diffusion). + + +## License +The code and model weights are licensed under CC-BY-NC. See [`LICENSE.txt`](LICENSE.txt) for details. diff --git a/PyTorch/built-in/mlm/DiT/diffusion/__init__.py b/PyTorch/built-in/mlm/DiT/diffusion/__init__.py new file mode 100644 index 0000000000..8c536a98da --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/diffusion/__init__.py @@ -0,0 +1,46 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from . import gaussian_diffusion as gd +from .respace import SpacedDiffusion, space_timesteps + + +def create_diffusion( + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + rescale_learned_sigmas=False, + diffusion_steps=1000 +): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type + # rescale_timesteps=rescale_timesteps, + ) diff --git a/PyTorch/built-in/mlm/DiT/diffusion/diffusion_utils.py b/PyTorch/built-in/mlm/DiT/diffusion/diffusion_utils.py new file mode 100644 index 0000000000..e493a6a3ec --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/diffusion/diffusion_utils.py @@ -0,0 +1,88 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import torch as th +import numpy as np + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def continuous_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a continuous Gaussian distribution. + :param x: the targets + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + centered_x = x - means + inv_stdv = th.exp(-log_scales) + normalized_x = centered_x * inv_stdv + log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) + return log_probs + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/PyTorch/built-in/mlm/DiT/diffusion/gaussian_diffusion.py b/PyTorch/built-in/mlm/DiT/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000..ccbcefeca4 --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/diffusion/gaussian_diffusion.py @@ -0,0 +1,873 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + + +import math + +import numpy as np +import torch as th +import enum + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start ** 0.5, + beta_end ** 0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type + ): + + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) if len(self.posterior_variance) > 1 else np.array([]) + + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/PyTorch/built-in/mlm/DiT/diffusion/respace.py b/PyTorch/built-in/mlm/DiT/diffusion/respace.py new file mode 100644 index 0000000000..0a2cc0435d --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/diffusion/respace.py @@ -0,0 +1,129 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/PyTorch/built-in/mlm/DiT/diffusion/timestep_sampler.py b/PyTorch/built-in/mlm/DiT/diffusion/timestep_sampler.py new file mode 100644 index 0000000000..a3f3698476 --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/diffusion/timestep_sampler.py @@ -0,0 +1,150 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/PyTorch/built-in/mlm/DiT/download.py b/PyTorch/built-in/mlm/DiT/download.py new file mode 100644 index 0000000000..de22d459a3 --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/download.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Functions for downloading pre-trained DiT models +""" +from torchvision.datasets.utils import download_url +import torch +import os + + +pretrained_models = {'DiT-XL-2-512x512.pt', 'DiT-XL-2-256x256.pt'} + + +def find_model(model_name): + """ + Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path. + """ + if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints + return download_model(model_name) + else: # Load a custom DiT checkpoint: + assert os.path.isfile(model_name), f'Could not find DiT checkpoint at {model_name}' + checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) + if "ema" in checkpoint: # supports checkpoints from train.py + checkpoint = checkpoint["ema"] + return checkpoint + + +def download_model(model_name): + """ + Downloads a pre-trained DiT model from the web. + """ + assert model_name in pretrained_models + local_path = f'pretrained_models/{model_name}' + if not os.path.isfile(local_path): + os.makedirs('pretrained_models', exist_ok=True) + web_path = f'https://dl.fbaipublicfiles.com/DiT/models/{model_name}' + download_url(web_path, 'pretrained_models') + model = torch.load(local_path, map_location=lambda storage, loc: storage) + return model + + +if __name__ == "__main__": + # Download all DiT checkpoints + for model in pretrained_models: + download_model(model) + print('Done.') diff --git a/PyTorch/built-in/mlm/DiT/environment.yml b/PyTorch/built-in/mlm/DiT/environment.yml new file mode 100644 index 0000000000..b5abcab94f --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/environment.yml @@ -0,0 +1,13 @@ +name: DiT +channels: + - pytorch + - nvidia +dependencies: + - python >= 3.8 + - pytorch >= 1.13 + - torchvision + - pytorch-cuda=11.7 + - pip: + - timm + - diffusers + - accelerate diff --git a/PyTorch/built-in/mlm/DiT/models.py b/PyTorch/built-in/mlm/DiT/models.py new file mode 100644 index 0000000000..c90eeba7b2 --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/models.py @@ -0,0 +1,370 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import numpy as np +import math +from timm.models.vision_transformer import PatchEmbed, Attention, Mlp + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +################################################################################# +# Core DiT Model # +################################################################################# + +class DiTBlock(nn.Module): + """ + A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class DiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + num_classes=1000, + learn_sigma=True, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + # Will use fixed sin-cos embedding: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + + self.blocks = nn.ModuleList([ + DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) + ]) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize (and freeze) pos_embed by sin-cos embedding: + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) + return imgs + + def forward(self, x, t, y): + """ + Forward pass of DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N,) tensor of class labels + """ + x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(t) # (N, D) + y = self.y_embedder(y, self.training) # (N, D) + c = t + y # (N, D) + for block in self.blocks: + x = block(x, c) # (N, T, D) + x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + return x + + def forward_with_cfg(self, x, t, y, cfg_scale): + """ + Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, t, y) + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +################################################################################# +# DiT Configs # +################################################################################# + +def DiT_XL_2(**kwargs): + return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) + +def DiT_XL_4(**kwargs): + return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) + +def DiT_XL_8(**kwargs): + return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) + +def DiT_L_2(**kwargs): + return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) + +def DiT_L_4(**kwargs): + return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) + +def DiT_L_8(**kwargs): + return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) + +def DiT_B_2(**kwargs): + return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) + +def DiT_B_4(**kwargs): + return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) + +def DiT_B_8(**kwargs): + return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) + +def DiT_S_2(**kwargs): + return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) + +def DiT_S_4(**kwargs): + return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) + +def DiT_S_8(**kwargs): + return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) + + +DiT_models = { + 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8, + 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, + 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, + 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8, +} diff --git a/PyTorch/built-in/mlm/DiT/sample.py b/PyTorch/built-in/mlm/DiT/sample.py new file mode 100644 index 0000000000..a4152afd88 --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/sample.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Sample new images from a pre-trained DiT. +""" +import torch +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +from torchvision.utils import save_image +from diffusion import create_diffusion +from diffusers.models import AutoencoderKL +from download import find_model +from models import DiT_models +import argparse + + +def main(args): + # Setup PyTorch: + torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + + if args.ckpt is None: + assert args.model == "DiT-XL/2", "Only DiT-XL/2 models are available for auto-download." + assert args.image_size in [256, 512] + assert args.num_classes == 1000 + + # Load model: + latent_size = args.image_size // 8 + model = DiT_models[args.model]( + input_size=latent_size, + num_classes=args.num_classes + ).to(device) + # Auto-download a pre-trained model or load a custom DiT checkpoint from train.py: + ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt" + state_dict = find_model(ckpt_path) + model.load_state_dict(state_dict) + model.eval() # important! + diffusion = create_diffusion(str(args.num_sampling_steps)) + vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) + + # Labels to condition the model with (feel free to change): + class_labels = [207, 360, 387, 974, 88, 979, 417, 279] + + # Create sampling noise: + n = len(class_labels) + z = torch.randn(n, 4, latent_size, latent_size, device=device) + y = torch.tensor(class_labels, device=device) + + # Setup classifier-free guidance: + z = torch.cat([z, z], 0) + y_null = torch.tensor([1000] * n, device=device) + y = torch.cat([y, y_null], 0) + model_kwargs = dict(y=y, cfg_scale=args.cfg_scale) + + # Sample images: + samples = diffusion.p_sample_loop( + model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device + ) + samples, _ = samples.chunk(2, dim=0) # Remove null class samples + samples = vae.decode(samples / 0.18215).sample + + # Save and display images: + save_image(samples, "sample.png", nrow=4, normalize=True, value_range=(-1, 1)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2") + parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="mse") + parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--num-classes", type=int, default=1000) + parser.add_argument("--cfg-scale", type=float, default=4.0) + parser.add_argument("--num-sampling-steps", type=int, default=250) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--ckpt", type=str, default=None, + help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).") + args = parser.parse_args() + main(args) diff --git a/PyTorch/built-in/mlm/DiT/sample_ddp.py b/PyTorch/built-in/mlm/DiT/sample_ddp.py new file mode 100644 index 0000000000..0a6b1ab6fb --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/sample_ddp.py @@ -0,0 +1,166 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Samples a large number of images from a pre-trained DiT model using DDP. +Subsequently saves a .npz file that can be used to compute FID and other +evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations + +For a simple single-GPU/CPU sampling script, see sample.py. +""" +import torch +import torch.distributed as dist +from models import DiT_models +from download import find_model +from diffusion import create_diffusion +from diffusers.models import AutoencoderKL +from tqdm import tqdm +import os +from PIL import Image +import numpy as np +import math +import argparse + + +def create_npz_from_sample_folder(sample_dir, num=50_000): + """ + Builds a single .npz file from a folder of .png samples. + """ + samples = [] + for i in tqdm(range(num), desc="Building .npz file from samples"): + sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") + sample_np = np.asarray(sample_pil).astype(np.uint8) + samples.append(sample_np) + samples = np.stack(samples) + assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) + npz_path = f"{sample_dir}.npz" + np.savez(npz_path, arr_0=samples) + print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") + return npz_path + + +def main(args): + """ + Run sampling. + """ + torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences + assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" + torch.set_grad_enabled(False) + + # Setup DDP: + dist.init_process_group("nccl") + rank = dist.get_rank() + device = rank % torch.cuda.device_count() + seed = args.global_seed * dist.get_world_size() + rank + torch.manual_seed(seed) + torch.cuda.set_device(device) + print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") + + if args.ckpt is None: + assert args.model == "DiT-XL/2", "Only DiT-XL/2 models are available for auto-download." + assert args.image_size in [256, 512] + assert args.num_classes == 1000 + + # Load model: + latent_size = args.image_size // 8 + model = DiT_models[args.model]( + input_size=latent_size, + num_classes=args.num_classes + ).to(device) + # Auto-download a pre-trained model or load a custom DiT checkpoint from train.py: + ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt" + state_dict = find_model(ckpt_path) + model.load_state_dict(state_dict) + model.eval() # important! + diffusion = create_diffusion(str(args.num_sampling_steps)) + vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) + assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0" + using_cfg = args.cfg_scale > 1.0 + + # Create folder to save samples: + model_string_name = args.model.replace("/", "-") + ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained" + folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.image_size}-vae-{args.vae}-" \ + f"cfg-{args.cfg_scale}-seed-{args.global_seed}" + sample_folder_dir = f"{args.sample_dir}/{folder_name}" + if rank == 0: + os.makedirs(sample_folder_dir, exist_ok=True) + print(f"Saving .png samples at {sample_folder_dir}") + dist.barrier() + + # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: + n = args.per_proc_batch_size + global_batch_size = n * dist.get_world_size() + # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples: + total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size) + if rank == 0: + print(f"Total number of images that will be sampled: {total_samples}") + assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size" + samples_needed_this_gpu = int(total_samples // dist.get_world_size()) + assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size" + iterations = int(samples_needed_this_gpu // n) + pbar = range(iterations) + pbar = tqdm(pbar) if rank == 0 else pbar + total = 0 + for _ in pbar: + # Sample inputs: + z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device) + y = torch.randint(0, args.num_classes, (n,), device=device) + + # Setup classifier-free guidance: + if using_cfg: + z = torch.cat([z, z], 0) + y_null = torch.tensor([1000] * n, device=device) + y = torch.cat([y, y_null], 0) + model_kwargs = dict(y=y, cfg_scale=args.cfg_scale) + sample_fn = model.forward_with_cfg + else: + model_kwargs = dict(y=y) + sample_fn = model.forward + + # Sample images: + samples = diffusion.p_sample_loop( + sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device + ) + if using_cfg: + samples, _ = samples.chunk(2, dim=0) # Remove null class samples + + samples = vae.decode(samples / 0.18215).sample + samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() + + # Save samples to disk as individual .png files + for i, sample in enumerate(samples): + index = i * dist.get_world_size() + rank + total + Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") + total += global_batch_size + + # Make sure all processes have finished saving their samples before attempting to convert to .npz + dist.barrier() + if rank == 0: + create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples) + print("Done.") + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2") + parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") + parser.add_argument("--sample-dir", type=str, default="samples") + parser.add_argument("--per-proc-batch-size", type=int, default=32) + parser.add_argument("--num-fid-samples", type=int, default=50_000) + parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--num-classes", type=int, default=1000) + parser.add_argument("--cfg-scale", type=float, default=1.5) + parser.add_argument("--num-sampling-steps", type=int, default=250) + parser.add_argument("--global-seed", type=int, default=0) + parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True, + help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.") + parser.add_argument("--ckpt", type=str, default=None, + help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).") + args = parser.parse_args() + main(args) diff --git a/PyTorch/built-in/mlm/DiT/train.py b/PyTorch/built-in/mlm/DiT/train.py new file mode 100644 index 0000000000..7cfee8089b --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/train.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +A minimal training script for DiT using PyTorch DDP. +""" +import torch +# the first flag below was False when we tested this script but True makes A100 training a lot faster: +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torchvision.datasets import ImageFolder +from torchvision import transforms +import numpy as np +from collections import OrderedDict +from PIL import Image +from copy import deepcopy +from glob import glob +from time import time +import argparse +import logging +import os + +from models import DiT_models +from diffusion import create_diffusion +from diffusers.models import AutoencoderKL + + +################################################################################# +# Training Helper Functions # +################################################################################# + +@torch.no_grad() +def update_ema(ema_model, model, decay=0.9999): + """ + Step the EMA model towards the current model. + """ + ema_params = OrderedDict(ema_model.named_parameters()) + model_params = OrderedDict(model.named_parameters()) + + for name, param in model_params.items(): + # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed + ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) + + +def requires_grad(model, flag=True): + """ + Set requires_grad flag for all parameters in a model. + """ + for p in model.parameters(): + p.requires_grad = flag + + +def cleanup(): + """ + End DDP training. + """ + dist.destroy_process_group() + + +def create_logger(logging_dir): + """ + Create a logger that writes to a log file and stdout. + """ + if dist.get_rank() == 0: # real logger + logging.basicConfig( + level=logging.INFO, + format='[\033[34m%(asctime)s\033[0m] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] + ) + logger = logging.getLogger(__name__) + else: # dummy logger (does nothing) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + return logger + + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + +################################################################################# +# Training Loop # +################################################################################# + +def main(args): + """ + Trains a new DiT model. + """ + assert torch.cuda.is_available(), "Training currently requires at least one GPU." + + # Setup DDP: + dist.init_process_group("nccl") + assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size." + rank = dist.get_rank() + device = rank % torch.cuda.device_count() + seed = args.global_seed * dist.get_world_size() + rank + torch.manual_seed(seed) + torch.cuda.set_device(device) + print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") + + # Setup an experiment folder: + if rank == 0: + os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders) + experiment_index = len(glob(f"{args.results_dir}/*")) + model_string_name = args.model.replace("/", "-") # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders) + experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder + checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints + os.makedirs(checkpoint_dir, exist_ok=True) + logger = create_logger(experiment_dir) + logger.info(f"Experiment directory created at {experiment_dir}") + else: + logger = create_logger(None) + + # Create model: + assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)." + latent_size = args.image_size // 8 + model = DiT_models[args.model]( + input_size=latent_size, + num_classes=args.num_classes + ) + # Note that parameter initialization is done within the DiT constructor + ema = deepcopy(model).to(device) # Create an EMA of the model for use after training + requires_grad(ema, False) + model = DDP(model.to(device), device_ids=[rank]) + diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule + vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) + logger.info(f"DiT Parameters: {sum(p.numel() for p in model.parameters()):,}") + + # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper): + opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0) + + # Setup data: + transform = transforms.Compose([ + transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + dataset = ImageFolder(args.data_path, transform=transform) + sampler = DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=rank, + shuffle=True, + seed=args.global_seed + ) + loader = DataLoader( + dataset, + batch_size=int(args.global_batch_size // dist.get_world_size()), + shuffle=False, + sampler=sampler, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True + ) + logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})") + + # Prepare models for training: + update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights + model.train() # important! This enables embedding dropout for classifier-free guidance + ema.eval() # EMA model should always be in eval mode + + # Variables for monitoring/logging purposes: + train_steps = 0 + log_steps = 0 + running_loss = 0 + start_time = time() + + logger.info(f"Training for {args.epochs} epochs...") + for epoch in range(args.epochs): + sampler.set_epoch(epoch) + logger.info(f"Beginning epoch {epoch}...") + for x, y in loader: + x = x.to(device) + y = y.to(device) + with torch.no_grad(): + # Map input images to latent space + normalize latents: + x = vae.encode(x).latent_dist.sample().mul_(0.18215) + t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device) + model_kwargs = dict(y=y) + loss_dict = diffusion.training_losses(model, x, t, model_kwargs) + loss = loss_dict["loss"].mean() + opt.zero_grad() + loss.backward() + opt.step() + update_ema(ema, model.module) + + # Log loss values: + running_loss += loss.item() + log_steps += 1 + train_steps += 1 + if train_steps % args.log_every == 0: + # Measure training speed: + torch.cuda.synchronize() + end_time = time() + steps_per_sec = log_steps / (end_time - start_time) + # Reduce loss history over all processes: + avg_loss = torch.tensor(running_loss / log_steps, device=device) + dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) + avg_loss = avg_loss.item() / dist.get_world_size() + logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}") + # Reset monitoring variables: + running_loss = 0 + log_steps = 0 + start_time = time() + + # Save DiT checkpoint: + if train_steps % args.ckpt_every == 0 and train_steps > 0: + if rank == 0: + checkpoint = { + "model": model.module.state_dict(), + "ema": ema.state_dict(), + "opt": opt.state_dict(), + "args": args + } + checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt" + torch.save(checkpoint, checkpoint_path) + logger.info(f"Saved checkpoint to {checkpoint_path}") + dist.barrier() + + model.eval() # important! This disables randomized embedding dropout + # do any sampling/FID calculation/etc. with ema (or model) in eval mode ... + + logger.info("Done!") + cleanup() + + +if __name__ == "__main__": + # Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters). + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, required=True) + parser.add_argument("--results-dir", type=str, default="results") + parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2") + parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--num-classes", type=int, default=1000) + parser.add_argument("--epochs", type=int, default=1400) + parser.add_argument("--global-batch-size", type=int, default=256) + parser.add_argument("--global-seed", type=int, default=0) + parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") # Choice doesn't affect training + parser.add_argument("--num-workers", type=int, default=4) + parser.add_argument("--log-every", type=int, default=100) + parser.add_argument("--ckpt-every", type=int, default=50_000) + args = parser.parse_args() + main(args) -- Gitee From 6087e9973242f48618383b76726158ad3e047a3e Mon Sep 17 00:00:00 2001 From: liuqiyuan Date: Mon, 13 May 2024 19:32:14 +0800 Subject: [PATCH 2/2] adapt to npu --- PyTorch/built-in/mlm/DiT/README.md | 260 ++++++++++-------- .../mlm/DiT/diffusion/gaussian_diffusion.py | 19 +- .../mlm/DiT/public_address_statement.md | 13 + PyTorch/built-in/mlm/DiT/requirement.txt | 8 + PyTorch/built-in/mlm/DiT/sample.py | 17 +- PyTorch/built-in/mlm/DiT/test/env_npu.sh | 51 ++++ PyTorch/built-in/mlm/DiT/test/train_8p.sh | 145 ++++++++++ PyTorch/built-in/mlm/DiT/train.py | 32 ++- PyTorch/built-in/mlm/DiT/utils/adamw.py | 150 ++++++++++ .../built-in/mlm/DiT/utils/device_utils.py | 31 +++ 10 files changed, 609 insertions(+), 117 deletions(-) create mode 100644 PyTorch/built-in/mlm/DiT/public_address_statement.md create mode 100644 PyTorch/built-in/mlm/DiT/requirement.txt create mode 100644 PyTorch/built-in/mlm/DiT/test/env_npu.sh create mode 100644 PyTorch/built-in/mlm/DiT/test/train_8p.sh create mode 100644 PyTorch/built-in/mlm/DiT/utils/adamw.py create mode 100644 PyTorch/built-in/mlm/DiT/utils/device_utils.py diff --git a/PyTorch/built-in/mlm/DiT/README.md b/PyTorch/built-in/mlm/DiT/README.md index 1337ab3468..68e69f08e6 100644 --- a/PyTorch/built-in/mlm/DiT/README.md +++ b/PyTorch/built-in/mlm/DiT/README.md @@ -1,163 +1,201 @@ -## Scalable Diffusion Models with Transformers (DiT)
Official PyTorch Implementation +# DiT for PyTorch -### [Paper](http://arxiv.org/abs/2212.09748) | [Project Page](https://www.wpeebles.com/DiT) | Run DiT-XL/2 [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/wpeebles/DiT) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb) +## 目录 -![DiT samples](visuals/sample_grid_0.png) +- [简介](#简介) + - [模型介绍](#模型介绍) + - [支持任务列表](#支持任务列表) + - [代码实现](#代码实现) +- [DiT](#DiT) + - [准备训练环境](#准备训练环境) + - [快速开始](#快速开始) + - [训练任务](#训练任务) + - [在线推理](#在线推理) +- [公网地址说明](#公网地址说明) +- [变更说明](#变更说明) +- [FAQ](#FAQ) -This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring -diffusion models with transformers (DiTs). You can find more visualizations on our [project page](https://www.wpeebles.com/DiT). +# 简介 -> [**Scalable Diffusion Models with Transformers**](https://www.wpeebles.com/DiT)
-> [William Peebles](https://www.wpeebles.com), [Saining Xie](https://www.sainingxie.com) ->
UC Berkeley, New York University
+## 模型介绍 -We train latent diffusion models, replacing the commonly-used U-Net backbone with a transformer that operates on -latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass -complexity as measured by Gflops. We find that DiTs with higher Gflops---through increased transformer depth/width or -increased number of input tokens---consistently have lower FID. In addition to good scalability properties, our -DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512×512 and 256×256 benchmarks, -achieving a state-of-the-art FID of 2.27 on the latter. +Scalable Diffusion Models with Transformers,是完全基于transformer架构的扩散模型,这个工作不仅将transformer成功应用在扩散模型,还探究了transformer架构在扩散模型上的scalability能力,其中最大的模型DiT-XL/2在ImageNet 256x256的类别条件生成上达到了SOTA。 -This repository contains: +## 支持任务列表 +本仓已经支持以下模型任务类型 -* 🪐 A simple PyTorch [implementation](models.py) of DiT -* ⚡️ Pre-trained class-conditional DiT models trained on ImageNet (512x512 and 256x256) -* 💥 A self-contained [Hugging Face Space](https://huggingface.co/spaces/wpeebles/DiT) and [Colab notebook](http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb) for running pre-trained DiT-XL/2 models -* 🛸 A DiT [training script](train.py) using PyTorch DDP +| 模型 | 任务列表 | 是否支持 | +| :------: | :------: | :------: | +| DiT-XL/2 | 训练 | ✔ | -An implementation of DiT directly in Hugging Face `diffusers` can also be found [here](https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/pipelines/dit.mdx). +## 代码实现 +- 参考实现: -## Setup + ``` + url=https://github.com/facebookresearch/DiT + commit_id=https://github.com/facebookresearch/DiT + ``` +- 适配昇腾 AI 处理器的实现: -First, download and set up the repo: + ``` + url=https://gitee.com/ascend/ModelZoo-PyTorch.git + code_path=PyTorch/built-in/mlm/DiT + ``` -```bash -git clone https://github.com/facebookresearch/DiT.git -cd DiT -``` +# DiT -We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want -to run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the file. +## 准备训练环境 -```bash -conda env create -f environment.yml -conda activate DiT -``` +### 安装环境 + **表 1** 三方库版本支持表 -## Sampling [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/wpeebles/DiT) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb) -![More DiT samples](visuals/sample_grid_1.png) +| 三方库 | 支持版本 | +| :-----: | :------: | +| PyTorch | 2.1.0 | -**Pre-trained DiT checkpoints.** You can sample from our pre-trained DiT models with [`sample.py`](sample.py). Weights for our pre-trained DiT model will be -automatically downloaded depending on the model you use. The script has various arguments to switch between the 256x256 -and 512x512 models, adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from -our 512x512 DiT-XL/2 model, you can use: +- 在模型根目录下执行以下命令,安装模型对应PyTorch版本需要的依赖。 -```bash -python sample.py --image-size 512 --seed 1 -``` + ```shell + pip install -r requirements.txt + ``` -For convenience, our pre-trained DiT models can be downloaded directly here as well: +### 安装昇腾环境 -| DiT Model | Image Resolution | FID-50K | Inception Score | Gflops | -|---------------|------------------|---------|-----------------|--------| -| [XL/2](https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt) | 256x256 | 2.27 | 278.24 | 119 | -| [XL/2](https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt) | 512x512 | 3.04 | 240.82 | 525 | + 请参考昇腾社区中《[Pytorch框架训练环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/ptes)》文档搭建昇腾环境,本仓已支持表2中软件版本。 + **表 2** 昇腾软件版本支持表 -**Custom DiT checkpoints.** If you've trained a new DiT model with [`train.py`](train.py) (see [below](#training-dit)), you can add the `--ckpt` -argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom -256x256 DiT-L/4 model, run: +| 软件类型 | 支持版本 | +| :---------------: | :------: | +| FrameworkPTAdaper | 6.0.RC2 | +| CANN | 8.0.RC2 | +| 昇腾NPU固件 | 24.1.RC2 | +| 昇腾NPU驱动 | 24.1.RC2 | -```bash -python sample.py --model DiT-L/4 --image-size 256 --ckpt /path/to/model.pt -``` +### 准备预训练权重 +- 联网环境下使用以下命令会自动下载**stabilityai/sd-vae-ft-mse**预训练模型。如果网络问题无法自动下载,需要在官网手动下载,并放在./sd-vae-ft-mse路径下,目录结构如下所示。并修改train.py--line160指向上述路径 -## Training DiT + ``` + sd-vae-ft-mse + ├── config.json + ├── diffusion_pytorch_model.bin + ├── diffusion_pytorch_model.safetensors + ├── README.md + ``` -We provide a training script for DiT in [`train.py`](train.py). This script can be used to train class-conditional -DiT models, but it can be easily modified to support other types of conditioning. To launch DiT-XL/2 (256x256) training with `N` GPUs on -one node: -```bash -torchrun --nnodes=1 --nproc_per_node=N train.py --model DiT-XL/2 --data-path /path/to/imagenet/train +### 准备数据集 + +- 自行下载准备imageNet2012数据集,目录结构如下。 + +``` +├── ImageNet2012 + ├──train + ├──类别1 + │──图片1 + │──图片2 + │ ... + ├──类别2 + │──图片1 + │──图片2 + │ ... + ├──... + ├──val + ├──类别1 + │──图片1 + │──图片2 + │ ... + ├──类别2 + │──图片1 + │──图片2 + │ ... ``` -### PyTorch Training Results +> **说明:** +> 该数据集的训练过程脚本只作为一种参考示例。 -We've trained DiT-XL/2 and DiT-B/4 models from scratch with the PyTorch training script -to verify that it reproduces the original JAX results up to several hundred thousand training iterations. Across our experiments, the PyTorch-trained models give -similar (and sometimes slightly better) results compared to the JAX-trained models up to reasonable random variation. Some data points: -| DiT Model | Train Steps | FID-50K
(JAX Training) | FID-50K
(PyTorch Training) | PyTorch Global Training Seed | -|------------|-------------|----------------------------|--------------------------------|------------------------------| -| XL/2 | 400K | 19.5 | **18.1** | 42 | -| B/4 | 400K | **68.4** | 68.9 | 42 | -| B/4 | 400K | 68.4 | **68.3** | 100 | +## 快速开始 +### 训练任务 -These models were trained at 256x256 resolution; we used 8x A100s to train XL/2 and 4x A100s to train B/4. Note that FID -here is computed with 250 DDPM sampling steps, with the `mse` VAE decoder and without guidance (`cfg-scale=1`). +本任务主要提供**单机**的**8卡**训练脚本。 -**TF32 Note (important for A100 users).** When we ran the above tests, TF32 matmuls were disabled per PyTorch's defaults. -We've enabled them at the top of `train.py` and `sample.py` because it makes training and sampling way way way faster on -A100s (and should for other Ampere GPUs too), but note that the use of TF32 may lead to some differences compared to -the above results. +#### 开始训练 -### Enhancements -Training (and sampling) could likely be sped-up significantly by: -- [ ] using [Flash Attention](https://github.com/HazyResearch/flash-attention) in the DiT model -- [ ] using `torch.compile` in PyTorch 2.0 + 1. 进入源码根目录。 -Basic features that would be nice to add: -- [ ] Monitor FID and other metrics -- [ ] Generate and save samples from the EMA model periodically -- [ ] Resume training from a checkpoint -- [ ] AMP/bfloat16 support + ``` + cd /${模型文件夹名称} + ``` -**🔥 Feature Update** Check out this repository at https://github.com/chuanyangjin/fast-DiT to preview a selection of training speed acceleration and memory saving features including gradient checkpointing, mixed precision training and pre-extrated VAE features. With these advancements, we have achieved a training speed of 0.84 steps/sec for DiT-XL/2 using just a single A100 GPU. + 2. 运行训练脚本。 -## Evaluation (FID, Inception Score, etc.) + 该模型支持单机8卡训练。 -We include a [`sample_ddp.py`](sample_ddp.py) script which samples a large number of images from a DiT model in parallel. This script -generates a folder of samples as well as a `.npz` file which can be directly used with [ADM's TensorFlow -evaluation suite](https://github.com/openai/guided-diffusion/tree/main/evaluations) to compute FID, Inception Score and -other metrics. For example, to sample 50K images from our pre-trained DiT-XL/2 model over `N` GPUs, run: + - 单机8卡训练 -```bash -torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py --model DiT-XL/2 --num-fid-samples 50000 -``` + ``` + bash test/train_8p.sh --data_path=/PATH/imagenet2012/train --image_size=256 global_batch_size=256 --precision=fp32 --epochs=1 + ``` + + 模型训练脚本参数说明如下。 + + ``` + train_8p.sh + --data_path //训练数据路径 + --image_size //图片大小,支持256和512 + --global_batch_size //全局batch size设置 + --precision // 训练精度,支持fp32和bf16 + --epochs //训练轮数 + ``` -There are several additional options; see [`sample_ddp.py`](sample_ddp.py) for details. +#### 训练结果 +| 芯片 | 卡数 | image size | global batch size | Precision | 性能FPS | +| ------------- | :--: | :--------: | :---------------: | :-------: | :-----: | +| GPU | 8p | 256 | 256 | fp32 | 432 | +| Atlas 800T A2 | 8p | 256 | 256 | fp32 | 376 | +| GPU | 8p | 256 | 512 | bf16 | 727 | +| Atlas 800T A2 | 8p | 256 | 512 | bf16 | 586 | +| GPU | 8p | 512 | 64 | fp32 | 80 | +| Atlas 800T A2 | 8p | 512 | 64 | fp32 | 77 | +| GPU | 8p | 512 | 128 | bf16 | 151 | +| Atlas 800T A2 | 8p | 512 | 128 | bf16 | 122 | -## Differences from JAX +### 在线推理 -Our models were originally trained in JAX on TPUs. The weights in this repo are ported directly from the JAX models. -There may be minor differences in results stemming from sampling with different floating point precisions. We re-evaluated -our ported PyTorch weights at FP32, and they actually perform marginally better than sampling in JAX (2.21 FID -versus 2.27 in the paper). +本任务主要提供**单卡**推理功能。 +#### 开始推理 -## BibTeX +1. 单卡推理命令 -```bibtex -@article{Peebles2022DiT, - title={Scalable Diffusion Models with Transformers}, - author={William Peebles and Saining Xie}, - year={2022}, - journal={arXiv preprint arXiv:2212.09748}, -} ``` +python sample.py --model DiT-XL/2 --image-size 256 --ckpt /path/to/model.pt +``` + +脚本入参说明如下。 + +``` +sample.py + --model //模型结构 + --image-size //图片大小,支持256和512 + --ckpt //权重路径,支持官方开源权重和自己训练的权重 +``` + + + +# 公网地址说明 +代码涉及公网地址参考 public_address_statement.md -## Acknowledgments -We thank Kaiming He, Ronghang Hu, Alexander Berg, Shoubhik Debnath, Tim Brooks, Ilija Radosavovic and Tete Xiao for helpful discussions. -William Peebles is supported by the NSF Graduate Research Fellowship. +# 变更说明 -This codebase borrows from OpenAI's diffusion repos, most notably [ADM](https://github.com/openai/guided-diffusion). +2024.05.15:首次发布。 +# FAQ -## License -The code and model weights are licensed under CC-BY-NC. See [`LICENSE.txt`](LICENSE.txt) for details. +无 \ No newline at end of file diff --git a/PyTorch/built-in/mlm/DiT/diffusion/gaussian_diffusion.py b/PyTorch/built-in/mlm/DiT/diffusion/gaussian_diffusion.py index ccbcefeca4..8488374647 100644 --- a/PyTorch/built-in/mlm/DiT/diffusion/gaussian_diffusion.py +++ b/PyTorch/built-in/mlm/DiT/diffusion/gaussian_diffusion.py @@ -1,3 +1,17 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Modified from OpenAI's diffusion repos # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion @@ -712,7 +726,7 @@ class GaussianDiffusion: output = th.where((t == 0), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} - def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + def training_losses(self, model, x_start, t, precision, model_kwargs=None, noise=None): """ Compute training losses for a single timestep. :param model: the model to evaluate loss on. @@ -730,6 +744,9 @@ class GaussianDiffusion: noise = th.randn_like(x_start) x_t = self.q_sample(x_start, t, noise=noise) + if precision == "bf16": + x_t = x_t.to(th.bfloat16) + terms = {} if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: diff --git a/PyTorch/built-in/mlm/DiT/public_address_statement.md b/PyTorch/built-in/mlm/DiT/public_address_statement.md new file mode 100644 index 0000000000..879c3b9e1d --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/public_address_statement.md @@ -0,0 +1,13 @@ +| 类型 | 开源代码地址 | 文件名 | 公网IP地址/公网URL地址/域名/邮箱地址 | 用途说明 | +| ------------ | ------------------------------------------------------------ | ------------------------------- | ------------------------------------------------------------ | -------------- | +| 开源代码引入 | https://github.com/facebookresearch/DiT/blob/main/download.py | download.py | https://dl.fbaipublicfiles.com/DiT/models/ | 下载预训练权重 | +| 开源代码引入 | https://github.com/facebookresearch/DiT/blob/main/model.py | models.py | https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py | 引用说明 | +| 开源代码引入 | https://github.com/facebookresearch/DiT/blob/main/model.py | models.py | https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb | 引用说明 | +| 开源代码引入 | https://github.com/facebookresearch/DiT/blob/main/model.py | models.py | https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py | 引用说明 | +| 开源代码引入 | https://github.com/facebookresearch/DiT/blob/main/sample_ddp.py | sample_ddp.py | https://github.com/openai/guided-diffusion/tree/main/evaluations | 引用说明 | +| 开源代码引入 | https://github.com/facebookresearch/DiT/blob/main/train.py | train.py | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py | 引用说明 | +| 开源代码引入 | https://github.com/facebookresearch/DiT/blob/main/diffusion/gaussian_diffusion.py | diffusion\gaussian_diffusion.py | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py | 引用说明 | +| 开源代码引入 | https://github.com/facebookresearch/DiT/blob/main/diffusion/gaussian_diffusion.py | diffusion\gaussian_diffusion.py | https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py | 引用说明 | +| 开源代码引入 | https://github.com/facebookresearch/DiT/blob/main/diffusion/gaussian_diffusion.py | diffusion\gaussian_diffusion.py | https://github.com/openai/guided-diffusion/blob/main/guided_diffusion | 引用说明 | +| 开源代码引入 | https://github.com/facebookresearch/DiT/blob/main/diffusion/gaussian_diffusion.py | diffusion\gaussian_diffusion.py | https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py | 引用说明 | +| | | | | | \ No newline at end of file diff --git a/PyTorch/built-in/mlm/DiT/requirement.txt b/PyTorch/built-in/mlm/DiT/requirement.txt new file mode 100644 index 0000000000..0185ed31b6 --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/requirement.txt @@ -0,0 +1,8 @@ +timm==0.9.16 +diffusers==0.27.2 +accelerate==0.29.2 +torchvision==0.16.0 +protobuf +decorator +scipy +attrs \ No newline at end of file diff --git a/PyTorch/built-in/mlm/DiT/sample.py b/PyTorch/built-in/mlm/DiT/sample.py index a4152afd88..7d77198200 100644 --- a/PyTorch/built-in/mlm/DiT/sample.py +++ b/PyTorch/built-in/mlm/DiT/sample.py @@ -1,3 +1,16 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. @@ -16,7 +29,9 @@ from diffusers.models import AutoencoderKL from download import find_model from models import DiT_models import argparse - +from utils.device_utils import is_npu_available +if is_npu_available(): + from torch_npu.contrib import transfer_to_npu def main(args): # Setup PyTorch: diff --git a/PyTorch/built-in/mlm/DiT/test/env_npu.sh b/PyTorch/built-in/mlm/DiT/test/env_npu.sh new file mode 100644 index 0000000000..19880a441e --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/test/env_npu.sh @@ -0,0 +1,51 @@ +#!/bin/bash +CANN_INSTALL_PATH_CONF='/etc/Ascend/ascend_cann_install.info' + +if [ -f $CANN_INSTALL_PATH_CONF ]; then + CANN_INSTALL_PATH=$(cat $CANN_INSTALL_PATH_CONF | grep Install_Path | cut -d "=" -f 2) +else + CANN_INSTALL_PATH="/usr/local/Ascend" +fi + +if [ -d ${CANN_INSTALL_PATH}/ascend-toolkit/latest ]; then + source ${CANN_INSTALL_PATH}/ascend-toolkit/set_env.sh +else + source ${CANN_INSTALL_PATH}/nnae/set_env.sh +fi + +#将Host日志输出到串口,0-关闭/1-开启 +export ASCEND_SLOG_PRINT_TO_STDOUT=0 +#设置默认日志级别,0-debug/1-info/2-warning/3-error +export ASCEND_GLOBAL_LOG_LEVEL=3 +#设置Event日志开启标志,0-关闭/1-开启 +export ASCEND_GLOBAL_EVENT_ENABLE=0 + + +#设置device侧日志登记为error +msnpureport -g error -d 0 +msnpureport -g error -d 1 +msnpureport -g error -d 2 +msnpureport -g error -d 3 +msnpureport -g error -d 4 +msnpureport -g error -d 5 +msnpureport -g error -d 6 +msnpureport -g error -d 7 +#关闭Device侧Event日志 +msnpureport -e disable + +path_lib=$(python3 -c """ +import sys +import re +result='' +for index in range(len(sys.path)): + match_sit = re.search('-packages', sys.path[index]) + if match_sit is not None: + match_lib = re.search('lib', sys.path[index]) + + if match_lib is not None: + end=match_lib.span()[1] + result += sys.path[index][0:end] + ':' + + result+=sys.path[index] + '/torch/lib:' +print(result)""" +) diff --git a/PyTorch/built-in/mlm/DiT/test/train_8p.sh b/PyTorch/built-in/mlm/DiT/test/train_8p.sh new file mode 100644 index 0000000000..453a754c48 --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/test/train_8p.sh @@ -0,0 +1,145 @@ +#!/bin/bash +################基础配置参数,需要模型审视修改################## +# 网络名称,同目录名称 +Network="DiT" +WORLD_SIZE=8 +WORK_DIR="" +LOAD_FROM="" + +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +PORT=${PORT:-29500} +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +for para in $* +do + if [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + elif [[ $para == --image_size* ]];then + image_size=`echo ${para#*=}` + elif [[ $para == --global_batch_size* ]];then + global_batch_size=`echo ${para#*=}` + elif [[ $para == --precision* ]];then + precision=`echo ${para#*=}` + elif [[ $para == --epochs* ]];then + epochs=`echo ${para#*=}` + fi +done + +# 校验是否传入data_path +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path\" must be confing" + exit 1 +fi + +# 校验是否传入image_size +if [[ $image_size == "" ]];then + echo "[Error] para \"image_size\" must be confing" + exit 1 +fi + +# 校验是否传入global_batch_size +if [[ $global_batch_size == "" ]];then + echo "[Error] para \"global_batch_size\" must be confing" + exit 1 +fi + +# 校验是否传入precision +if [[ $precision == "" ]];then + echo "[Error] para \"precision\" must be confing" + exit 1 +fi + +# 校验是否传入epochs +if [[ $epochs == "" ]];then + echo "[Error] para \"epochs\" must be confing" + exit 1 +fi + +###############指定训练脚本执行路径############### +# cd到与test文件夹同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径 +cur_path=$(pwd) +cur_path_last_dirname=${cur_path##*/} +if [ x"${cur_path_last_dirname}" == x"test" ]; then + test_path_dir=${cur_path} + cd .. + cur_path=$(pwd) +else + test_path_dir=${cur_path}/test +fi + +ASCEND_DEVICE_ID=0 + +if [ -d ${cur_path}/test/output/${ASCEND_DEVICE_ID} ]; then + rm -rf ${cur_path}/test/output/${ASCEND_DEVICE_ID} + mkdir -p ${cur_path}/test/output/${ASCEND_DEVICE_ID} +else + mkdir -p ${cur_path}/test/output/${ASCEND_DEVICE_ID} +fi + +start_time=$(date +%s) +# 非平台场景时source 环境变量 +check_etp_flag=$(env | grep etp_running_flag) +etp_flag=$(echo ${check_etp_flag#*=}) +if [ x"${etp_flag}" != x"true" ]; then + source ${test_path_dir}/env_npu.sh +fi + +torchrun --nproc_per_node 8 train.py \ + --model DiT-XL/2 \ + --data-path ${data_path} \ + --image-size ${image_size} \ + --global-batch-size ${global_batch_size} \ + --precision ${precision} \ + --epochs ${epochs} \ + >$cur_path/test/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & +wait + + +# 训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +# 训练用例信息,不需要修改 +BatchSize=${global_batch_size} +DeviceType=$(uname -m) +CaseName=${Network}_bs${BatchSize}_${WORLD_SIZE}'p'_'acc' + +# 结果打印,不需要修改 +echo "------------------ Final result ------------------" +# 输出性能FPS,需要模型审视修改 +avg_time=`grep -a 'Steps/Sec:' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk -F "Steps/Sec: " '{print $2}' | awk '{a+=$1} END {if (NR != 0) printf("%.3f",a/NR)}'` +FPS=`echo "$avg_time * $BatchSize" |bc` +# 打印,不需要修改 +echo "Final Performance images/sec : $FPS" + +# 输出训练精度,需要模型审视修改 +train_loss=$(grep -a "Steps/Sec:" ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log |tail -1|awk -F "Train Loss: " '{print $2}' |awk -F ", " '{print $1}'| awk -F ", " '{print $1}') +# 打印,不需要修改 +echo "Final Train Accuracy : ${train_loss}" +echo "E2E Training Duration sec : $e2e_time" + + +# 性能看护结果汇总 +# 获取性能数据,不需要修改 +# 吞吐量 +ActualFPS=${FPS} +# 训练总时长 +TrainingTime=`grep -a 'Time' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk -F "Time: " '{print $2}'|awk -F "," '{print $1}'| awk '{a+=$1} END {printf("%.3f",a)}'` + +# 从train_$ASCEND_DEVICE_ID.log提取Loss到train_${CaseName}_loss.txt中,需要根据模型审视 +grep "Time" ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log | awk -F "Loss: " '{print $2}' >>${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt + +# # 最后一个迭代loss值,不需要修改 +# ActualLoss=$(awk 'END {print}' ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt) + +# 关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" >${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${WORLD_SIZE}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainAccuracy = ${train_loss}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}_perf_report.log +echo "TrainingTime = ${TrainingTime}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}_perf_report.log +echo "E2ETrainingTime = ${e2e_time}" >>${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}_perf_report.log \ No newline at end of file diff --git a/PyTorch/built-in/mlm/DiT/train.py b/PyTorch/built-in/mlm/DiT/train.py index 7cfee8089b..765e0b7282 100644 --- a/PyTorch/built-in/mlm/DiT/train.py +++ b/PyTorch/built-in/mlm/DiT/train.py @@ -1,3 +1,17 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. @@ -30,6 +44,10 @@ import os from models import DiT_models from diffusion import create_diffusion from diffusers.models import AutoencoderKL +from utils.adamw import AdamW +from utils.device_utils import is_npu_available +if is_npu_available(): + from torch_npu.contrib import transfer_to_npu ################################################################################# @@ -143,6 +161,11 @@ def main(args): input_size=latent_size, num_classes=args.num_classes ) + + if args.precision == "bf16": + print("Enable bfloat16...") + model.to(torch.bfloat16) + # Note that parameter initialization is done within the DiT constructor ema = deepcopy(model).to(device) # Create an EMA of the model for use after training requires_grad(ema, False) @@ -151,8 +174,8 @@ def main(args): vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) logger.info(f"DiT Parameters: {sum(p.numel() for p in model.parameters()):,}") - # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper): - opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0) + # Setup optimizer (we used Adam betas=(0.9, 0.998) and a constant learning rate of 1e-4 in our paper): + opt = AdamW(model.parameters(), lr=1e-4, weight_decay=0, betas=(0.9, 0.998)) # Setup data: transform = transforms.Compose([ @@ -203,7 +226,7 @@ def main(args): x = vae.encode(x).latent_dist.sample().mul_(0.18215) t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device) model_kwargs = dict(y=y) - loss_dict = diffusion.training_losses(model, x, t, model_kwargs) + loss_dict = diffusion.training_losses(model, x, t, args.precision, model_kwargs) loss = loss_dict["loss"].mean() opt.zero_grad() loss.backward() @@ -263,7 +286,8 @@ if __name__ == "__main__": parser.add_argument("--global-seed", type=int, default=0) parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") # Choice doesn't affect training parser.add_argument("--num-workers", type=int, default=4) - parser.add_argument("--log-every", type=int, default=100) + parser.add_argument("--log-every", type=int, default=1) parser.add_argument("--ckpt-every", type=int, default=50_000) + parser.add_argument("--precision", type=str, choices=["fp32", "bf16"], default="fp32") args = parser.parse_args() main(args) diff --git a/PyTorch/built-in/mlm/DiT/utils/adamw.py b/PyTorch/built-in/mlm/DiT/utils/adamw.py new file mode 100644 index 0000000000..9f73ae5833 --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/utils/adamw.py @@ -0,0 +1,150 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional +import torch +import torch_npu +from torch import Tensor +from torch.optim.optimizer import Optimizer + + +def adamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + step: int, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool): + r"""Functional API that performs AdamW algorithm computation. + See :class:`~torch.optim.AdamW` for details. + """ + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + + # Perform stepweight decay + ## param.mul_(1 - lr * weight_decay) + bias_correction1 = beta1 ** (step - 1) + bias_correction2 = beta2 ** (step - 1) + + param.data, exp_avg, exp_avg_sq = torch_npu.npu_apply_adam_w( + bias_correction1, + bias_correction2, + lr, + weight_decay, + beta1, + beta2, + eps, + grad, + None, + amsgrad, + maximize, + out=(param.data, exp_avg, exp_avg_sq) + ) + + +class AdamW(Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False, *, maximize: bool = False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + group.setdefault('maximize', False) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + group['step'], + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=group['maximize']) + + return loss diff --git a/PyTorch/built-in/mlm/DiT/utils/device_utils.py b/PyTorch/built-in/mlm/DiT/utils/device_utils.py new file mode 100644 index 0000000000..ea890e43b0 --- /dev/null +++ b/PyTorch/built-in/mlm/DiT/utils/device_utils.py @@ -0,0 +1,31 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import importlib + + +def is_npu_available(): + "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" + if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None: + return False + + import torch_npu + + try: + # Will raise a RuntimeError if no NPU is found + _ = torch.npu.device_count() + return torch.npu.is_available() + except RuntimeError: + return False \ No newline at end of file -- Gitee