From 7c8a91d9ddf2bcb62c13610fb3fba34c20733d14 Mon Sep 17 00:00:00 2001 From: mookie Date: Fri, 13 Oct 2023 18:06:19 +0800 Subject: [PATCH] init open_flamingo --- .../OpenFlamingo_ for PyTorch/HISTORY.md | 3 + .../others/OpenFlamingo_ for PyTorch/LICENSE | 21 + .../OpenFlamingo_ for PyTorch/MODEL_CARD.md | 44 + .../others/OpenFlamingo_ for PyTorch/Makefile | 19 + .../OpenFlamingo_ for PyTorch/README.md | 128 +++ .../OpenFlamingo_ for PyTorch/README_EN.md | 225 ++++ .../TERMS_AND_CONDITIONS.md | 15 + .../OpenFlamingo_ for PyTorch/environment.yml | 10 + .../open_flamingo/__init__.py | 2 + .../open_flamingo/eval/README.md | 29 + .../open_flamingo/eval/__init__.py | 1 + .../open_flamingo/eval/coco_metric.py | 22 + .../open_flamingo/eval/eval_datasets.py | 126 +++ .../open_flamingo/eval/eval_model.py | 63 ++ .../open_flamingo/eval/evaluate.py | 988 ++++++++++++++++ .../open_flamingo/eval/imagenet_utils.py | 1007 +++++++++++++++++ .../open_flamingo/eval/models/blip.py | 110 ++ .../eval/models/open_flamingo.py | 112 ++ .../open_flamingo/eval/ok_vqa_utils.py | 214 ++++ .../open_flamingo/eval/vqa_metric.py | 581 ++++++++++ .../open_flamingo/scripts/run_eval.sh | 38 + .../open_flamingo/scripts/run_eval_backup.sh | 45 + .../open_flamingo/src/__init__.py | 0 .../open_flamingo/src/factory.py | 109 ++ .../open_flamingo/src/flamingo.py | 198 ++++ .../open_flamingo/src/flamingo_lm.py | 138 +++ .../open_flamingo/src/helpers.py | 275 +++++ .../open_flamingo/src/utils.py | 31 + .../open_flamingo/train/__init__.py | 1 + .../train/convert_mmc4_to_wds.py | 72 ++ .../open_flamingo/train/data.py | 576 ++++++++++ .../open_flamingo/train/distributed.py | 128 +++ .../open_flamingo/train/train.py | 490 ++++++++ .../open_flamingo/train/train_utils.py | 398 +++++++ .../requirements-dev.txt | 5 + .../requirements.txt | 19 + .../others/OpenFlamingo_ for PyTorch/setup.py | 57 + .../tests/test_flamingo_model.py | 77 ++ .../OpenFlamingo_ for PyTorch/train_4_npus.sh | 17 + 39 files changed, 6394 insertions(+) create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/HISTORY.md create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/LICENSE create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/MODEL_CARD.md create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/Makefile create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/README.md create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/README_EN.md create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/TERMS_AND_CONDITIONS.md create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/environment.yml create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/__init__.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/README.md create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/__init__.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/coco_metric.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/eval_datasets.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/eval_model.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/evaluate.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/imagenet_utils.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/models/blip.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/models/open_flamingo.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/ok_vqa_utils.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/vqa_metric.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/scripts/run_eval.sh create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/scripts/run_eval_backup.sh create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/__init__.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/factory.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/flamingo.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/flamingo_lm.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/helpers.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/utils.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/__init__.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/convert_mmc4_to_wds.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/data.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/distributed.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/train.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/train_utils.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/requirements-dev.txt create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/requirements.txt create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/setup.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/tests/test_flamingo_model.py create mode 100644 PyTorch/contrib/others/OpenFlamingo_ for PyTorch/train_4_npus.sh diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/HISTORY.md b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/HISTORY.md new file mode 100644 index 0000000000..5567205091 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/HISTORY.md @@ -0,0 +1,3 @@ +## 1.0.0 + +* it works \ No newline at end of file diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/LICENSE b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/LICENSE new file mode 100644 index 0000000000..206be3ebbf --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Anas Awadalla, Irena Gao, Joshua Gardner, Jack Hessel, Yusuf Hanafy, Wanrong Zhu, Kalyani Marathe, Yonatan Bitton, Samir Gadre, Jenia Jitsev, Simon Kornblith, Pang Wei Koh, Gabriel Ilharco, Mitchell Wortsman, Ludwig Schmidt. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/MODEL_CARD.md b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/MODEL_CARD.md new file mode 100644 index 0000000000..b1264ae72d --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/MODEL_CARD.md @@ -0,0 +1,44 @@ +--- +language: en +datasets: +- laion2b +--- + +# OpenFlamingo-9B + +[Blog post]() | [Code](https://github.com/mlfoundations/open_flamingo) | [Demo](https://7164d2142d11.ngrok.app) + +OpenFlamingo is an open source implementation of DeepMind's [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model) models. +OpenFlamingo-9B is built off of [CLIP ViT-L/14](https://huggingface.co/openai/clip-vit-large-patch14) and [LLaMA-7B](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/). + + +## Model Details +We freeze the pretrained vision encoder and language model, and then we train connecting Perceiver modules and cross-attention layers, following the original Flamingo paper. + +Our training data is a mixture of [LAION 2B](https://huggingface.co/datasets/laion/laion2B-en) and a large interleaved image-text dataset called Multimodal C4, which will be released soon. + +The current model is an early checkpoint of an ongoing effort. This checkpoint has seen 5 million interleaved image-text examples from Multimodal C4 and 10 million samples from LAION 2B. + +## Uses +OpenFlamingo-9B is intended to be used **for academic research purposes only.** Commercial use is prohibited, in line with LLaMA's non-commercial license. + +### Bias, Risks, and Limitations +This model may generate inaccurate or offensive outputs, reflecting biases in its training data and pretrained priors. + +In an effort to mitigate current potential biases and harms, we have deployed a text content filter on model outputs in the OpenFlamingo demo. We continue to red-team the model to understand and improve its safety. + +## Evaluation +We've evaluated this checkpoint on the validation sets for two vision-language tasks: COCO captioning and VQAv2. Results are displayed below. + +**COCO (CIDEr)** + +|0-shot|4-shot|8-shot|16-shot|32-shot| +|--|--|--|--|--| +|65.52|74.28|79.26|81.84|84.52| + + +**VQAv2 (VQA accuracy)** + +|0-shot|4-shot|8-shot|16-shot|32-shot| +|---|---|---|---|---| +|43.55|44.05|47.5|48.87|50.34| diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/Makefile b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/Makefile new file mode 100644 index 0000000000..bdaab6f436 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/Makefile @@ -0,0 +1,19 @@ +install: ## [Local development] Upgrade pip, install requirements, install package. + python -m pip install -U pip + python -m pip install -e . + +install-dev: ## [Local development] Install test requirements + python -m pip install -r requirements-dev.txt + +lint: ## [Local development] Run mypy, pylint and black + python -m mypy open_flamingo + python -m pylint open_flamingo + python -m black --check -l 120 open_flamingo + +black: ## [Local development] Auto-format python code using black + python -m black -l 120 . + +.PHONY: help + +help: # Run `make help` to get help on the make commands + @grep -E '^[0-9a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/README.md b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/README.md new file mode 100644 index 0000000000..2476826a8d --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/README.md @@ -0,0 +1,128 @@ +# OpenFlamingo for PyTorch + +- [概述](#概述) +- [准备训练环境](#准备训练环境) +- [开始训练](#开始训练) +- [训练结果展示](#训练结果展示) +- [版本说明](#版本说明) + + + +# 概述 +这是对OpenFlamingo官方仓的迁移,使其能在NPU上进行训练和推理。 +## 简述 + +OpenFlamingo是一个用于通过上下文学习训练视觉语言模型的开源框架,也是DeepMind的Flamingo模型的开源复制品。OpenFlamingo的核心是一个支持大型多模态模型(LMM)训练和评估的框架。 + +- 参考实现: + + ``` + url=https://github.com/mlfoundations/open_flamingo.git + commit_id=c2e80b4f37f12677c1925e36fe2101dea07e01a8 + ``` + +- 适配昇腾 AI 处理器的实现: + + ``` + url=https://gitee.com/ascend/ModelZoo-PyTorch.git + code_path=PyTorch/contrib/others/ + ``` + +# 准备训练环境 + +## 准备环境 + +- 当前模型支持的PyTorch 如下表所示。 + + **表 1** 版本配套表 + + | Torch_Version | 三方库依赖版本 | + | :-----------: | :-------------------------------------: | + | PyTorch 1.11.0 | - | + +- 环境准备指导。 + + 请参考《[Pytorch框架训练环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/ptes)》。 + +- 安装依赖。 + + ``` + pip install -r requirements.txt + ``` + + +## 准备数据集 + +1. 获取数据集 + + 用户自行下载原始数据集LAION-2B,在源码包根目录下新建目录datasets,并将数据集解压至该目录,数据集目录结构参考如下所示: + + ``` + ├── laion2b + ├──00000.parquet + ├──00000.tar + ├──00000_stats.json + ├──00001.parquet + ├──00001.tar + ├──00001_stats.json + ├──... + ``` + +# 开始训练 + +## 训练模型 + +1. 进入解压后的源码包根目录 + + ``` + cd /${模型文件夹名称} + ``` + +2. 运行训练脚本 + + 支持多机多卡训练,以单机4卡为例: + + ``` + # 启动单机4卡训练 + bash ./train_4_npus.sh + ``` + 模型训练脚本参数说明如下。 + ``` + 分布式训练参数: + nnodes=1 // 节点个数 + nproc_per_node=4 // 每个节点上使用的NPU数量 + + 公共参数: + --run_name // 任务名称 + --vision_encoder_pretrained // clip预训练数据集名称 + --lm_path // 语言编码器路径 + --tokenizer_path // 加载模型权重路径 + --dataset_resampled // 数据集重采样 + --laion_shards // laion shards路径 + --batch_size_laion // laion shards批次大小 + --train_num_samples_laion // 每轮训练样本数 + --logging_steps // 日志打印step间隔 + --learning_rate // 学习率 + --loss_multiplier_laion // laion loss权重 + --workers // 数据集并行处理数 + --lr_scheduler // 每个节点上的NPU数量 + --warmup_steps // 热身学习率steps数 + --use_media_placement_augmentation // 使用媒体位置增强 + ``` + +# 训练结果展示 + +**表 2** 训练结果展示表 + +| NAME | FPS | batch_size | AMP_Type | Torch_Version | +|:------:|:--:|:----------:|:--------:|:-------------:| +| 4p-竞品A | \ | \ | \ | \ | +| 4p-NPU | 9.64 | 4 | fp32 | 1.11.0 | + +# 版本说明 + +## 变更 + +2023.11.7:首次发布。 + +## FAQ \ No newline at end of file diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/README_EN.md b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/README_EN.md new file mode 100644 index 0000000000..8bd3ab633f --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/README_EN.md @@ -0,0 +1,225 @@ +# 🦩 OpenFlamingo + +[![PyPI version](https://badge.fury.io/py/open_flamingo.svg)](https://badge.fury.io/py/open_flamingo) + +[Blog post](https://laion.ai/blog/open-flamingo/) | Paper (coming soon) + +Welcome to our open source version of DeepMind's [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model) model! In this repository, we provide a PyTorch implementation for training and evaluating OpenFlamingo models. We also provide an initial [OpenFlamingo 9B model](https://huggingface.co/openflamingo/OpenFlamingo-9B) trained on a new [Multimodal C4](https://github.com/allenai/mmc4) dataset. Please refer to our blog post for more details. + +This repo is still under development, and we hope to release better performing and larger OpenFlamingo models soon. If you have any questions, please feel free to open an issue. We also welcome contributions! + +# Table of Contents +- [Installation](#installation) +- [Approach](#approach) + * [Model architecture](#model-architecture) +- [Usage](#usage) + * [Initializing an OpenFlamingo model](#initializing-an-openflamingo-model) + * [Generating text](#generating-text) +- [Training](#training) + * [Dataset](#dataset) +- [Evaluation](#evaluation) +- [Future plans](#future-plans) +- [Team](#team) +- [Acknowledgments](#acknowledgments) +- [Citing](#citing) + +# Installation + +To install the package in an existing environment, run +``` +pip install open-flamingo +``` + +or to create a conda environment for running OpenFlamingo, run +``` +conda env create -f environment.yml +``` + +# Usage +We provide an initial [OpenFlamingo 9B model](https://huggingface.co/openflamingo/OpenFlamingo-9B) using a CLIP ViT-Large vision encoder and a LLaMA-7B language model. In general, we support any [CLIP vision encoder](https://huggingface.co/models?search=clip). For the language model, we support [LLaMA](https://huggingface.co/models?search=llama), [OPT](https://huggingface.co/models?search=opt), [GPT-Neo](https://huggingface.co/models?search=gpt-neo), [GPT-J](https://huggingface.co/models?search=gptj), and [Pythia](https://huggingface.co/models?search=pythia) models. + +NOTE: To use LLaMA models, you will need to use this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py) for converting LLaMA weights to HuggingFace format. + +## Initializing an OpenFlamingo model +``` python +from open_flamingo import create_model_and_transforms + +model, image_processor, tokenizer = create_model_and_transforms( + clip_vision_encoder_path="ViT-L-14", + clip_vision_encoder_pretrained="openai", + lang_encoder_path="", + tokenizer_path="", + cross_attn_every_n_layers=4 +) + +# grab model checkpoint from huggingface hub +from huggingface_hub import hf_hub_download +import torch + +checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-9B", "checkpoint.pt") +model.load_state_dict(torch.load(checkpoint_path), strict=False) +``` + +## Generating text +Here is an example of generating text conditioned on interleaved images/text, in this case we will do few-shot image captioning. + +``` python +from PIL import Image +import requests + +""" +Step 1: Load images +""" +demo_image_one = Image.open( + requests.get( + "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True + ).raw +) + +demo_image_two = Image.open( + requests.get( + "http://images.cocodataset.org/test-stuff2017/000000028137.jpg", + stream=True + ).raw +) + +query_image = Image.open( + requests.get( + "http://images.cocodataset.org/test-stuff2017/000000028352.jpg", + stream=True + ).raw +) + + +""" +Step 2: Preprocessing images +Details: For OpenFlamingo, we expect the image to be a torch tensor of shape + batch_size x num_media x num_frames x channels x height x width. + In this case batch_size = 1, num_media = 3, num_frames = 1 + (this will always be one expect for video which we don't support yet), + channels = 3, height = 224, width = 224. +""" +vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)] +vision_x = torch.cat(vision_x, dim=0) +vision_x = vision_x.unsqueeze(1).unsqueeze(0) + +""" +Step 3: Preprocessing text +Details: In the text we expect an special token to indicate where an image is. + We also expect an <|endofchunk|> special token to indicate the end of the text + portion associated with an image. +""" +tokenizer.padding_side = "left" # For generation padding tokens should be on the left +lang_x = tokenizer( + ["An image of two cats.<|endofchunk|>An image of a bathroom sink.<|endofchunk|>An image of"], + return_tensors="pt", +) + + +""" +Step 4: Generate text +""" +generated_text = model.generate( + vision_x=vision_x, + lang_x=lang_x["input_ids"], + attention_mask=lang_x["attention_mask"], + max_new_tokens=20, + num_beams=3, +) + +print("Generated text: ", tokenizer.decode(generated_text[0])) +``` + +# Approach +OpenFlamingo is a multimodal language model that can be used for a variety of tasks. It is trained on a large multimodal dataset (e.g. [Multimodal C4](https://github.com/allenai/mmc4)) and can be used to generate text conditioned on interleaved images/text. For example, OpenFlamingo can be used to generate a caption for an image, or to generate a question given an image and a text passage. The benefit of this approach is that we are able to rapidly adapt to new tasks using in-context training. + +## Model architecture +OpenFlamingo seeks to fuse a pretrained vision encoder and a language model using cross attention layers. The model architecture is shown below. + +![OpenFlamingo architecture](docs/flamingo.png) +Credit: [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model) + +# Training +To train a model, modify the following example command, which uses OPT 1.3B as an example LM: +``` +torchrun --nnodes=1 --nproc_per_node=4 train.py \ +--run_name flamingo3B \ +--lm_path facebook/opt-1.3b \ +--tokenizer_path facebook/opt-1.3b \ +--dataset_resampled \ +--laion_shards "/path/to/shards/shard-{0000..0999}.tar" \ +--mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \ +--batch_size_mmc4 4 \ +--batch_size_laion 8 \ +--train_num_samples_mmc4 125000 \ +--train_num_samples_laion 250000 \ +--loss_multiplier_laion 0.2 \ +--workers=6 \ +--num_epochs 250 \ +--lr_scheduler constant \ +--warmup_steps 5000 \ +--use_media_placement_augmentation \ +--mmc4_textsim_threshold 0.32 +``` + +## Dataset +We expect all our training datasets to be [WebDataset](https://github.com/webdataset/webdataset) shards. +We train our models on the [LAION 2B](https://huggingface.co/datasets/laion/laion2B-en) and [Multimodal C4](https://github.com/allenai/mmc4) datasets. By default the LAION 2B dataset is in WebDataset format if it is downloaded using the [img2dataset tool](https://github.com/rom1504/img2dataset) and Multimodal C4 can be converted to the WebDataset format using this [script](https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/train/convert_mmc4_to_wds.py). + + +# Evaluation +We currently support running evaluations on [COCO](https://cocodataset.org/#home), [VQAv2](https://visualqa.org/index.html), [OKVQA](https://okvqa.allenai.org), [Flickr30k](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset), and [ImageNet](https://image-net.org/index.php). Note that currently these evaluations are ran in validation mode (as specified in the Flamingo paper). We will be adding support for running evaluations in test mode in the future. + + +To run evaluations on OKVQA you will need to run the following command: +``` +import nltk +nltk.download('wordnet') +``` + +To evaluate the model, run the script at `open_flamingo/scripts/run_eval.sh` + +# Future plans +- [ ] Add support for video input +- [ ] Release better performing and larger OpenFlamingo models +- [ ] Expand our evaluation suite +- [ ] Add support for FSDP training + +# Team + +OpenFlamingo is developed by: + +[Anas Awadalla](https://anas-awadalla.streamlit.app/), [Irena Gao](https://i-gao.github.io/), [Joshua Gardner](https://homes.cs.washington.edu/~jpgard/), [Jack Hessel](https://jmhessel.com/), [Yusuf Hanafy](https://www.linkedin.com/in/yusufhanafy/), [Wanrong Zhu](https://wanrong-zhu.com/), [Kalyani Marathe](https://sites.google.com/uw.edu/kalyanimarathe/home?authuser=0), [Yonatan Bitton](https://yonatanbitton.github.io/), [Samir Gadre](https://sagadre.github.io/), [Jenia Jitsev](https://scholar.google.de/citations?user=p1FuAMkAAAAJ&hl=en), [Simon Kornblith](https://simonster.com/), [Pang Wei Koh](https://koh.pw/), [Gabriel Ilharco](https://gabrielilharco.com/), [Mitchell Wortsman](https://mitchellnw.github.io/), [Ludwig Schmidt](https://people.csail.mit.edu/ludwigs/). + +The team is primarily from the University of Washington, Stanford, AI2, UCSB, and Google. + +# Acknowledgments +This code is based on Lucidrains' [flamingo implementation](https://github.com/lucidrains/flamingo-pytorch) and David Hansmair's [flamingo-mini repo](https://github.com/dhansmair/flamingo-mini). Thank you for making your code public! We also thank the [OpenCLIP](https://github.com/mlfoundations/open_clip) team as we use their data loading code and take inspiration from their library design. + +We would also like to thank [Jean-Baptiste Alayrac](https://www.jbalayrac.com) and [Antoine Miech](https://antoine77340.github.io) for their advice, [Rohan Taori](https://www.rohantaori.com/), [Nicholas Schiefer](https://nicholasschiefer.com/), [Deep Ganguli](https://hai.stanford.edu/people/deep-ganguli), [Thomas Liao](https://thomasliao.com/), [Tatsunori Hashimoto](https://thashim.github.io/), and [Nicholas Carlini](https://nicholas.carlini.com/) for their help with assessing the safety risks of our release, and to [Stability AI](https://stability.ai) for providing us with compute resources to train these models. + +# Citing +If you found this repository useful, please consider citing: + +``` +@software{anas_awadalla_2023_7733589, + author = {Awadalla, Anas and Gao, Irena and Gardner, Joshua and Hessel, Jack and Hanafy, Yusuf and Zhu, Wanrong and Marathe, Kalyani and Bitton, Yonatan and Gadre, Samir and Jitsev, Jenia and Kornblith, Simon and Koh, Pang Wei and Ilharco, Gabriel and Wortsman, Mitchell and Schmidt, Ludwig}, + title = {OpenFlamingo}, + month = mar, + year = 2023, + publisher = {Zenodo}, + version = {v0.1.1}, + doi = {10.5281/zenodo.7733589}, + url = {https://doi.org/10.5281/zenodo.7733589} +} +``` + +``` +@article{Alayrac2022FlamingoAV, + title={Flamingo: a Visual Language Model for Few-Shot Learning}, + author={Jean-Baptiste Alayrac and Jeff Donahue and Pauline Luc and Antoine Miech and Iain Barr and Yana Hasson and Karel Lenc and Arthur Mensch and Katie Millican and Malcolm Reynolds and Roman Ring and Eliza Rutherford and Serkan Cabi and Tengda Han and Zhitao Gong and Sina Samangooei and Marianne Monteiro and Jacob Menick and Sebastian Borgeaud and Andy Brock and Aida Nematzadeh and Sahand Sharifzadeh and Mikolaj Binkowski and Ricardo Barreira and Oriol Vinyals and Andrew Zisserman and Karen Simonyan}, + journal={ArXiv}, + year={2022}, + volume={abs/2204.14198} +} +``` diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/TERMS_AND_CONDITIONS.md b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/TERMS_AND_CONDITIONS.md new file mode 100644 index 0000000000..3571e77642 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/TERMS_AND_CONDITIONS.md @@ -0,0 +1,15 @@ +**Please read the following information carefully before proceeding.** + +OpenFlamingo is a **research prototype** that aims to enable users to interact with AI through both language and images. AI agents equipped with both language and visual understanding can be useful on a larger variety of tasks compared to models that communicate solely via language. By releasing an open-source research prototype, we hope to help the research community better understand the risks and limitations of modern visual-language AI models and accelerate the development of safer and more reliable methods. + +- [ ] I understand that OpenFlamingo is a research prototype and I will only use it for non-commercial research purposes. + +**Limitations.** OpenFlamingo is built on top of the LLaMA large language model developed by Meta AI. Large language models, including LLaMA, are trained on mostly unfiltered internet data, and have been shown to be able to produce toxic, unethical, inaccurate, and harmful content. On top of this, OpenFlamingo’s ability to support visual inputs creates additional risks, since it can be used in a wider variety of applications; image+text models may carry additional risks specific to multimodality. Please use discretion when assessing the accuracy or appropriateness of the model’s outputs, and be mindful before sharing its results. + +- [ ] I understand that OpenFlamingo may produce unintended, inappropriate, offensive, and/or inaccurate results. I agree to take full responsibility for any use of the OpenFlamingo outputs that I generate. + +**Privacy and data collection.** This demo does NOT store any personal information on its users, and it does NOT store user queries. + +**Licensing.** As OpenFlamingo is built on top of the LLaMA large language model from Meta AI, the LLaMA license agreement (as documented in the Meta request form) also applies. + +- [ ] I have read and agree to the terms of the LLaMA license agreement. diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/environment.yml b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/environment.yml new file mode 100644 index 0000000000..1d477e5bf5 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/environment.yml @@ -0,0 +1,10 @@ +name: openflamingo +channels: + - defaults +dependencies: + - python=3.9 + - conda-forge::openjdk + - pip + - pip: + - -r requirements.txt + - -e . diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/__init__.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/__init__.py new file mode 100644 index 0000000000..ab67750bb7 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/__init__.py @@ -0,0 +1,2 @@ +from .src.flamingo import Flamingo +from .src.factory import create_model_and_transforms diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/README.md b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/README.md new file mode 100644 index 0000000000..3289e50e24 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/README.md @@ -0,0 +1,29 @@ +# OpenFlamingo Evaluation Suite + +This is the evaluation module of OpenFlamingo. It contains a set of utilities for evaluating multimodal models on various benchmarking datasets. + +*This module is a work in progress! We will be updating this README as it develops. In the meantime, if you notice an issue, please file a Bug Report or Feature Request [here](https://github.com/mlfoundations/open_flamingo/issues/new/choose).* + +# Running the evaluation suite on OpenFlamingo-9B + +The easiest way to run the evaluation suite is by using the script at `open_flamingo/open_flamingo/scripts/run_eval.sh`. + +Before running that script, we suggest to download a local copy of the OpenFlamingo model, as follows: + +``` +# grab model checkpoint from huggingface hub +from huggingface_hub import hf_hub_download +HF_TOKEN="" + +checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-9B", "checkpoint.pt") +checkpoint_path= hf_hub_download("openflamingo/OpenFlamingo-9B", + "checkpoint.pt", + local_dir="openflamingo/OpenFlamingo-9B", + cache_dir="openflamingo/OpenFlamingo-9B", + local_dir_use_symlinks=False, + token=HF_TOKEN) +print(checkpoint_path) +## openflamingo/OpenFlamingo-9B/checkpoint.pt +``` + +This should place the OpenFlamingo model at the expected location in the evaluation script. diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/__init__.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/__init__.py @@ -0,0 +1 @@ + diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/coco_metric.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/coco_metric.py new file mode 100644 index 0000000000..6f159c6259 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/coco_metric.py @@ -0,0 +1,22 @@ +from pycocoevalcap.eval import COCOEvalCap +from pycocotools.coco import COCO + + +def compute_cider( + result_path, + annotations_path, +): + # create coco object and coco_result object + coco = COCO(annotations_path) + coco_result = coco.loadRes(result_path) + + # create coco_eval object by taking coco and coco_result + coco_eval = COCOEvalCap(coco, coco_result) + coco_eval.params["image_id"] = coco_result.getImgIds() + coco_eval.evaluate() + + return coco_eval.eval + + +def postprocess_captioning_generation(predictions): + return predictions.split("Output", 1)[0] diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/eval_datasets.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/eval_datasets.py new file mode 100644 index 0000000000..51d05b380d --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/eval_datasets.py @@ -0,0 +1,126 @@ +import json +import os + +from PIL import Image +from torch.utils.data import Dataset +from torchvision.datasets import ImageFolder + +from open_flamingo.eval.imagenet_utils import IMAGENET_1K_CLASS_ID_TO_LABEL + + +class CaptionDataset(Dataset): + def __init__( + self, + image_train_dir_path, + annotations_path, + is_train, + dataset_name, + image_val_dir_path=None, + ): + self.image_train_dir_path = image_train_dir_path + self.image_val_dir_path = image_val_dir_path + self.annotations = [] + self.is_train = is_train + self.dataset_name = dataset_name + + print(annotations_path) + full_annotations = json.load(open(annotations_path))["images"] + + print(full_annotations[0:3]) + print(self.is_train) + + for i in range(len(full_annotations)): + if self.is_train and full_annotations[i]["split"] != "train": + continue + elif not self.is_train and full_annotations[i]["split"] != "test": + continue + + self.annotations.append(full_annotations[i]) + + def __len__(self): + return len(self.annotations) + + def __getitem__(self, idx): + if self.dataset_name == "coco": + image = Image.open( + os.path.join( + self.image_train_dir_path, self.annotations[idx]["filename"] + ) + if self.annotations[idx]["filepath"] == "train2014" + else os.path.join( + self.image_val_dir_path, self.annotations[idx]["filename"] + ) + ) + elif self.dataset_name == "flickr": + image = Image.open( + os.path.join( + self.image_train_dir_path, self.annotations[idx]["filename"] + ) + ) + image.load() + caption = self.annotations[idx]["sentences"][0]["raw"] + return { + "image": image, + "caption": caption, + "image_id": self.annotations[idx]["cocoid"] + if self.dataset_name == "coco" + else self.annotations[idx]["filename"].split(".")[0], + } + + +class VQADataset(Dataset): + def __init__( + self, image_dir_path, question_path, annotations_path, is_train, dataset_name + ): + self.questions = json.load(open(question_path, "r"))["questions"] + self.answers = json.load(open(annotations_path, "r"))["annotations"] + self.image_dir_path = image_dir_path + self.is_train = is_train + self.dataset_name = dataset_name + + def __len__(self): + return len(self.questions) + + def get_img_path(self, question): + if self.dataset_name in {"vqav2", "ok-vqa"}: + return os.path.join( + self.image_dir_path, + f"COCO_train2014_{question['image_id']:012d}.jpg" + if self.is_train + else f"COCO_val2014_{question['image_id']:012d}.jpg", + ) + elif self.dataset_name == "vizwiz": + return os.path.join(self.image_dir_path, question["image_id"]) + elif self.dataset_name == "textvqa": + return os.path.join(self.image_dir_path, f"{question['image_id']}.jpg") + else: + raise Exception(f"Unknown VQA dataset {self.dataset_name}") + + def __getitem__(self, idx): + question = self.questions[idx] + answers = self.answers[idx] + img_path = self.get_img_path(question) + image = Image.open(img_path) + image.load() + return { + "image": image, + "question": question["question"], + "answers": [a["answer"] for a in answers["answers"]], + "question_id": question["question_id"], + } + + +class ImageNetDataset(ImageFolder): + """Class to represent the ImageNet1k dataset.""" + + def __init__(self, root, **kwargs): + super().__init__(root=root, **kwargs) + + def __getitem__(self, idx): + sample, target = super().__getitem__(idx) + target_label = IMAGENET_1K_CLASS_ID_TO_LABEL[target] + return { + "image": sample, + "class_id": target, # numeric ID of the ImageNet class + "class_name": target_label, # human-readable name of ImageNet class + } diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/eval_model.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/eval_model.py new file mode 100644 index 0000000000..9bd99606e1 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/eval_model.py @@ -0,0 +1,63 @@ +import abc +import argparse +from typing import List + +from PIL import Image + + +class BaseEvalModel(abc.ABC): + """Base class encapsulating functionality needed to evaluate a model.""" + + def __init__(self, args: List[str]): + """Initialize model. + + Args: + args: arguments to model. These should be parsed, or if the model + has no applicable arguments, an error should be thrown if `args` + is non-empty. + """ + + def get_outputs( + self, + batch_text: List[str], + batch_images: List[List[Image.Image]], + max_generation_length: int, + num_beams: int, + length_penalty: float, + ) -> List[str]: + """Get outputs for a batch of images and text. + + Args: + batch_text: list of text strings, with the text "" in place + of any images to be included. + batch_images: images to provide to model. Should be a list of lists, + where each list contains the images for a single example. + max_generation_length: maximum length of the generated caption. + Defaults to 10. + num_beams: number of beams to use for beam search. Defaults to 3. + length_penalty: length penalty for beam search. Defaults to -2.0. + + Returns: + List of decoded output strings. + """ + + def vqa_prompt(self, question, answer=None) -> str: + """Get the prompt to use for VQA evaluation. If the answer is not provided, it should be left blank to be generated by the model. + + Returns: + The prompt to use for VQA. + """ + + def caption_prompt(self, caption=None) -> str: + """Get the prompt to use for caption evaluation. If the caption is not provided, it should be left blank to be generated by the model. + + Returns: + The prompt to use for captioning. + """ + + def classification_prompt(self, class_str=None) -> str: + """Get the prompt to use for classification evaluation. If the class_str is not provided, it should be left blank to be generated by the model. + + Returns: + The prompt to use for classification. + """ diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/evaluate.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/evaluate.py new file mode 100644 index 0000000000..cd59fa0284 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/evaluate.py @@ -0,0 +1,988 @@ +import argparse +import importlib +import json +import os +import random +import uuid +from collections import defaultdict + +from einops import repeat +import more_itertools +import numpy as np +import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu + + +from coco_metric import compute_cider, postprocess_captioning_generation +from eval_datasets import CaptionDataset, VQADataset, ImageNetDataset +from tqdm import tqdm + + +from eval_datasets import VQADataset, ImageNetDataset +from open_flamingo.eval.imagenet_utils import ( + openai_imagenet_classnames, + IMAGENET_1K_CLASS_ID_TO_LABEL, +) + +from eval_model import BaseEvalModel + +from open_flamingo.eval.ok_vqa_utils import postprocess_ok_vqa_generation +from open_flamingo.src.flamingo import Flamingo +from vqa_metric import compute_vqa_accuracy, postprocess_vqa_generation + +parser = argparse.ArgumentParser() +parser.add_argument( + "--results_file", type=str, default=None, help="JSON file to save results" +) + +# Trial arguments +parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int) +parser.add_argument( + "--num_trials", + type=int, + default=1, + help="Number of trials to run for each shot using different demonstrations", +) +parser.add_argument( + "--trial_seeds", + nargs="+", + default=[42], + help="Seeds to use for each trial for picking demonstrations and eval sets", +) +parser.add_argument( + "--num_samples", type=int, default=5000, help="Number of samples to evaluate on" +) +parser.add_argument( + "--query_set_size", type=int, default=2048, help="Size of demonstration query set" +) + +parser.add_argument("--batch_size", type=int, default=8) + +# Per-dataset evaluation flags +parser.add_argument( + "--eval_coco", + action="store_true", + default=False, + help="Whether to evaluate on COCO.", +) +parser.add_argument( + "--eval_vqav2", + action="store_true", + default=False, + help="Whether to evaluate on VQAV2.", +) +parser.add_argument( + "--eval_ok_vqa", + action="store_true", + default=False, + help="Whether to evaluate on OK-VQA.", +) +parser.add_argument( + "--eval_vizwiz", + action="store_true", + default=False, + help="Whether to evaluate on VizWiz.", +) +parser.add_argument( + "--eval_textvqa", + action="store_true", + default=False, + help="Whether to evaluate on TextVQA.", +) +parser.add_argument( + "--eval_imagenet", + action="store_true", + default=False, + help="Whether to evaluate on ImageNet.", +) + +parser.add_argument( + "--eval_flickr30", + action="store_true", + default=False, + help="Whether to evaluate on Flickr30.", +) + +# Dataset arguments + +## Flickr30 Dataset +parser.add_argument( + "--flickr_image_dir_path", + type=str, + help="Path to the flickr30/flickr30k_images directory.", + default=None, +) +parser.add_argument( + "--flickr_karpathy_json_path", + type=str, + help="Path to the dataset_flickr30k.json file.", + default=None, +) +parser.add_argument( + "--flickr_annotations_json_path", + type=str, + help="Path to the dataset_flickr30k_coco_style.json file.", +) +## COCO Dataset +parser.add_argument( + "--coco_train_image_dir_path", + type=str, + default=None, +) +parser.add_argument( + "--coco_val_image_dir_path", + type=str, + default=None, +) +parser.add_argument( + "--coco_karpathy_json_path", + type=str, + default=None, +) +parser.add_argument( + "--coco_annotations_json_path", + type=str, + default=None, +) + +## VQAV2 Dataset +parser.add_argument( + "--vqav2_train_image_dir_path", + type=str, + default=None, +) +parser.add_argument( + "--vqav2_train_questions_json_path", + type=str, + default=None, +) +parser.add_argument( + "--vqav2_train_annotations_json_path", + type=str, + default=None, +) +parser.add_argument( + "--vqav2_test_image_dir_path", + type=str, + default=None, +) +parser.add_argument( + "--vqav2_test_questions_json_path", + type=str, + default=None, +) +parser.add_argument( + "--vqav2_test_annotations_json_path", + type=str, + default=None, +) + +## OK-VQA Dataset +parser.add_argument( + "--ok_vqa_train_image_dir_path", + type=str, + help="Path to the vqav2/train2014 directory.", + default=None, +) +parser.add_argument( + "--ok_vqa_train_questions_json_path", + type=str, + help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.", + default=None, +) +parser.add_argument( + "--ok_vqa_train_annotations_json_path", + type=str, + help="Path to the v2_mscoco_train2014_annotations.json file.", + default=None, +) +parser.add_argument( + "--ok_vqa_test_image_dir_path", + type=str, + help="Path to the vqav2/val2014 directory.", + default=None, +) +parser.add_argument( + "--ok_vqa_test_questions_json_path", + type=str, + help="Path to the v2_OpenEnded_mscoco_val2014_questions.json file.", + default=None, +) +parser.add_argument( + "--ok_vqa_test_annotations_json_path", + type=str, + help="Path to the v2_mscoco_val2014_annotations.json file.", + default=None, +) + +## VizWiz Dataset +parser.add_argument( + "--vizwiz_train_image_dir_path", + type=str, + help="Path to the vizwiz train images directory.", + default=None, +) +parser.add_argument( + "--vizwiz_test_image_dir_path", + type=str, + help="Path to the vizwiz test images directory.", + default=None, +) +parser.add_argument( + "--vizwiz_train_questions_json_path", + type=str, + help="Path to the vizwiz questions json file.", + default=None, +) +parser.add_argument( + "--vizwiz_train_annotations_json_path", + type=str, + help="Path to the vizwiz annotations json file.", + default=None, +) +parser.add_argument( + "--vizwiz_test_questions_json_path", + type=str, + help="Path to the vizwiz questions json file.", + default=None, +) +parser.add_argument( + "--vizwiz_test_annotations_json_path", + type=str, + help="Path to the vizwiz annotations json file.", + default=None, +) + +# TextVQA Dataset +parser.add_argument( + "--textvqa_image_dir_path", + type=str, + help="Path to the textvqa images directory.", + default=None, +) +parser.add_argument( + "--textvqa_train_questions_json_path", + type=str, + help="Path to the textvqa questions json file.", + default=None, +) +parser.add_argument( + "--textvqa_train_annotations_json_path", + type=str, + help="Path to the textvqa annotations json file.", + default=None, +) +parser.add_argument( + "--textvqa_test_questions_json_path", + type=str, + help="Path to the textvqa questions json file.", + default=None, +) +parser.add_argument( + "--textvqa_test_annotations_json_path", + type=str, + help="Path to the textvqa annotations json file.", + default=None, +) + +## Imagenet dataset +parser.add_argument("--imagenet_root", type=str, default="/tmp") + +parser.add_argument( + "--model", + type=str, + help="Model name. Currently only `OpenFlamingo` is supported.", + default="open_flamingo", +) + + +def main(): + args, leftovers = parser.parse_known_args() + module = importlib.import_module(f"open_flamingo.eval.models.{args.model}") + + model_args = { + leftovers[i].lstrip("-"): leftovers[i + 1] for i in range(0, len(leftovers) - 1, 2) + } + model_args['vision_encoder_pretrained'] = "" + model_args = {'lm_path': 'facebook/opt-1.3b', 'lm_tokenizer_path': 'facebook/opt-1.3b', 'vision_encoder_path': 'ViT-L-14', 'checkpoint_path': '/home/data2/linzheyuan/open_flamingo/flamingo3B/checkpoint_0.pt', 'cross_attn_every_n_layers': '4', 'device': '0', 'coco_image_dir_path': '/home/data1/coco/train2017', 'vision_encoder_pretrained': ''} + eval_model = module.EvalModel(model_args) + + if args.model != "open_flamingo" and args.shots != [0]: + raise ValueError("Only 0 shot eval is supported for non-open_flamingo models") + + if len(args.trial_seeds) != args.num_trials: + raise ValueError("Number of trial seeds must be == number of trials.") + + results = defaultdict(list) + + if args.eval_flickr30: + print("Evaluating on Flickr30k...") + for shot in args.shots: + scores = [] + for seed, trial in zip(args.trial_seeds, range(args.num_trials)): + cider_score = evaluate_captioning( + args, + eval_model=eval_model, + num_shots=shot, + seed=seed, + dataset_name="flickr", + ) + print(f"Shots {shot} Trial {trial} CIDEr score: {cider_score}") + scores.append(cider_score) + print(f"Shots {shot} Mean CIDEr score: {np.mean(scores)}") + results["flickr30"].append( + {"shots": shot, "trials": scores, "mean": np.mean(scores)} + ) + + if args.eval_coco: + print("Evaluating on COCO...") + for shot in args.shots: + scores = [] + for seed, trial in zip(args.trial_seeds, range(args.num_trials)): + cider_score = evaluate_captioning( + args, + eval_model=eval_model, + num_shots=shot, + seed=seed, + dataset_name="coco", + ) + print(f"Shots {shot} Trial {trial} CIDEr score: {cider_score}") + scores.append(cider_score) + print(f"Shots {shot} Mean CIDEr score: {np.mean(scores)}") + results["coco"].append( + {"shots": shot, "trials": scores, "mean": np.mean(scores)} + ) + + if args.eval_ok_vqa: + print("Evaluating on OK-VQA...") + for shot in args.shots: + scores = [] + for seed, trial in zip(args.trial_seeds, range(args.num_trials)): + ok_vqa_score = evaluate_vqa( + args=args, + eval_model=eval_model, + num_shots=shot, + seed=seed, + dataset_name="ok_vqa", + ) + print(f"Shots {shot} Trial {trial} OK-VQA score: {ok_vqa_score}") + scores.append(ok_vqa_score) + print(f"Shots {shot} Mean OK-VQA score: {np.mean(scores)}") + results["ok_vqa"].append( + {"shots": shot, "trials": scores, "mean": np.mean(scores)} + ) + + if args.eval_vqav2: + print("Evaluating on VQAv2...") + for shot in args.shots: + scores = [] + for seed, trial in zip(args.trial_seeds, range(args.num_trials)): + vqa_score = evaluate_vqa( + args=args, + eval_model=eval_model, + num_shots=shot, + seed=seed, + dataset_name="vqav2", + ) + print(f"Shots {shot} Trial {trial} VQA score: {vqa_score}") + scores.append(vqa_score) + print(f"Shots {shot} Mean VQA score: {np.mean(scores)}") + results["vqav2"].append( + {"shots": shot, "trials": scores, "mean": np.mean(scores)} + ) + + if args.eval_vizwiz: + print("Evaluating on VizWiz...") + for shot in args.shots: + scores = [] + for seed, trial in zip(args.trial_seeds, range(args.num_trials)): + vizwiz_score = evaluate_vqa( + args=args, + eval_model=eval_model, + num_shots=shot, + seed=seed, + dataset_name="vizwiz", + ) + print(f"Shots {shot} Trial {trial} VizWiz score: {vizwiz_score}") + scores.append(vizwiz_score) + print(f"Shots {shot} Mean VizWiz score: {np.mean(scores)}") + results["vizwiz"].append( + {"shots": shot, "trials": scores, "mean": np.mean(scores)} + ) + + if args.eval_textvqa: + print("Evaluating on TextVQA...") + for shot in args.shots: + scores = [] + for seed, trial in zip(args.trial_seeds, range(args.num_trials)): + textvqa_score = evaluate_vqa( + args=args, + eval_model=eval_model, + num_shots=shot, + seed=seed, + dataset_name="textvqa", + ) + print(f"Shots {shot} Trial {trial} TextVQA score: {textvqa_score}") + scores.append(textvqa_score) + print(f"Shots {shot} Mean TextVQA score: {np.mean(scores)}") + results["textvqa"].append( + {"shots": shot, "trials": scores, "mean": np.mean(scores)} + ) + + if args.eval_imagenet: + print("Evaluating on ImageNet...") + for shot in args.shots: + scores = [] + for seed, trial in zip(args.trial_seeds, range(args.num_trials)): + imagenet_score = evaluate_imagenet( + eval_model=eval_model, + batch_size=args.batch_size, + num_samples=args.num_samples, + num_shots=shot, + seed=seed, + imagenet_root=args.imagenet_root, + ) + print( + f"Shots {shot} Trial {trial} " f"ImageNet score: {imagenet_score}" + ) + scores.append(imagenet_score) + print(f"Shots {shot} Mean ImageNet score: {np.mean(scores)}") + results["imagenet"].append( + {"shots": shot, "trials": scores, "mean": np.mean(scores)} + ) + + if args.results_file is not None: + with open(args.results_file, "w") as f: + json.dump(results, f) + + +def get_random_indices(num_samples, query_set_size, full_dataset, seed): + if num_samples + query_set_size > len(full_dataset): + raise ValueError( + f"num_samples + query_set_size must be less than {len(full_dataset)}" + ) + + # get a random subset of the dataset + np.random.seed(seed) + random_indices = np.random.choice( + len(full_dataset), num_samples + query_set_size, replace=False + ) + return random_indices + + +def get_query_set(train_dataset, query_set_size, seed): + np.random.seed(seed) + query_set = np.random.choice(len(train_dataset), query_set_size, replace=False) + return [train_dataset[i] for i in query_set] + + +def prepare_eval_samples(test_dataset, num_samples, seed): + np.random.seed(seed) + random_indices = np.random.choice(len(test_dataset), num_samples, replace=False) + return torch.utils.data.Subset(test_dataset, random_indices) + + +def sample_batch_demos_from_query_set(query_set, num_samples, batch_size): + return [random.sample(query_set, num_samples) for _ in range(batch_size)] + + +def compute_effective_num_shots(num_shots, model_type): + if model_type == "open_flamingo": + return num_shots if num_shots > 0 else 2 + return num_shots + + +def evaluate_captioning( + args: argparse.Namespace, + eval_model: BaseEvalModel, + seed: int = 42, + max_generation_length: int = 20, + num_beams: int = 3, + length_penalty: float = -2.0, + num_shots: int = 8, + dataset_name: str = "coco", +): + """Evaluate a model on COCO dataset. + + Args: + args (argparse.Namespace): arguments + eval_model (BaseEvalModel): model to evaluate + seed (int, optional): seed for random number generator. Defaults to 42. + max_generation_length (int, optional): maximum length of the generated caption. Defaults to 20. + num_beams (int, optional): number of beams to use for beam search. Defaults to 3. + length_penalty (float, optional): length penalty for beam search. Defaults to -2.0. + num_shots (int, optional): number of in-context samples to use. Defaults to 8. + dataset_name (str, optional): dataset to evaluate on. Can be "coco" or "flickr". Defaults to "coco". + Returns: + float: CIDEr score + + """ + + if dataset_name == "coco": + image_train_dir_path = args.coco_train_image_dir_path + image_val_dir_path = args.coco_val_image_dir_path + annotations_path = args.coco_karpathy_json_path + elif dataset_name == "flickr": + image_train_dir_path = ( + args.flickr_image_dir_path + ) # Note: calling this "train" for consistency with COCO but Flickr only has one split for images + image_val_dir_path = None + annotations_path = args.flickr_karpathy_json_path + else: + raise ValueError(f"Unsupported dataset: {dataset_name}") + + train_dataset = CaptionDataset( + image_train_dir_path=image_train_dir_path, + image_val_dir_path=image_val_dir_path, + annotations_path=annotations_path, + is_train=True, + dataset_name=dataset_name, + ) + + test_dataset = CaptionDataset( + image_train_dir_path=image_train_dir_path, + image_val_dir_path=image_val_dir_path, + annotations_path=annotations_path, + is_train=False, + dataset_name=dataset_name, + ) + + effective_num_shots = compute_effective_num_shots(num_shots, args.model) + + test_dataset = prepare_eval_samples( + test_dataset, + args.num_samples if args.num_samples > 0 else len(test_dataset), + seed, + ) + + in_context_samples = get_query_set(train_dataset, args.query_set_size, seed) + + predictions = defaultdict() + + for batch in more_itertools.chunked( + tqdm(test_dataset, desc=f"Running inference {dataset_name.upper()}"), + args.batch_size, + ): + batch_demo_samples = sample_batch_demos_from_query_set( + in_context_samples, effective_num_shots, len(batch) + ) + + batch_images = [] + batch_text = [] + for i in range(len(batch)): + if num_shots > 0: + context_images = [x["image"] for x in batch_demo_samples[i]] + else: + context_images = [] + batch_images.append(context_images + [batch[i]["image"]]) + + context_text = "".join( + [ + eval_model.get_caption_prompt(caption=x["caption"].strip()) + for x in batch_demo_samples[i] + ] + ) + + # Keep the text but remove the image tags for the zero-shot case + if num_shots == 0: + context_text = context_text.replace("", "") + + batch_text.append(context_text + eval_model.get_caption_prompt()) + + outputs = eval_model.get_outputs( + batch_images=batch_images, + batch_text=batch_text, + max_generation_length=max_generation_length, + num_beams=num_beams, + length_penalty=length_penalty, + ) + + new_predictions = [ + postprocess_captioning_generation(out).replace('"', "") for out in outputs + ] + + for i, sample in enumerate(batch): + predictions[sample["image_id"]] = { + "caption": new_predictions[i], + } + + # save the predictions to a temporary file + results_path = f"{dataset_name}results_{uuid.uuid4()}.json" + + with open(results_path, "w") as f: + f.write( + json.dumps( + [ + {"image_id": k, "caption": predictions[k]["caption"]} + for k in predictions + ], + indent=4, + ) + ) + + metrics = compute_cider( + result_path=results_path, + annotations_path=args.coco_annotations_json_path + if dataset_name == "coco" + else args.flickr_annotations_json_path, + ) + + # delete the temporary file + os.remove(results_path) + + return metrics["CIDEr"] * 100.0 + + +def evaluate_vqa( + args: argparse.Namespace, + eval_model: BaseEvalModel, + seed: int = 42, + max_generation_length: int = 5, + num_beams: int = 3, + length_penalty: float = -2.0, + num_shots: int = 8, + dataset_name: str = "vqav2", +): + """ + Evaluate a model on VQA datasets. Currently supports VQA v2.0, OK-VQA, VizWiz and TextVQA. + + Args: + args (argparse.Namespace): arguments + eval_model (BaseEvalModel): model to evaluate + seed (int, optional): random seed. Defaults to 42. + max_generation_length (int, optional): max generation length. Defaults to 5. + num_beams (int, optional): number of beams to use for beam search. Defaults to 3. + length_penalty (float, optional): length penalty for beam search. Defaults to -2.0. + num_shots (int, optional): number of shots to use. Defaults to 8. + dataset_name (string): type of vqa dataset: currently supports vqav2, ok_vqa. Defaults to vqav2. + Returns: + float: accuracy score + """ + + if dataset_name == "ok_vqa": + train_image_dir_path = args.ok_vqa_train_image_dir_path + train_questions_json_path = args.ok_vqa_train_questions_json_path + train_annotations_json_path = args.ok_vqa_train_annotations_json_path + test_image_dir_path = args.ok_vqa_test_image_dir_path + test_questions_json_path = args.ok_vqa_test_questions_json_path + test_annotations_json_path = args.ok_vqa_test_annotations_json_path + elif dataset_name == "vqav2": + train_image_dir_path = args.vqav2_train_image_dir_path + train_questions_json_path = args.vqav2_train_questions_json_path + train_annotations_json_path = args.vqav2_train_annotations_json_path + test_image_dir_path = args.vqav2_test_image_dir_path + test_questions_json_path = args.vqav2_test_questions_json_path + test_annotations_json_path = args.vqav2_test_annotations_json_path + elif dataset_name == "vizwiz": + train_image_dir_path = args.vizwiz_train_image_dir_path + train_questions_json_path = args.vizwiz_train_questions_json_path + train_annotations_json_path = args.vizwiz_train_annotations_json_path + test_image_dir_path = args.vizwiz_test_image_dir_path + test_questions_json_path = args.vizwiz_test_questions_json_path + test_annotations_json_path = args.vizwiz_test_annotations_json_path + elif dataset_name == "textvqa": + train_image_dir_path = args.textvqa_image_dir_path + train_questions_json_path = args.textvqa_train_questions_json_path + train_annotations_json_path = args.textvqa_train_annotations_json_path + test_image_dir_path = args.textvqa_image_dir_path + test_questions_json_path = args.textvqa_test_questions_json_path + test_annotations_json_path = args.textvqa_test_annotations_json_path + else: + raise ValueError(f"Unsupported dataset: {dataset_name}") + + train_dataset = VQADataset( + image_dir_path=train_image_dir_path, + question_path=train_questions_json_path, + annotations_path=train_annotations_json_path, + is_train=True, + dataset_name=dataset_name, + ) + + test_dataset = VQADataset( + image_dir_path=test_image_dir_path, + question_path=test_questions_json_path, + annotations_path=test_annotations_json_path, + is_train=False, + dataset_name=dataset_name, + ) + + effective_num_shots = compute_effective_num_shots(num_shots, args.model) + + test_dataset = prepare_eval_samples( + test_dataset, + args.num_samples if args.num_samples > 0 else len(test_dataset), + seed, + ) + + in_context_samples = get_query_set(train_dataset, args.query_set_size, seed) + predictions = [] + + for batch in more_itertools.chunked( + tqdm(test_dataset, desc=f"Running inference {dataset_name.upper()}"), + args.batch_size, + ): + batch_demo_samples = sample_batch_demos_from_query_set( + in_context_samples, effective_num_shots, len(batch) + ) + + batch_images = [] + batch_text = [] + for i in range(len(batch)): + if num_shots > 0: + context_images = [x["image"] for x in batch_demo_samples[i]] + else: + context_images = [] + batch_images.append(context_images + [batch[i]["image"]]) + + context_text = "".join( + [ + eval_model.get_vqa_prompt( + question=x["question"], answer=x["answers"][0] + ) + for x in batch_demo_samples[i] + ] + ) + + # Keep the text but remove the image tags for the zero-shot case + if num_shots == 0: + context_text = context_text.replace("", "") + + batch_text.append( + context_text + eval_model.get_vqa_prompt(question=batch[i]["question"]) + ) + + outputs = eval_model.get_outputs( + batch_images=batch_images, + batch_text=batch_text, + max_generation_length=max_generation_length, + num_beams=num_beams, + length_penalty=length_penalty, + ) + + process_function = ( + postprocess_ok_vqa_generation + if dataset_name == "ok_vqa" + else postprocess_vqa_generation + ) + + new_predictions = map(process_function, outputs) + + predictions.extend( + [ + {"answer": p, "question_id": sample["question_id"]} + for p, sample in zip(new_predictions, batch) + ] + ) + # save the predictions to a temporary file + random_uuid = str(uuid.uuid4()) + with open(f"{dataset_name}results_{random_uuid}.json", "w") as f: + f.write(json.dumps(predictions, indent=4)) + + acc = compute_vqa_accuracy( + f"{dataset_name}results_{random_uuid}.json", + test_questions_json_path, + test_annotations_json_path, + ) + + # delete the temporary file + os.remove(f"{dataset_name}results_{random_uuid}.json") + + return acc + + +def evaluate_imagenet( + eval_model, + batch_size: int, + imagenet_root: str, + seed: int = 42, + num_samples: int = 5000, + num_shots: int = 8, +): + """ + Evaluate a model on ImageNet dataset. + + Args: + eval_model (BaseEvalModel): model to evaluate + batch_size (int): batch size + imagenet_root (str): path to imagenet root for the specified split. + seed (int, optional): random seed. Defaults to 42. + num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples. + num_shots (int, optional): number of shots to use. Defaults to 8. + + Returns: + float: accuracy score + """ + if not hasattr(eval_model, "model") or not hasattr(eval_model, "tokenizer"): + raise NotImplementedError( + "evaluate_imagenet is currently only supported for OpenFlamingo " "models" + ) + np.random.seed(seed) + model, tokenizer = eval_model.model, eval_model.tokenizer + assert isinstance(model, Flamingo) + + train_dataset = ImageNetDataset(os.path.join(imagenet_root, "train")) + val_dataset = ImageNetDataset(os.path.join(imagenet_root, "val")) + + effective_num_shots = compute_effective_num_shots(num_shots, args.model) + tokenizer.padding_side = ( + "left" # For generation padding tokens should be on the left + ) + + acc1 = 0 + acc5 = 0 + prompt_text = "A photo of a" + + val_iterator = more_itertools.chunked(val_dataset, batch_size) + for batch_idx, batch in enumerate(val_iterator): + batch_images = [] + batch_text = [] + + for idx in range(len(batch)): + # Choose a different set of random context samples for each sample + # from the training set + context_indices = np.random.choice( + len(train_dataset), effective_num_shots, replace=False + ) + + in_context_samples = [train_dataset[i] for i in context_indices] + + vision_x = [ + eval_model.image_processor(data["image"]).unsqueeze(0) + for data in in_context_samples + ] + [eval_model.image_processor(batch[idx]["image"]).unsqueeze(0)] + batch_images.append(torch.cat(vision_x, dim=0)) + + context_class_names = [ + in_context_samples[i]["class_name"] for i in range(effective_num_shots) + ] + context_text = "".join( + f"{prompt_text} {classname}<|endofchunk|>" + for classname in context_class_names + ) + batch_text.append(context_text) + + # shape [B, T_img, C, h, w] + vision_x = torch.stack(batch_images, dim=0) + # shape [B, T_img, 1, C, h, w] where 1 is the frame dimension + vision_x = vision_x.unsqueeze(2) + model._encode_vision_x(vision_x.cuda()) + + # Cache the context text: tokenize context and prompt, + # e.g. ' a picture of a ' + ctx_and_prompt_tokenized = tokenizer( + [context_text + prompt_text + " " for context_text in batch_text], + return_tensors="pt", + padding=True, + truncation=True, + max_length=2048, + ) + + with torch.no_grad(): + precomputed = model( + vision_x=None, + lang_x=ctx_and_prompt_tokenized["input_ids"].cuda(), + attention_mask=ctx_and_prompt_tokenized["attention_mask"].cuda(), + clear_conditioned_layers=False, + use_cached_vision_x=True, + use_cache=True, + ) + + def _detach_pkvs(pkvs): + """Detach a set of past key values.""" + return tuple([tuple([x.detach() for x in inner]) for inner in pkvs]) + + precomputed_pkvs = _detach_pkvs(precomputed.past_key_values) + + precomputed_logits = precomputed.logits.detach() + + overall_probs = [] + for imagenet_class_name in tqdm(openai_imagenet_classnames): + past_key_values = None + # Tokenize only the class name and iteratively decode the model's + # predictions for this class. + classname_tokens = tokenizer( + imagenet_class_name, add_special_tokens=False, return_tensors="pt" + )["input_ids"].cuda() + + if classname_tokens.ndim == 1: # Case: classname is only 1 token + classname_tokens = torch.unsqueeze(classname_tokens, 1) + + classname_tokens = repeat( + classname_tokens, "b s -> (repeat b) s", repeat=batch_size + ) + + # Compute the outputs one token at a time, using cached + # activations. + + # Initialize the elementwise predictions with the last set of + # logits from precomputed; this will correspond to the predicted + # probability of the first position/token in the imagenet + # classname. We will append the logits for each token to this + # list (each element has shape [B, 1, vocab_size]). + elementwise_logits = [precomputed_logits[:, -2:-1, :]] + + for token_idx in range(classname_tokens.shape[1]): + _lang_x = classname_tokens[:, token_idx].reshape((-1, 1)) + with torch.no_grad(): + outputs = model( + vision_x=None, + lang_x=_lang_x, + clear_conditioned_layers=False, + use_cached_vision_x=True, + past_key_values=( + past_key_values if token_idx > 0 else precomputed_pkvs + ), + use_cache=True, + ) + past_key_values = _detach_pkvs(outputs.past_key_values) + elementwise_logits.append(outputs.logits.detach()) + + # logits/probs has shape [B, classname_tokens + 1, vocab_size] + logits = torch.concat(elementwise_logits, 1) + probs = torch.softmax(logits, dim=-1).detach() + + # collect the probability of the generated token -- probability + # at index 0 corresponds to the token at index 1. + probs = probs[:, :-1, :] # shape [B, classname_tokens, vocab_size] + + gen_probs = torch.gather(probs, 2, classname_tokens[:, :, None]).squeeze(-1) + + class_prob = torch.prod(gen_probs, 1).detach().cpu().numpy() + overall_probs.append(class_prob) + + overall_probs = np.row_stack(overall_probs).T # shape [B, num_classes] + + def topk(probs_ary: np.ndarray, k: int) -> np.ndarray: + """Return the indices of the top k elements in probs_ary.""" + return np.argsort(probs_ary)[::-1][:k] + + for i in range(batch_size): + top5 = [ + IMAGENET_1K_CLASS_ID_TO_LABEL[pred] + for pred in topk(overall_probs[i], 5) + ] + + y_i = batch[i]["class_name"] + acc5 += int(y_i in set(top5)) + acc1 += int(y_i == top5[0]) + + print( + f"DEBUG: batch {idx} elem {i} of {batch_size}:" + f"label {y_i} // top5 {top5}" + ) + + examples_seen = (batch_idx + 1) * batch_size + print( + "eval {}/{}: acc@1 ({}), acc@5 ({})".format( + examples_seen, num_samples, acc1 / examples_seen, acc5 / examples_seen + ) + ) + if batch_idx * batch_size >= num_samples - 1: + break + + return float(acc1) / num_samples + + +if __name__ == "__main__": + main() diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/imagenet_utils.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/imagenet_utils.py new file mode 100644 index 0000000000..5803c70024 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/imagenet_utils.py @@ -0,0 +1,1007 @@ +# classnames via https://github.com/mlfoundations/wise-ft/blob/master/src/datasets/imagenet_classnames.py#L1 +openai_imagenet_classnames = [ + "tench", + "goldfish", + "great white shark", + "tiger shark", + "hammerhead shark", + "electric ray", + "stingray", + "rooster", + "hen", + "ostrich", + "brambling", + "goldfinch", + "house finch", + "junco", + "indigo bunting", + "American robin", + "bulbul", + "jay", + "magpie", + "chickadee", + "American dipper", + "kite (bird of prey)", + "bald eagle", + "vulture", + "great grey owl", + "fire salamander", + "smooth newt", + "newt", + "spotted salamander", + "axolotl", + "American bullfrog", + "tree frog", + "tailed frog", + "loggerhead sea turtle", + "leatherback sea turtle", + "mud turtle", + "terrapin", + "box turtle", + "banded gecko", + "green iguana", + "Carolina anole", + "desert grassland whiptail lizard", + "agama", + "frilled-necked lizard", + "alligator lizard", + "Gila monster", + "European green lizard", + "chameleon", + "Komodo dragon", + "Nile crocodile", + "American alligator", + "triceratops", + "worm snake", + "ring-necked snake", + "eastern hog-nosed snake", + "smooth green snake", + "kingsnake", + "garter snake", + "water snake", + "vine snake", + "night snake", + "boa constrictor", + "African rock python", + "Indian cobra", + "green mamba", + "sea snake", + "Saharan horned viper", + "eastern diamondback rattlesnake", + "sidewinder rattlesnake", + "trilobite", + "harvestman", + "scorpion", + "yellow garden spider", + "barn spider", + "European garden spider", + "southern black widow", + "tarantula", + "wolf spider", + "tick", + "centipede", + "black grouse", + "ptarmigan", + "ruffed grouse", + "prairie grouse", + "peafowl", + "quail", + "partridge", + "african grey parrot", + "macaw", + "sulphur-crested cockatoo", + "lorikeet", + "coucal", + "bee eater", + "hornbill", + "hummingbird", + "jacamar", + "toucan", + "duck", + "red-breasted merganser", + "goose", + "black swan", + "tusker", + "echidna", + "platypus", + "wallaby", + "koala", + "wombat", + "jellyfish", + "sea anemone", + "brain coral", + "flatworm", + "nematode", + "conch", + "snail", + "slug", + "sea slug", + "chiton", + "chambered nautilus", + "Dungeness crab", + "rock crab", + "fiddler crab", + "red king crab", + "American lobster", + "spiny lobster", + "crayfish", + "hermit crab", + "isopod", + "white stork", + "black stork", + "spoonbill", + "flamingo", + "little blue heron", + "great egret", + "bittern bird", + "crane bird", + "limpkin", + "common gallinule", + "American coot", + "bustard", + "ruddy turnstone", + "dunlin", + "common redshank", + "dowitcher", + "oystercatcher", + "pelican", + "king penguin", + "albatross", + "grey whale", + "killer whale", + "dugong", + "sea lion", + "Chihuahua", + "Japanese Chin", + "Maltese", + "Pekingese", + "Shih Tzu", + "King Charles Spaniel", + "Papillon", + "toy terrier", + "Rhodesian Ridgeback", + "Afghan Hound", + "Basset Hound", + "Beagle", + "Bloodhound", + "Bluetick Coonhound", + "Black and Tan Coonhound", + "Treeing Walker Coonhound", + "English foxhound", + "Redbone Coonhound", + "borzoi", + "Irish Wolfhound", + "Italian Greyhound", + "Whippet", + "Ibizan Hound", + "Norwegian Elkhound", + "Otterhound", + "Saluki", + "Scottish Deerhound", + "Weimaraner", + "Staffordshire Bull Terrier", + "American Staffordshire Terrier", + "Bedlington Terrier", + "Border Terrier", + "Kerry Blue Terrier", + "Irish Terrier", + "Norfolk Terrier", + "Norwich Terrier", + "Yorkshire Terrier", + "Wire Fox Terrier", + "Lakeland Terrier", + "Sealyham Terrier", + "Airedale Terrier", + "Cairn Terrier", + "Australian Terrier", + "Dandie Dinmont Terrier", + "Boston Terrier", + "Miniature Schnauzer", + "Giant Schnauzer", + "Standard Schnauzer", + "Scottish Terrier", + "Tibetan Terrier", + "Australian Silky Terrier", + "Soft-coated Wheaten Terrier", + "West Highland White Terrier", + "Lhasa Apso", + "Flat-Coated Retriever", + "Curly-coated Retriever", + "Golden Retriever", + "Labrador Retriever", + "Chesapeake Bay Retriever", + "German Shorthaired Pointer", + "Vizsla", + "English Setter", + "Irish Setter", + "Gordon Setter", + "Brittany dog", + "Clumber Spaniel", + "English Springer Spaniel", + "Welsh Springer Spaniel", + "Cocker Spaniel", + "Sussex Spaniel", + "Irish Water Spaniel", + "Kuvasz", + "Schipperke", + "Groenendael dog", + "Malinois", + "Briard", + "Australian Kelpie", + "Komondor", + "Old English Sheepdog", + "Shetland Sheepdog", + "collie", + "Border Collie", + "Bouvier des Flandres dog", + "Rottweiler", + "German Shepherd Dog", + "Dobermann", + "Miniature Pinscher", + "Greater Swiss Mountain Dog", + "Bernese Mountain Dog", + "Appenzeller Sennenhund", + "Entlebucher Sennenhund", + "Boxer", + "Bullmastiff", + "Tibetan Mastiff", + "French Bulldog", + "Great Dane", + "St. Bernard", + "husky", + "Alaskan Malamute", + "Siberian Husky", + "Dalmatian", + "Affenpinscher", + "Basenji", + "pug", + "Leonberger", + "Newfoundland dog", + "Great Pyrenees dog", + "Samoyed", + "Pomeranian", + "Chow Chow", + "Keeshond", + "brussels griffon", + "Pembroke Welsh Corgi", + "Cardigan Welsh Corgi", + "Toy Poodle", + "Miniature Poodle", + "Standard Poodle", + "Mexican hairless dog (xoloitzcuintli)", + "grey wolf", + "Alaskan tundra wolf", + "red wolf or maned wolf", + "coyote", + "dingo", + "dhole", + "African wild dog", + "hyena", + "red fox", + "kit fox", + "Arctic fox", + "grey fox", + "tabby cat", + "tiger cat", + "Persian cat", + "Siamese cat", + "Egyptian Mau", + "cougar", + "lynx", + "leopard", + "snow leopard", + "jaguar", + "lion", + "tiger", + "cheetah", + "brown bear", + "American black bear", + "polar bear", + "sloth bear", + "mongoose", + "meerkat", + "tiger beetle", + "ladybug", + "ground beetle", + "longhorn beetle", + "leaf beetle", + "dung beetle", + "rhinoceros beetle", + "weevil", + "fly", + "bee", + "ant", + "grasshopper", + "cricket insect", + "stick insect", + "cockroach", + "praying mantis", + "cicada", + "leafhopper", + "lacewing", + "dragonfly", + "damselfly", + "red admiral butterfly", + "ringlet butterfly", + "monarch butterfly", + "small white butterfly", + "sulphur butterfly", + "gossamer-winged butterfly", + "starfish", + "sea urchin", + "sea cucumber", + "cottontail rabbit", + "hare", + "Angora rabbit", + "hamster", + "porcupine", + "fox squirrel", + "marmot", + "beaver", + "guinea pig", + "common sorrel horse", + "zebra", + "pig", + "wild boar", + "warthog", + "hippopotamus", + "ox", + "water buffalo", + "bison", + "ram (adult male sheep)", + "bighorn sheep", + "Alpine ibex", + "hartebeest", + "impala (antelope)", + "gazelle", + "arabian camel", + "llama", + "weasel", + "mink", + "European polecat", + "black-footed ferret", + "otter", + "skunk", + "badger", + "armadillo", + "three-toed sloth", + "orangutan", + "gorilla", + "chimpanzee", + "gibbon", + "siamang", + "guenon", + "patas monkey", + "baboon", + "macaque", + "langur", + "black-and-white colobus", + "proboscis monkey", + "marmoset", + "white-headed capuchin", + "howler monkey", + "titi monkey", + "Geoffroy's spider monkey", + "common squirrel monkey", + "ring-tailed lemur", + "indri", + "Asian elephant", + "African bush elephant", + "red panda", + "giant panda", + "snoek fish", + "eel", + "silver salmon", + "rock beauty fish", + "clownfish", + "sturgeon", + "gar fish", + "lionfish", + "pufferfish", + "abacus", + "abaya", + "academic gown", + "accordion", + "acoustic guitar", + "aircraft carrier", + "airliner", + "airship", + "altar", + "ambulance", + "amphibious vehicle", + "analog clock", + "apiary", + "apron", + "trash can", + "assault rifle", + "backpack", + "bakery", + "balance beam", + "balloon", + "ballpoint pen", + "Band-Aid", + "banjo", + "baluster / handrail", + "barbell", + "barber chair", + "barbershop", + "barn", + "barometer", + "barrel", + "wheelbarrow", + "baseball", + "basketball", + "bassinet", + "bassoon", + "swimming cap", + "bath towel", + "bathtub", + "station wagon", + "lighthouse", + "beaker", + "military hat (bearskin or shako)", + "beer bottle", + "beer glass", + "bell tower", + "baby bib", + "tandem bicycle", + "bikini", + "ring binder", + "binoculars", + "birdhouse", + "boathouse", + "bobsleigh", + "bolo tie", + "poke bonnet", + "bookcase", + "bookstore", + "bottle cap", + "hunting bow", + "bow tie", + "brass memorial plaque", + "bra", + "breakwater", + "breastplate", + "broom", + "bucket", + "buckle", + "bulletproof vest", + "high-speed train", + "butcher shop", + "taxicab", + "cauldron", + "candle", + "cannon", + "canoe", + "can opener", + "cardigan", + "car mirror", + "carousel", + "tool kit", + "cardboard box / carton", + "car wheel", + "automated teller machine", + "cassette", + "cassette player", + "castle", + "catamaran", + "CD player", + "cello", + "mobile phone", + "chain", + "chain-link fence", + "chain mail", + "chainsaw", + "storage chest", + "chiffonier", + "bell or wind chime", + "china cabinet", + "Christmas stocking", + "church", + "movie theater", + "cleaver", + "cliff dwelling", + "cloak", + "clogs", + "cocktail shaker", + "coffee mug", + "coffeemaker", + "spiral or coil", + "combination lock", + "computer keyboard", + "candy store", + "container ship", + "convertible", + "corkscrew", + "cornet", + "cowboy boot", + "cowboy hat", + "cradle", + "construction crane", + "crash helmet", + "crate", + "infant bed", + "Crock Pot", + "croquet ball", + "crutch", + "cuirass", + "dam", + "desk", + "desktop computer", + "rotary dial telephone", + "diaper", + "digital clock", + "digital watch", + "dining table", + "dishcloth", + "dishwasher", + "disc brake", + "dock", + "dog sled", + "dome", + "doormat", + "drilling rig", + "drum", + "drumstick", + "dumbbell", + "Dutch oven", + "electric fan", + "electric guitar", + "electric locomotive", + "entertainment center", + "envelope", + "espresso machine", + "face powder", + "feather boa", + "filing cabinet", + "fireboat", + "fire truck", + "fire screen", + "flagpole", + "flute", + "folding chair", + "football helmet", + "forklift", + "fountain", + "fountain pen", + "four-poster bed", + "freight car", + "French horn", + "frying pan", + "fur coat", + "garbage truck", + "gas mask or respirator", + "gas pump", + "goblet", + "go-kart", + "golf ball", + "golf cart", + "gondola", + "gong", + "gown", + "grand piano", + "greenhouse", + "radiator grille", + "grocery store", + "guillotine", + "hair clip", + "hair spray", + "half-track", + "hammer", + "hamper", + "hair dryer", + "hand-held computer", + "handkerchief", + "hard disk drive", + "harmonica", + "harp", + "combine harvester", + "hatchet", + "holster", + "home theater", + "honeycomb", + "hook", + "hoop skirt", + "gymnastic horizontal bar", + "horse-drawn vehicle", + "hourglass", + "iPod", + "clothes iron", + "carved pumpkin", + "jeans", + "jeep", + "T-shirt", + "jigsaw puzzle", + "rickshaw", + "joystick", + "kimono", + "knee pad", + "knot", + "lab coat", + "ladle", + "lampshade", + "laptop computer", + "lawn mower", + "lens cap", + "letter opener", + "library", + "lifeboat", + "lighter", + "limousine", + "ocean liner", + "lipstick", + "slip-on shoe", + "lotion", + "music speaker", + "loupe magnifying glass", + "sawmill", + "magnetic compass", + "messenger bag", + "mailbox", + "tights", + "one-piece bathing suit", + "manhole cover", + "maraca", + "marimba", + "mask", + "matchstick", + "maypole", + "maze", + "measuring cup", + "medicine cabinet", + "megalith", + "microphone", + "microwave oven", + "military uniform", + "milk can", + "minibus", + "miniskirt", + "minivan", + "missile", + "mitten", + "mixing bowl", + "mobile home", + "ford model t", + "modem", + "monastery", + "monitor", + "moped", + "mortar and pestle", + "graduation cap", + "mosque", + "mosquito net", + "vespa", + "mountain bike", + "tent", + "computer mouse", + "mousetrap", + "moving van", + "muzzle", + "metal nail", + "neck brace", + "necklace", + "baby pacifier", + "notebook computer", + "obelisk", + "oboe", + "ocarina", + "odometer", + "oil filter", + "pipe organ", + "oscilloscope", + "overskirt", + "bullock cart", + "oxygen mask", + "product packet / packaging", + "paddle", + "paddle wheel", + "padlock", + "paintbrush", + "pajamas", + "palace", + "pan flute", + "paper towel", + "parachute", + "parallel bars", + "park bench", + "parking meter", + "railroad car", + "patio", + "payphone", + "pedestal", + "pencil case", + "pencil sharpener", + "perfume", + "Petri dish", + "photocopier", + "plectrum", + "Pickelhaube", + "picket fence", + "pickup truck", + "pier", + "piggy bank", + "pill bottle", + "pillow", + "ping-pong ball", + "pinwheel", + "pirate ship", + "drink pitcher", + "block plane", + "planetarium", + "plastic bag", + "plate rack", + "farm plow", + "plunger", + "Polaroid camera", + "pole", + "police van", + "poncho", + "pool table", + "soda bottle", + "plant pot", + "potter's wheel", + "power drill", + "prayer rug", + "printer", + "prison", + "missile", + "projector", + "hockey puck", + "punching bag", + "purse", + "quill", + "quilt", + "race car", + "racket", + "radiator", + "radio", + "radio telescope", + "rain barrel", + "recreational vehicle", + "fishing casting reel", + "reflex camera", + "refrigerator", + "remote control", + "restaurant", + "revolver", + "rifle", + "rocking chair", + "rotisserie", + "eraser", + "rugby ball", + "ruler measuring stick", + "sneaker", + "safe", + "safety pin", + "salt shaker", + "sandal", + "sarong", + "saxophone", + "scabbard", + "weighing scale", + "school bus", + "schooner", + "scoreboard", + "CRT monitor", + "screw", + "screwdriver", + "seat belt", + "sewing machine", + "shield", + "shoe store", + "shoji screen / room divider", + "shopping basket", + "shopping cart", + "shovel", + "shower cap", + "shower curtain", + "ski", + "balaclava ski mask", + "sleeping bag", + "slide rule", + "sliding door", + "slot machine", + "snorkel", + "snowmobile", + "snowplow", + "soap dispenser", + "soccer ball", + "sock", + "solar thermal collector", + "sombrero", + "soup bowl", + "keyboard space bar", + "space heater", + "space shuttle", + "spatula", + "motorboat", + "spider web", + "spindle", + "sports car", + "spotlight", + "stage", + "steam locomotive", + "through arch bridge", + "steel drum", + "stethoscope", + "scarf", + "stone wall", + "stopwatch", + "stove", + "strainer", + "tram", + "stretcher", + "couch", + "stupa", + "submarine", + "suit", + "sundial", + "sunglasses", + "sunglasses", + "sunscreen", + "suspension bridge", + "mop", + "sweatshirt", + "swim trunks / shorts", + "swing", + "electrical switch", + "syringe", + "table lamp", + "tank", + "tape player", + "teapot", + "teddy bear", + "television", + "tennis ball", + "thatched roof", + "front curtain", + "thimble", + "threshing machine", + "throne", + "tile roof", + "toaster", + "tobacco shop", + "toilet seat", + "torch", + "totem pole", + "tow truck", + "toy store", + "tractor", + "semi-trailer truck", + "tray", + "trench coat", + "tricycle", + "trimaran", + "tripod", + "triumphal arch", + "trolleybus", + "trombone", + "hot tub", + "turnstile", + "typewriter keyboard", + "umbrella", + "unicycle", + "upright piano", + "vacuum cleaner", + "vase", + "vaulted or arched ceiling", + "velvet fabric", + "vending machine", + "vestment", + "viaduct", + "violin", + "volleyball", + "waffle iron", + "wall clock", + "wallet", + "wardrobe", + "military aircraft", + "sink", + "washing machine", + "water bottle", + "water jug", + "water tower", + "whiskey jug", + "whistle", + "hair wig", + "window screen", + "window shade", + "Windsor tie", + "wine bottle", + "airplane wing", + "wok", + "wooden spoon", + "wool", + "split-rail fence", + "shipwreck", + "sailboat", + "yurt", + "website", + "comic book", + "crossword", + "traffic or street sign", + "traffic light", + "dust jacket", + "menu", + "plate", + "guacamole", + "consomme", + "hot pot", + "trifle", + "ice cream", + "popsicle", + "baguette", + "bagel", + "pretzel", + "cheeseburger", + "hot dog", + "mashed potatoes", + "cabbage", + "broccoli", + "cauliflower", + "zucchini", + "spaghetti squash", + "acorn squash", + "butternut squash", + "cucumber", + "artichoke", + "bell pepper", + "cardoon", + "mushroom", + "Granny Smith apple", + "strawberry", + "orange", + "lemon", + "fig", + "pineapple", + "banana", + "jackfruit", + "cherimoya (custard apple)", + "pomegranate", + "hay", + "carbonara", + "chocolate syrup", + "dough", + "meatloaf", + "pizza", + "pot pie", + "burrito", + "red wine", + "espresso", + "tea cup", + "eggnog", + "mountain", + "bubble", + "cliff", + "coral reef", + "geyser", + "lakeshore", + "promontory", + "sandbar", + "beach", + "valley", + "volcano", + "baseball player", + "bridegroom", + "scuba diver", + "rapeseed", + "daisy", + "yellow lady's slipper", + "corn", + "acorn", + "rose hip", + "horse chestnut seed", + "coral fungus", + "agaric", + "gyromitra", + "stinkhorn mushroom", + "earth star fungus", + "hen of the woods mushroom", + "bolete", + "corn cob", + "toilet paper", +] +# Maps numeric class ids to labels +IMAGENET_1K_CLASS_ID_TO_LABEL = dict( + zip(range(len(openai_imagenet_classnames)), openai_imagenet_classnames) +) diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/models/blip.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/models/blip.py new file mode 100644 index 0000000000..d6aed17b42 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/models/blip.py @@ -0,0 +1,110 @@ +from typing import List + +from PIL import Image +import torch + +from transformers import Blip2Processor, Blip2ForConditionalGeneration +from open_flamingo.eval.eval_model import BaseEvalModel + + +class EvalModel(BaseEvalModel): + """BLIP-2 model evaluation. + + Attributes: + model (nn.Module): Underlying Torch model. + tokenizer (transformers.PreTrainedTokenizer): Tokenizer for model. + device: Index of GPU to use, or the string "cpu" + """ + + def __init__(self, model_args): + assert ( + "processor_path" in model_args + and "lm_path" in model_args + and "device" in model_args + ), "BLIP-2 requires processor_path, lm_path, and device arguments to be specified" + + model_args["device"] = int(model_args["device"]) + + self.device = model_args["device"] if model_args["device"] >= 0 else "cpu" + self.processor = Blip2Processor.from_pretrained(model_args["processor_path"]) + self.model = Blip2ForConditionalGeneration.from_pretrained( + model_args["lm_path"] + ) + self.model.to(self.device) + self.model.eval() + self.processor.tokenizer.padding_side = "left" + + def _prepare_images(self, batch: List[List[torch.Tensor]]) -> torch.Tensor: + """Preprocess images and stack them. + + Args: + batch: A list of lists of images. + + Returns: + A Tensor of shape + (batch_size, channels, height, width). + """ + batch_images = None + assert all( + len(example) == 1 for example in batch + ), "BLIP-2 only supports one image per example" + + for example in batch: + assert len(example) == 1, "BLIP-2 only supports one image per example" + batch_images = torch.cat( + [ + batch_images, + self.processor.image_processor(example, return_tensors="pt")[ + "pixel_values" + ], + ] + if batch_images is not None + else [ + self.processor.image_processor(example, return_tensors="pt")[ + "pixel_values" + ] + ], + dim=0, + ) + return batch_images + + def get_outputs( + self, + batch_text: List[str], + batch_images: List[List[Image.Image]], + max_generation_length: int, + num_beams: int, + length_penalty: float, + ) -> List[str]: + encodings = self.processor.tokenizer( + batch_text, + padding="longest", + truncation=True, + return_tensors="pt", + max_length=2000, + ) + input_ids = encodings["input_ids"] + attention_mask = encodings["attention_mask"] + + with torch.inference_mode(): + outputs = self.model.generate( + self._prepare_images(batch_images).to(self.device), + input_ids.to(self.device), + attention_mask=attention_mask.to(self.device), + max_new_tokens=max_generation_length, + num_beams=num_beams, + length_penalty=length_penalty, + ) + + return self.processor.tokenizer.batch_decode(outputs, skip_special_tokens=True) + + def get_vqa_prompt(self, question, answer=None) -> str: + return ( + f"Question:{question} Short answer:{answer if answer is not None else ''}" + ) + + def get_caption_prompt(self, caption=None) -> str: + return f"A photo of {caption if caption is not None else ''}" + + def get_classification_prompt(self, class_str=None) -> str: + raise NotImplementedError diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/models/open_flamingo.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/models/open_flamingo.py new file mode 100644 index 0000000000..1fdcb543b7 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/models/open_flamingo.py @@ -0,0 +1,112 @@ +from typing import List + +from PIL import Image +import torch + +from open_flamingo.eval.eval_model import BaseEvalModel +from open_flamingo.src.factory import create_model_and_transforms + + +class EvalModel(BaseEvalModel): + """OpenFlamingo model evaluation. + + Attributes: + model (nn.Module): Underlying Torch model. + tokenizer (transformers.PreTrainedTokenizer): Tokenizer for model. + device: Index of GPU to use, or the string "CPU" + """ + + def __init__(self, model_args): + assert ( + "vision_encoder_path" in model_args + and "lm_path" in model_args + and "device" in model_args + and "checkpoint_path" in model_args + and "lm_tokenizer_path" in model_args + and "cross_attn_every_n_layers" in model_args + and "vision_encoder_pretrained" in model_args + ), "OpenFlamingo requires vision_encoder_path, lm_path, device, checkpoint_path, lm_tokenizer_path, cross_attn_every_n_layers, and vision_encoder_pretrained arguments to be specified" + + model_args["device"] = int(model_args["device"]) + self.device = model_args["device"] if model_args["device"] >= 0 else "cpu" + ( + self.model, + self.image_processor, + self.tokenizer, + ) = create_model_and_transforms( + model_args["vision_encoder_path"], + model_args["vision_encoder_pretrained"], + model_args["lm_path"], + model_args["lm_tokenizer_path"], + cross_attn_every_n_layers=int(model_args["cross_attn_every_n_layers"]), + ) + checkpoint = torch.load(model_args["checkpoint_path"], map_location="cpu") + self.model.load_state_dict(checkpoint, strict=False) + self.model.to("npu") + self.model.eval() + self.tokenizer.padding_side = "left" + + def _prepare_images(self, batch: List[List[torch.Tensor]]) -> torch.Tensor: + """Preprocess images and stack them. + + Args: + batch: A list of lists of images. + + Returns: + A Tensor of shape + (batch_size, images_per_example, frames, channels, height, width). + """ + images_per_example = max(len(x) for x in batch) + batch_images = None + for iexample, example in enumerate(batch): + for iimage, image in enumerate(example): + preprocessed = self.image_processor(image) + + if batch_images is None: + batch_images = torch.zeros( + (len(batch), images_per_example, 1) + preprocessed.shape, + dtype=preprocessed.dtype, + ) + batch_images[iexample, iimage, 0] = preprocessed + return batch_images + + def get_outputs( + self, + batch_text: List[str], + batch_images: List[List[Image.Image]], + max_generation_length: int, + num_beams: int, + length_penalty: float, + ) -> List[str]: + encodings = self.tokenizer( + batch_text, + padding="longest", + truncation=True, + return_tensors="pt", + max_length=2000, + ) + input_ids = encodings["input_ids"] + attention_mask = encodings["attention_mask"] + + with torch.inference_mode(): + outputs = self.model.generate( + self._prepare_images(batch_images).to(self.device), + input_ids.to(self.device), + attention_mask=attention_mask.to(self.device), + max_new_tokens=max_generation_length, + num_beams=num_beams, + length_penalty=length_penalty, + ) + + outputs = outputs[:, len(input_ids[0]) :] + + return self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + + def get_vqa_prompt(self, question, answer=None) -> str: + return f"Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}" + + def get_caption_prompt(self, caption=None) -> str: + return f"Output:{caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}" + + def get_classification_prompt(self, class_str=None) -> str: + return f"A photo of a {class_str if class_str is not None else ''}{'<|endofchunk|>' if class_str is not None else ''}" diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/ok_vqa_utils.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/ok_vqa_utils.py new file mode 100644 index 0000000000..cbe6feeed4 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/ok_vqa_utils.py @@ -0,0 +1,214 @@ +# Those are manual mapping that are not caught by our stemming rules or would +# would be done incorrectly by our automatic stemming rule. In details, +# the keys of the _MANUAL_MATCHES dict contains the original word and the value +# contains the transformation of the word expected by the OKVQA stemming rule. +# These manual rules were found by checking the `raw_answers` and the `answers` +# fields of the released OKVQA dataset and checking all things that were not +# properly mapped by our automatic rules. In particular some of the mapping +# are sometimes constant, e.g. christmas -> christmas which was incorrectly +# singularized by our inflection.singularize. +import re +import nltk +from nltk.corpus.reader import VERB +import inflection + +_MANUAL_MATCHES = { + "police": "police", + "las": "las", + "vegas": "vegas", + "yes": "yes", + "jeans": "jean", + "hell's": "hell", + "domino's": "domino", + "morning": "morn", + "clothes": "cloth", + "are": "are", + "riding": "ride", + "leaves": "leaf", + "dangerous": "danger", + "clothing": "cloth", + "texting": "text", + "kiting": "kite", + "firefighters": "firefight", + "ties": "tie", + "married": "married", + "teething": "teeth", + "gloves": "glove", + "tennis": "tennis", + "dining": "dine", + "directions": "direct", + "waves": "wave", + "christmas": "christmas", + "drives": "drive", + "pudding": "pud", + "coding": "code", + "plating": "plate", + "quantas": "quanta", + "hornes": "horn", + "graves": "grave", + "mating": "mate", + "paned": "pane", + "alertness": "alert", + "sunbathing": "sunbath", + "tenning": "ten", + "wetness": "wet", + "urinating": "urine", + "sickness": "sick", + "braves": "brave", + "firefighting": "firefight", + "lenses": "lens", + "reflections": "reflect", + "backpackers": "backpack", + "eatting": "eat", + "designers": "design", + "curiousity": "curious", + "playfulness": "play", + "blindness": "blind", + "hawke": "hawk", + "tomatoe": "tomato", + "rodeoing": "rodeo", + "brightness": "bright", + "circuses": "circus", + "skateboarders": "skateboard", + "staring": "stare", + "electronics": "electron", + "electicity": "elect", + "mountainous": "mountain", + "socializing": "social", + "hamburgers": "hamburg", + "caves": "cave", + "transitions": "transit", + "wading": "wade", + "creame": "cream", + "toileting": "toilet", + "sautee": "saute", + "buildings": "build", + "belongings": "belong", + "stockings": "stock", + "walle": "wall", + "cumulis": "cumuli", + "travelers": "travel", + "conducter": "conduct", + "browsing": "brows", + "pooping": "poop", + "haircutting": "haircut", + "toppings": "top", + "hearding": "heard", + "sunblocker": "sunblock", + "bases": "base", + "markings": "mark", + "mopeds": "mope", + "kindergartener": "kindergarten", + "pies": "pie", + "scrapbooking": "scrapbook", + "couponing": "coupon", + "meetings": "meet", + "elevators": "elev", + "lowes": "low", + "men's": "men", + "childrens": "children", + "shelves": "shelve", + "paintings": "paint", + "raines": "rain", + "paring": "pare", + "expressions": "express", + "routes": "rout", + "pease": "peas", + "vastness": "vast", + "awning": "awn", + "boy's": "boy", + "drunkenness": "drunken", + "teasing": "teas", + "conferences": "confer", + "ripeness": "ripe", + "suspenders": "suspend", + "earnings": "earn", + "reporters": "report", + "kid's": "kid", + "containers": "contain", + "corgie": "corgi", + "porche": "porch", + "microwaves": "microwave", + "batter's": "batter", + "sadness": "sad", + "apartments": "apart", + "oxygenize": "oxygen", + "striping": "stripe", + "purring": "pure", + "professionals": "profession", + "piping": "pipe", + "farmer's": "farmer", + "potatoe": "potato", + "emirates": "emir", + "womens": "women", + "veteran's": "veteran", + "wilderness": "wilder", + "propellers": "propel", + "alpes": "alp", + "charioteering": "chariot", + "swining": "swine", + "illness": "ill", + "crepte": "crept", + "adhesives": "adhesive", + "regent's": "regent", + "decorations": "decor", + "rabbies": "rabbi", + "overseas": "oversea", + "travellers": "travel", + "casings": "case", + "smugness": "smug", + "doves": "dove", + "nationals": "nation", + "mustange": "mustang", + "ringe": "ring", + "gondoliere": "gondolier", + "vacationing": "vacate", + "reminders": "remind", + "baldness": "bald", + "settings": "set", + "glaced": "glace", + "coniferous": "conifer", + "revelations": "revel", + "personals": "person", + "daughter's": "daughter", + "badness": "bad", + "projections": "project", + "polarizing": "polar", + "vandalizers": "vandal", + "minerals": "miner", + "protesters": "protest", + "controllers": "control", + "weddings": "wed", + "sometimes": "sometime", + "earing": "ear", +} + + +class OKVQAStemmer: + """Stemmer to match OKVQA v1.1 procedure.""" + + def __init__(self): + self._wordnet_lemmatizer = nltk.stem.WordNetLemmatizer() + + def stem(self, input_string): + """Apply stemming.""" + word_and_pos = nltk.pos_tag(nltk.tokenize.word_tokenize(input_string)) + stemmed_words = [] + for w, p in word_and_pos: + if w in _MANUAL_MATCHES: + w = _MANUAL_MATCHES[w] + elif w.endswith("ing"): + w = self._wordnet_lemmatizer.lemmatize(w, VERB) + elif p.startswith("NNS") or p.startswith("NNPS"): + w = inflection.singularize(w) + stemmed_words.append(w) + return " ".join(stemmed_words) + + +stemmer = OKVQAStemmer() + + +def postprocess_ok_vqa_generation(predictions) -> str: + prediction = re.split("Question|Answer|Short", predictions, 1)[0] + prediction_stem = stemmer.stem(prediction) + return prediction_stem diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/vqa_metric.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/vqa_metric.py new file mode 100644 index 0000000000..47c0aaa6f0 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/eval/vqa_metric.py @@ -0,0 +1,581 @@ +import copy +import datetime +import json +import os +import random +import re +import sys + +# Interface for accessing the VQA dataset. + +# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: +# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py). + +# The following functions are defined: +# VQA - VQA class that loads VQA annotation file and prepares data structures. +# getQuesIds - Get question ids that satisfy given filter conditions. +# getImgIds - Get image ids that satisfy given filter conditions. +# loadQA - Load questions and answers with the specified question ids. +# showQA - Display the specified questions and answers. +# loadRes - Load result file and create result object. + +# Help on each function can be accessed by: "help(COCO.function)" + + +class VQA: + def __init__(self, annotation_file=None, question_file=None): + """ + Constructor of VQA helper class for reading and visualizing questions and answers. + :param annotation_file (str): location of VQA annotation file + :return: + """ + # load dataset + self.dataset = {} + self.questions = {} + self.qa = {} + self.qqa = {} + self.imgToQA = {} + if not annotation_file == None and not question_file == None: + print("loading VQA annotations and questions into memory...") + time_t = datetime.datetime.utcnow() + dataset = json.load(open(annotation_file, "r")) + questions = json.load(open(question_file, "r")) + print(datetime.datetime.utcnow() - time_t) + self.dataset = dataset + self.questions = questions + self.createIndex() + + def createIndex(self): + # create index + print("creating index...") + imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]} + qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]} + qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]} + for ann in self.dataset["annotations"]: + imgToQA[ann["image_id"]] += [ann] + qa[ann["question_id"]] = ann + for ques in self.questions["questions"]: + qqa[ques["question_id"]] = ques + print("index created!") + + # create class members + self.qa = qa + self.qqa = qqa + self.imgToQA = imgToQA + + def info(self): + """ + Print information about the VQA annotation file. + :return: + """ + for key, value in self.dataset["info"].items(): + print("%s: %s" % (key, value)) + + def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]): + """ + Get question ids that satisfy given filter conditions. default skips that filter + :param imgIds (int array) : get question ids for given imgs + quesTypes (str array) : get question ids for given question types + ansTypes (str array) : get question ids for given answer types + :return: ids (int array) : integer array of question ids + """ + imgIds = imgIds if type(imgIds) == list else [imgIds] + quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] + ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] + + if len(imgIds) == len(quesTypes) == len(ansTypes) == 0: + anns = self.dataset["annotations"] + else: + if not len(imgIds) == 0: + anns = sum( + [self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], + [], + ) + else: + anns = self.dataset["annotations"] + anns = ( + anns + if len(quesTypes) == 0 + else [ann for ann in anns if ann["question_type"] in quesTypes] + ) + anns = ( + anns + if len(ansTypes) == 0 + else [ann for ann in anns if ann["answer_type"] in ansTypes] + ) + ids = [ann["question_id"] for ann in anns] + return ids + + def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]): + """ + Get image ids that satisfy given filter conditions. default skips that filter + :param quesIds (int array) : get image ids for given question ids + quesTypes (str array) : get image ids for given question types + ansTypes (str array) : get image ids for given answer types + :return: ids (int array) : integer array of image ids + """ + quesIds = quesIds if type(quesIds) == list else [quesIds] + quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] + ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] + + if len(quesIds) == len(quesTypes) == len(ansTypes) == 0: + anns = self.dataset["annotations"] + else: + if not len(quesIds) == 0: + anns = sum( + [self.qa[quesId] for quesId in quesIds if quesId in self.qa], [] + ) + else: + anns = self.dataset["annotations"] + anns = ( + anns + if len(quesTypes) == 0 + else [ann for ann in anns if ann["question_type"] in quesTypes] + ) + anns = ( + anns + if len(ansTypes) == 0 + else [ann for ann in anns if ann["answer_type"] in ansTypes] + ) + ids = [ann["image_id"] for ann in anns] + return ids + + def loadQA(self, ids=[]): + """ + Load questions and answers with the specified question ids. + :param ids (int array) : integer ids specifying question ids + :return: qa (object array) : loaded qa objects + """ + if type(ids) == list: + return [self.qa[id] for id in ids] + elif type(ids) == int: + return [self.qa[ids]] + + def showQA(self, anns): + """ + Display the specified annotations. + :param anns (array of object): annotations to display + :return: None + """ + if len(anns) == 0: + return 0 + for ann in anns: + quesId = ann["question_id"] + print("Question: %s" % (self.qqa[quesId]["question"])) + for ans in ann["answers"]: + print("Answer %d: %s" % (ans["answer_id"], ans["answer"])) + + def loadRes(self, resFile, quesFile): + """ + Load result file and return a result object. + :param resFile (str) : file name of result file + :return: res (obj) : result api object + """ + res = VQA() + res.questions = json.load(open(quesFile)) + res.dataset["info"] = copy.deepcopy(self.questions["info"]) + res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"]) + res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"]) + res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"]) + res.dataset["license"] = copy.deepcopy(self.questions["license"]) + + print("Loading and preparing results... ") + time_t = datetime.datetime.utcnow() + anns = json.load(open(resFile)) + assert type(anns) == list, "results is not an array of objects" + annsQuesIds = [ann["question_id"] for ann in anns] + # print set of question ids that do not have corresponding annotations + + # assert set(annsQuesIds) == set(self.getQuesIds()), \ + # 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.' + for ann in anns: + quesId = ann["question_id"] + if res.dataset["task_type"] == "Multiple Choice": + assert ( + ann["answer"] in self.qqa[quesId]["multiple_choices"] + ), "predicted answer is not one of the multiple choices" + qaAnn = self.qa[quesId] + ann["image_id"] = qaAnn["image_id"] + ann["question_type"] = qaAnn["question_type"] + if "answer_type" in ann: + ann["answer_type"] = qaAnn["answer_type"] + print( + "DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds()) + ) + + res.dataset["annotations"] = anns + res.createIndex() + return res + + +class VQAEval: + def __init__(self, vqa, vqaRes, n=2): + self.n = n + self.accuracy = {} + self.evalQA = {} + self.evalQuesType = {} + self.evalAnsType = {} + self.vqa = vqa + self.vqaRes = vqaRes + self.params = {"question_id": vqaRes.getQuesIds()} + self.contractions = { + "aint": "ain't", + "arent": "aren't", + "cant": "can't", + "couldve": "could've", + "couldnt": "couldn't", + "couldn'tve": "couldn't've", + "couldnt've": "couldn't've", + "didnt": "didn't", + "doesnt": "doesn't", + "dont": "don't", + "hadnt": "hadn't", + "hadnt've": "hadn't've", + "hadn'tve": "hadn't've", + "hasnt": "hasn't", + "havent": "haven't", + "hed": "he'd", + "hed've": "he'd've", + "he'dve": "he'd've", + "hes": "he's", + "howd": "how'd", + "howll": "how'll", + "hows": "how's", + "Id've": "I'd've", + "I'dve": "I'd've", + "Im": "I'm", + "Ive": "I've", + "isnt": "isn't", + "itd": "it'd", + "itd've": "it'd've", + "it'dve": "it'd've", + "itll": "it'll", + "let's": "let's", + "maam": "ma'am", + "mightnt": "mightn't", + "mightnt've": "mightn't've", + "mightn'tve": "mightn't've", + "mightve": "might've", + "mustnt": "mustn't", + "mustve": "must've", + "neednt": "needn't", + "notve": "not've", + "oclock": "o'clock", + "oughtnt": "oughtn't", + "ow's'at": "'ow's'at", + "'ows'at": "'ow's'at", + "'ow'sat": "'ow's'at", + "shant": "shan't", + "shed've": "she'd've", + "she'dve": "she'd've", + "she's": "she's", + "shouldve": "should've", + "shouldnt": "shouldn't", + "shouldnt've": "shouldn't've", + "shouldn'tve": "shouldn't've", + "somebody'd": "somebodyd", + "somebodyd've": "somebody'd've", + "somebody'dve": "somebody'd've", + "somebodyll": "somebody'll", + "somebodys": "somebody's", + "someoned": "someone'd", + "someoned've": "someone'd've", + "someone'dve": "someone'd've", + "someonell": "someone'll", + "someones": "someone's", + "somethingd": "something'd", + "somethingd've": "something'd've", + "something'dve": "something'd've", + "somethingll": "something'll", + "thats": "that's", + "thered": "there'd", + "thered've": "there'd've", + "there'dve": "there'd've", + "therere": "there're", + "theres": "there's", + "theyd": "they'd", + "theyd've": "they'd've", + "they'dve": "they'd've", + "theyll": "they'll", + "theyre": "they're", + "theyve": "they've", + "twas": "'twas", + "wasnt": "wasn't", + "wed've": "we'd've", + "we'dve": "we'd've", + "weve": "we've", + "werent": "weren't", + "whatll": "what'll", + "whatre": "what're", + "whats": "what's", + "whatve": "what've", + "whens": "when's", + "whered": "where'd", + "wheres": "where's", + "whereve": "where've", + "whod": "who'd", + "whod've": "who'd've", + "who'dve": "who'd've", + "wholl": "who'll", + "whos": "who's", + "whove": "who've", + "whyll": "why'll", + "whyre": "why're", + "whys": "why's", + "wont": "won't", + "wouldve": "would've", + "wouldnt": "wouldn't", + "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", + "yall": "y'all", + "yall'll": "y'all'll", + "y'allll": "y'all'll", + "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", + "y'all'dve": "y'all'd've", + "youd": "you'd", + "youd've": "you'd've", + "you'dve": "you'd've", + "youll": "you'll", + "youre": "you're", + "youve": "you've", + } + self.manualMap = { + "none": "0", + "zero": "0", + "one": "1", + "two": "2", + "three": "3", + "four": "4", + "five": "5", + "six": "6", + "seven": "7", + "eight": "8", + "nine": "9", + "ten": "10", + } + self.articles = ["a", "an", "the"] + + self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") + self.commaStrip = re.compile("(\d)(\,)(\d)") + self.punct = [ + ";", + r"/", + "[", + "]", + '"', + "{", + "}", + "(", + ")", + "=", + "+", + "\\", + "_", + "-", + ">", + "<", + "@", + "`", + ",", + "?", + "!", + ] + + def evaluate(self, quesIds=None): + if quesIds == None: + quesIds = [quesId for quesId in self.params["question_id"]] + gts = {} + res = {} + for quesId in quesIds: + gts[quesId] = self.vqa.qa[quesId] + res[quesId] = self.vqaRes.qa[quesId] + + # ================================================= + # Compute accuracy + # ================================================= + accQA = [] + accQuesType = {} + accAnsType = {} + print("computing accuracy") + step = 0 + for quesId in quesIds: + for ansDic in gts[quesId]["answers"]: + ansDic["answer"] = ansDic["answer"].replace("\n", " ") + ansDic["answer"] = ansDic["answer"].replace("\t", " ") + ansDic["answer"] = ansDic["answer"].strip() + resAns = res[quesId]["answer"] + resAns = resAns.replace("\n", " ") + resAns = resAns.replace("\t", " ") + resAns = resAns.strip() + resAns = self.processPunctuation(resAns) + resAns = self.processDigitArticle(resAns) + gtAcc = [] + + for ansDic in gts[quesId]["answers"]: + ansDic["answer"] = self.processPunctuation(ansDic["answer"]) + ansDic["answer"] = self.processDigitArticle(ansDic["answer"]) + + for gtAnsDatum in gts[quesId]["answers"]: + otherGTAns = [ + item for item in gts[quesId]["answers"] if item != gtAnsDatum + ] + matchingAns = [item for item in otherGTAns if item["answer"] == resAns] + acc = min(1, float(len(matchingAns)) / 3) + gtAcc.append(acc) + quesType = gts[quesId]["question_type"] + ansType = ( + gts[quesId]["answer_type"] if "answer_type" in gts[quesId] else "other" + ) + avgGTAcc = float(sum(gtAcc)) / len(gtAcc) + accQA.append(avgGTAcc) + if quesType not in accQuesType: + accQuesType[quesType] = [] + accQuesType[quesType].append(avgGTAcc) + if ansType not in accAnsType: + accAnsType[ansType] = [] + accAnsType[ansType].append(avgGTAcc) + self.setEvalQA(quesId, avgGTAcc) + self.setEvalQuesType(quesId, quesType, avgGTAcc) + self.setEvalAnsType(quesId, ansType, avgGTAcc) + if step % 100 == 0: + self.updateProgress(step / float(len(quesIds))) + step = step + 1 + + self.setAccuracy(accQA, accQuesType, accAnsType) + print("Done computing accuracy") + + def processPunctuation(self, inText): + outText = inText + for p in self.punct: + if (p + " " in inText or " " + p in inText) or ( + re.search(self.commaStrip, inText) != None + ): + outText = outText.replace(p, "") + else: + outText = outText.replace(p, " ") + outText = self.periodStrip.sub("", outText, re.UNICODE) + return outText + + def processDigitArticle(self, inText): + outText = [] + tempText = inText.lower().split() + for word in tempText: + word = self.manualMap.setdefault(word, word) + if word not in self.articles: + outText.append(word) + else: + pass + for wordId, word in enumerate(outText): + if word in self.contractions: + outText[wordId] = self.contractions[word] + outText = " ".join(outText) + return outText + + def setAccuracy(self, accQA, accQuesType, accAnsType): + self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n) + self.accuracy["perQuestionType"] = { + quesType: round( + 100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]), + self.n, + ) + for quesType in accQuesType + } + self.accuracy["perAnswerType"] = { + ansType: round( + 100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n + ) + for ansType in accAnsType + } + + def setEvalQA(self, quesId, acc): + self.evalQA[quesId] = round(100 * acc, self.n) + + def setEvalQuesType(self, quesId, quesType, acc): + if quesType not in self.evalQuesType: + self.evalQuesType[quesType] = {} + self.evalQuesType[quesType][quesId] = round(100 * acc, self.n) + + def setEvalAnsType(self, quesId, ansType, acc): + if ansType not in self.evalAnsType: + self.evalAnsType[ansType] = {} + self.evalAnsType[ansType][quesId] = round(100 * acc, self.n) + + def updateProgress(self, progress): + barLength = 20 + status = "" + if isinstance(progress, int): + progress = float(progress) + if not isinstance(progress, float): + progress = 0 + status = "error: progress var must be float\r\n" + if progress < 0: + progress = 0 + status = "Halt...\r\n" + if progress >= 1: + progress = 1 + status = "Done...\r\n" + block = int(round(barLength * progress)) + text = "\rFinshed Percent: [{0}] {1}% {2}".format( + "#" * block + "-" * (barLength - block), int(progress * 100), status + ) + sys.stdout.write(text) + sys.stdout.flush() + + +def compute_vqa_accuracy(result_json_path, question_json_path, annotation_json_path): + """Compute the VQA accuracy metric. + + Args: + predictions (List): list of predictions + ground_truth (List[List]): list of all possible ground truth answers + + Returns: + float: VQA accuracy + """ + # coding: utf-8 + # dataDir = data_dir + + # set up file names and paths + # versionType = 'v2_' # this should be '' when using VQA v2.0 dataset + # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0 + # taskType = 'OpenEnded' + # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0. + # dataType = 'mscoco' + # dataSubType = 'train2014' + # annFile = '%s/%s%s_%s_annotations.json' % ( + # dataDir, versionType, dataType, dataSubType) + # quesFile = '%s/%s%s_%s_%s_questions.json' % ( + # dataDir, versionType, taskType, dataType, dataSubType) + # imgDir = '%s/%s/%s/' % (dataDir, dataType, dataSubType) + # resultType = res_file_name + # fileTypes = ['results', 'accuracy', + # 'evalQA', 'evalQuesType', 'evalAnsType'] + + # An example result json file has been provided in './Results' folder. + + # [resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/%s%s_%s_%s_%s_%s.json' % (dataDir, versionType, taskType, dataType, dataSubType, + # resultType, fileType) for fileType in fileTypes] + + # create vqa object and vqaRes object + vqa = VQA(annotation_json_path, question_json_path) + vqaRes = vqa.loadRes(result_json_path, question_json_path) + + # create vqaEval object by taking vqa and vqaRes + # n is precision of accuracy (number of places after decimal), default is 2 + vqaEval = VQAEval(vqa, vqaRes, n=2) + + # evaluate results + """ + If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function + By default it uses all the question ids in annotation file + """ + vqaEval.evaluate() + + return vqaEval.accuracy["overall"] + + +def postprocess_vqa_generation(predictions): + answer = re.split("Question|Answer|Short", predictions, 1)[0] + answer = re.split(", ", answer, 1)[0] + return answer diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/scripts/run_eval.sh b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/scripts/run_eval.sh new file mode 100644 index 0000000000..5b93d0b29d --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/scripts/run_eval.sh @@ -0,0 +1,38 @@ +# echo 'activating virtual environment' +# source ~/.bashrc +# eval "$(conda shell.bash hook)" +# conda activate openflamingo +# which python + +LM_PATH="facebook/opt-1.3b" +LM_TOKENIZER_PATH="facebook/opt-1.3b" +VISION_ENCODER_NAME="ViT-L-14" +VISION_ENCODER_PRETRAINED="" +CKPT_PATH="/home/data2/linzheyuan/open_flamingo/flamingo3B/checkpoint_0.pt" +DEVICE="0" + +COCO_IMG_PATH="/home/data1/coco/train2017" +COCO_ANNO_PATH="/home/data1/coco/annotations/captions_train2017.json" + +RANDOM_ID=$$ +RESULTS_FILE="results_${RANDOM_ID}.json" + +python open_flamingo/eval/evaluate.py \ + --lm_path $LM_PATH \ + --lm_tokenizer_path $LM_TOKENIZER_PATH \ + --vision_encoder_path $VISION_ENCODER_NAME \ + --vision_encoder_pretrained $VISION_ENCODER_PRETRAINED \ + --checkpoint_path $CKPT_PATH \ + --cross_attn_every_n_layers 4 \ + --device $DEVICE \ + --coco_image_dir_path $COCO_IMG_PATH \ + --coco_annotations_json_path $COCO_ANNO_PATH \ + --results_file $RESULTS_FILE \ + --eval_coco \ + --num_samples 16 \ + --shots 8 \ + --num_trials 1 \ + --batch_size 1 + + +echo "evaluation complete! results written to ${RESULTS_FILE}" diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/scripts/run_eval_backup.sh b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/scripts/run_eval_backup.sh new file mode 100644 index 0000000000..ddafae06c6 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/scripts/run_eval_backup.sh @@ -0,0 +1,45 @@ +echo 'activating virtual environment' +source ~/.bashrc +eval "$(conda shell.bash hook)" +conda activate openflamingo +which python + +LM_PATH="luodian/llama-7b-hf" +LM_TOKENIZER_PATH="luodian/llama-7b-hf" +VISION_ENCODER_NAME="ViT-L-14" +VISION_ENCODER_PRETRAINED="openai" +CKPT_PATH="openflamingo/OpenFlamingo-9B/checkpoint.pt" +DEVICE="0" + +COCO_IMG_PATH="/train2017/" +COCO_ANNO_PATH="/annotations/captions_train2017.json" +VQAV2_IMG_PATH="/train2014" +VQAV2_ANNO_PATH="/v2_mscoco_train2014_annotations.json" +VQAV2_QUESTION_PATH="/v2_OpenEnded_mscoco_train2014_questions.json" + +RANDOM_ID=$$ +RESULTS_FILE="results_${RANDOM_ID}.json" + +python open_flamingo/eval/evaluate.py \ + --lm_path $LM_PATH \ + --lm_tokenizer_path $LM_TOKENIZER_PATH \ + --vision_encoder_path $VISION_ENCODER_NAME \ + --vision_encoder_pretrained $VISION_ENCODER_PRETRAINED \ + --checkpoint_path $CKPT_PATH \ + --cross_attn_every_n_layers 4 \ + --device $DEVICE \ + --coco_image_dir_path $COCO_IMG_PATH \ + --coco_annotations_json_path $COCO_ANNO_PATH \ + --vqav2_image_dir_path $VQAV2_IMG_PATH \ + --vqav2_annotations_json_path $VQAV2_ANNO_PATH \ + --vqav2_questions_json_path $VQAV2_QUESTION_PATH \ + --results_file $RESULTS_FILE \ + --eval_coco \ + --eval_vqav2 \ + --num_samples 5000 \ + --shots 8 \ + --num_trials 1 \ + --batch_size 1 + + +echo "evaluation complete! results written to ${RESULTS_FILE}" diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/__init__.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/factory.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/factory.py new file mode 100644 index 0000000000..a67c2c37f2 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/factory.py @@ -0,0 +1,109 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer +import open_clip + +from .flamingo import Flamingo +from .flamingo_lm import FlamingoLMMixin +from .utils import extend_instance + + +def create_model_and_transforms( + clip_vision_encoder_path: str, + clip_vision_encoder_pretrained: str, + lang_encoder_path: str, + tokenizer_path: str, + cross_attn_every_n_layers: int = 1, + use_local_files: bool = False, + decoder_layers_attr_name: str = None, + **flamingo_kwargs, +): + """ + Initialize a Flamingo model from a pretrained vision encoder and language encoder. + Appends special tokens to the tokenizer and freezes backbones. + + Args: + clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32") + clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k") + lang_encoder_path (str): path to pretrained language encoder + tokenizer_path (str): path to pretrained tokenizer + cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1. + use_local_files (bool, optional): whether to use local files. Defaults to False. + decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. + Returns: + Flamingo: Flamingo model from pretrained vision and language encoders + Image processor: Pipeline to preprocess input images + Tokenizer: A tokenizer for the language model + """ + vision_encoder, _, image_processor = open_clip.create_model_and_transforms( + clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained + ) + # set the vision encoder to output the visual features + vision_encoder.visual.output_tokens = True + + text_tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, local_files_only=use_local_files + ) + # add Flamingo special tokens to the tokenizer + text_tokenizer.add_special_tokens( + {"additional_special_tokens": ["<|endofchunk|>", ""]} + ) + if text_tokenizer.pad_token is None: + # Issue: GPT models don't have a pad token, which we use to + # modify labels for the loss. + text_tokenizer.add_special_tokens({"pad_token": ""}) + + lang_encoder = AutoModelForCausalLM.from_pretrained( + lang_encoder_path, local_files_only=use_local_files + ) + extend_instance(lang_encoder, FlamingoLMMixin) + + if decoder_layers_attr_name is None: + decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) + lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) + lang_encoder.resize_token_embeddings(len(text_tokenizer)) + + model = Flamingo( + vision_encoder, + lang_encoder, + text_tokenizer.encode("<|endofchunk|>")[-1], + text_tokenizer.encode("")[-1], + vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"][ + "width" + ], + cross_attn_every_n_layers=cross_attn_every_n_layers, + **flamingo_kwargs, + ) + + # Freeze all parameters + model.requires_grad_(False) + assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 + + # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings + model.perceiver.requires_grad_(True) + model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) + model.lang_encoder.get_input_embeddings().requires_grad_(True) + + print( + f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" + ) + + return model, image_processor, text_tokenizer + + +def _infer_decoder_layers_attr_name(model): + for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES: + if k.lower() in model.__class__.__name__.lower(): + return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] + + raise ValueError( + f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually." + ) + + +__KNOWN_DECODER_LAYERS_ATTR_NAMES = { + "opt": "model.decoder.layers", + "gptneo": "transformer.h", + "gptj": "transformer.h", + "gpt-j": "transformer.h", + "pythia": "gpt_neox.layers", + "llama": "model.layers", +} diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/flamingo.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/flamingo.py new file mode 100644 index 0000000000..a8eae254b3 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/flamingo.py @@ -0,0 +1,198 @@ +import torch +from einops import rearrange +from torch import nn + +from .helpers import PerceiverResampler + + +class Flamingo(nn.Module): + def __init__( + self, + vision_encoder: nn.Module, + lang_encoder: nn.Module, + eoc_token_id: int, + media_token_id: int, + vis_dim: int, + cross_attn_every_n_layers: int = 1, + use_media_placement_augmentation: bool = False, + ): + """ + Args: + vision_encoder (nn.Module): HF CLIPModel + lang_encoder (nn.Module): HF causal language model + eoc_token_id (int): Token id for <|endofchunk|> + media_token_id (int): Token id for + vis_dim (int): Dimension of the visual features. + Visual features are projected to match this shape along the last dimension. + cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1. + use_media_placement_augmentation (bool, optional): Whether to randomly assign images to the preceding or following text in training. Defaults to False. + """ + super().__init__() + self.eoc_token_id = eoc_token_id + self.media_token_id = media_token_id + self.use_media_placement_augmentation = use_media_placement_augmentation + self.vis_dim = vis_dim + self.vision_encoder = vision_encoder + self.perceiver = PerceiverResampler(dim=self.vis_dim) + self.lang_encoder = lang_encoder + self.lang_encoder.init_flamingo( + media_token_id=media_token_id, + vis_hidden_size=self.vis_dim, + cross_attn_every_n_layers=cross_attn_every_n_layers, + use_media_placement_augmentation=self.use_media_placement_augmentation, + ) + + def forward( + self, + vision_x: torch.Tensor, + lang_x: torch.Tensor, + attention_mask: torch.Tensor = None, + labels: torch.Tensor = None, + use_cached_vision_x: bool = False, + clear_conditioned_layers: bool = True, + past_key_values=None, + use_cache: bool = False, + ): + """ + Forward pass of Flamingo. + + Args: + vision_x (torch.Tensor): Vision input + shape (B, T_img, F, C, H, W) with F=1 + lang_x (torch.Tensor): Language input ids + shape (B, T_txt) + attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. + labels (torch.Tensor, optional): Labels. Defaults to None. + clear_conditioned_layers: if True, clear the conditioned layers + once the foward pass is completed. Set this to false if the + same set of images will be reused in another subsequent + forward pass. + past_key_values: pre-computed values to pass to language model. + See past_key_values documentation in Hugging Face + CausalLM models. + use_cache: whether to use cached key values. See use_cache + documentation in Hugging Face CausalLM models. + """ + assert ( + vision_x is not None + ) or use_cached_vision_x, ( + "Must provide either vision_x or use_cached_vision_x to True." + ) + + if use_cached_vision_x: + # Case: use cached; vision_x should be cached and other + # vision-related inputs should not be provided. + assert ( + vision_x is None + ), "Expect vision_x to be None when use_cached_vision_x is True." + assert self.lang_encoder.is_conditioned() + + else: + # Case: do not use caching (i.e. this is a standard forward pass); + self._encode_vision_x(vision_x=vision_x) + + output = self.lang_encoder( + input_ids=lang_x, + attention_mask=attention_mask, + labels=labels, + past_key_values=past_key_values, + use_cache=use_cache, + ) + + if clear_conditioned_layers: + self.lang_encoder.clear_conditioned_layers() + + return output + + def generate( + self, + vision_x: torch.Tensor, + lang_x: torch.Tensor, + attention_mask: torch.Tensor = None, + num_beams=1, + max_new_tokens=None, + temperature=1.0, + top_k=0, + top_p=1.0, + no_repeat_ngram_size=0, + prefix_allowed_tokens_fn=None, + length_penalty=1.0, + num_return_sequences=1, + do_sample=False, + early_stopping=False, + ): + """ + Generate text conditioned on vision and language inputs. + + Args: + vision_x (torch.Tensor): Vision input + shape (B, T_img, F, C, H, W) + images in the same chunk are collated along T_img, and frames are collated along F + currently only F=1 is supported (single-frame videos) + lang_x (torch.Tensor): Language input + shape (B, T_txt) + max_length (int, optional): Maximum length of the output. Defaults to None. + attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. + num_beams (int, optional): Number of beams. Defaults to 1. + max_new_tokens (int, optional): Maximum new tokens. Defaults to None. + temperature (float, optional): Temperature. Defaults to 1.0. + top_k (int, optional): Top k. Defaults to 0. + top_p (float, optional): Top p. Defaults to 1.0. + no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0. + length_penalty (float, optional): Length penalty. Defaults to 1.0. + num_return_sequences (int, optional): Number of return sequences. Defaults to 1. + do_sample (bool, optional): Do sample. Defaults to False. + early_stopping (bool, optional): Early stopping. Defaults to False. + Returns: + torch.Tensor: lang_x with generated tokens appended to it + """ + if num_beams > 1: + vision_x = vision_x.repeat_interleave(num_beams, dim=0) + + self._encode_vision_x(vision_x=vision_x) + + output = self.lang_encoder.generate( + lang_x, + attention_mask=attention_mask, + eos_token_id=self.eoc_token_id, + num_beams=num_beams, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + no_repeat_ngram_size=no_repeat_ngram_size, + length_penalty=length_penalty, + num_return_sequences=num_return_sequences, + do_sample=do_sample, + early_stopping=early_stopping, + ) + + self.lang_encoder.clear_conditioned_layers() + return output + + def _encode_vision_x(self, vision_x: torch.Tensor): + """ + Compute media tokens from vision input by passing it through vision encoder and conditioning language model. + Args: + vision_x (torch.Tensor): Vision input + shape (B, T_img, F, C, H, W) + Images in the same chunk are collated along T_img, and frames are collated along F + Currently only F=1 is supported (single-frame videos) + + rearrange code based on https://github.com/dhansmair/flamingo-mini + """ + + assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" + b, T, F = vision_x.shape[:3] + assert F == 1, "Only single frame supported" + + vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") + with torch.no_grad(): + vision_x = self.vision_encoder.visual(vision_x)[1] + vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) + + vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d) + + for layer in self.lang_encoder._get_decoder_layers(): + layer.condition_vis_x(vision_x) diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/flamingo_lm.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/flamingo_lm.py new file mode 100644 index 0000000000..c064d87ed1 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/flamingo_lm.py @@ -0,0 +1,138 @@ +import random + +import torch.nn as nn + +from .helpers import GatedCrossAttentionBlock +from .utils import getattr_recursive, setattr_recursive + + +class FlamingoLayer(nn.Module): + def __init__(self, gated_cross_attn_layer, decoder_layer): + super().__init__() + self.gated_cross_attn_layer = gated_cross_attn_layer + self.decoder_layer = decoder_layer + self.vis_x = None + self.media_locations = None + + def is_conditioned(self) -> bool: + """Check whether the layer is conditioned.""" + return self.vis_x is not None + + # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/) + def condition_vis_x(self, vis_x): + self.vis_x = vis_x + + def condition_media_locations(self, media_locations): + self.media_locations = media_locations + + def condition_attend_previous(self, attend_previous): + self.attend_previous = attend_previous + + def forward( + self, + lang_x, + attention_mask=None, + **decoder_layer_kwargs, + ): + if self.gated_cross_attn_layer is None: + return self.decoder_layer( + lang_x, attention_mask=attention_mask, **decoder_layer_kwargs + ) + + if self.vis_x is None: + raise ValueError("vis_x must be conditioned before forward pass") + + if self.media_locations is None: + raise ValueError("media_locations must be conditioned before forward pass") + + lang_x = self.gated_cross_attn_layer( + lang_x, + self.vis_x, + media_locations=self.media_locations, + attend_previous=self.attend_previous, + ) + lang_x = self.decoder_layer( + lang_x, attention_mask=attention_mask, **decoder_layer_kwargs + ) + return lang_x + + +class FlamingoLMMixin(nn.Module): + """ + Mixin to add cross-attention layers to a language model. + """ + + def set_decoder_layers_attr_name(self, decoder_layers_attr_name): + self.decoder_layers_attr_name = decoder_layers_attr_name + + def _get_decoder_layers(self): + return getattr_recursive(self, self.decoder_layers_attr_name) + + def _set_decoder_layers(self, value): + setattr_recursive(self, self.decoder_layers_attr_name, value) + + def init_flamingo( + self, + media_token_id, + vis_hidden_size, + cross_attn_every_n_layers, + use_media_placement_augmentation, + ): + """ + Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations. + """ + + self.gated_cross_attn_layers = nn.ModuleList( + [ + GatedCrossAttentionBlock( + dim=self.config.hidden_size, dim_visual=vis_hidden_size + ) + if (layer_idx + 1) % cross_attn_every_n_layers == 0 + else None + for layer_idx, _ in enumerate(self._get_decoder_layers()) + ] + ) + self._set_decoder_layers( + nn.ModuleList( + [ + FlamingoLayer(gated_cross_attn_layer, decoder_layer) + for gated_cross_attn_layer, decoder_layer in zip( + self.gated_cross_attn_layers, self._get_decoder_layers() + ) + ] + ) + ) + self.media_token_id = media_token_id + self.use_media_placement_augmentation = use_media_placement_augmentation + self.initialized_flamingo = True + + def forward(self, *input, **kwargs): + """Condition the Flamingo layers on the media locations before forward()""" + if not self.initialized_flamingo: + raise ValueError( + "Flamingo layers are not initialized. Please call `init_flamingo` first." + ) + + input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0] + media_locations = input_ids == self.media_token_id + attend_previous = ( + (random.random() < 0.5) if self.use_media_placement_augmentation else False + ) + + for layer in self.get_decoder().layers: + layer.condition_media_locations(media_locations) + layer.condition_attend_previous(attend_previous) + + return super().forward( + *input, **kwargs + ) # Call the other parent's forward method + + def is_conditioned(self) -> bool: + """Check whether all decoder layers are already conditioned.""" + return all(l.is_conditioned() for l in self._get_decoder_layers()) + + def clear_conditioned_layers(self): + for layer in self._get_decoder_layers(): + layer.condition_vis_x(None) + layer.condition_media_locations(None) + layer.condition_attend_previous(None) diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/helpers.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/helpers.py new file mode 100644 index 0000000000..78b4896980 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/helpers.py @@ -0,0 +1,275 @@ +""" +Taken from https://github.com/lucidrains/flamingo-pytorch +""" + +import torch +from einops import rearrange, repeat +from einops_exts import rearrange_many +from torch import einsum, nn + + +def exists(val): + return val is not None + + +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm_media = nn.LayerNorm(dim) + self.norm_latents = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, T, n1, D) + latent (torch.Tensor): latent features + shape (b, T, n2, D) + """ + x = self.norm_media(x) + latents = self.norm_latents(latents) + + h = self.heads + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) + q = q * self.scale + + # attention + sim = einsum("... i d, ... j d -> ... i j", q, k) + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + out = einsum("... i j, ... j d -> ... i d", attn, v) + out = rearrange(out, "b h t n d -> b t n (h d)", h=h) + return self.to_out(out) + + +class PerceiverResampler(nn.Module): + def __init__( + self, + *, + dim, + depth=6, + dim_head=64, + heads=8, + num_latents=64, + max_num_media=None, + max_num_frames=None, + ff_mult=4, + ): + super().__init__() + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + self.frame_embs = ( + nn.Parameter(torch.randn(max_num_frames, dim)) + if exists(max_num_frames) + else None + ) + self.media_time_embs = ( + nn.Parameter(torch.randn(max_num_media, 1, dim)) + if exists(max_num_media) + else None + ) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + """ + Args: + x (torch.Tensor): image features + shape (b, T, F, v, D) + Returns: + shape (b, T, n, D) where n is self.num_latents + """ + b, T, F, v = x.shape[:4] + + # frame and media time embeddings + if exists(self.frame_embs): + frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) + x = x + frame_embs + x = rearrange( + x, "b T F v d -> b T (F v) d" + ) # flatten the frame and spatial dimensions + if exists(self.media_time_embs): + x = x + self.media_time_embs[:T] + + # blocks + latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + return self.norm(latents) + + +# gated cross attention + + +class MaskedCrossAttention(nn.Module): + def __init__( + self, + *, + dim, + dim_visual, + dim_head=64, + heads=8, + only_attend_immediate_media=True, + ): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + # whether for text to only attend to immediate preceding image, or all previous images + self.only_attend_immediate_media = only_attend_immediate_media + + def forward(self, x, media, media_locations=None, attend_previous=True): + """ + Args: + x (torch.Tensor): text features + shape (B, T_txt, D_txt) + media (torch.Tensor): image features + shape (B, T_img, n, D_img) where n is the dim of the latents + media_locations: boolean mask identifying the media tokens in x + shape (B, T_txt) + attend_previous: bool + If false, ignores immediately preceding image and starts attending when following image + """ + _, T_img, n = media.shape[:3] + h = self.heads + + x = self.norm(x) + + q = self.to_q(x) + media = rearrange(media, "b t n d -> b (t n) d") + + k, v = self.to_kv(media).chunk(2, dim=-1) + q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h) + + q = q * self.scale + + sim = einsum("... i d, ... j d -> ... i j", q, k) + + if exists(media_locations): + # at each boolean of True, increment the time counter (relative to media time) + text_time = media_locations.cumsum(dim=-1) + media_time = torch.arange(T_img, device=x.device) + 1 + + if not attend_previous: + text_time[~media_locations] += 1 + # make sure max is still the number of images in the sequence + text_time[ + text_time + > repeat( + torch.count_nonzero(media_locations, dim=1), + "b -> b i", + i=text_time.shape[1], + ) + ] = 0 + + # text time must equal media time if only attending to most immediate image + # otherwise, as long as text time is greater than media time (if attending to all previous images / media) + mask_op = torch.eq if self.only_attend_immediate_media else torch.ge + + text_to_media_mask = mask_op( + rearrange(text_time, "b i -> b 1 i 1"), + repeat(media_time, "j -> 1 1 1 (j n)", n=n), + ) + sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max) + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + if exists(media_locations) and self.only_attend_immediate_media: + # any text without a preceding media needs to have attention zeroed out + text_without_media_mask = text_time == 0 + text_without_media_mask = rearrange( + text_without_media_mask, "b i -> b 1 i 1" + ) + attn = attn.masked_fill(text_without_media_mask, 0.0) + + out = einsum("... i j, ... j d -> ... i d", attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class GatedCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + dim, + dim_visual, + dim_head=64, + heads=8, + ff_mult=4, + only_attend_immediate_media=True, + ): + super().__init__() + self.attn = MaskedCrossAttention( + dim=dim, + dim_visual=dim_visual, + dim_head=dim_head, + heads=heads, + only_attend_immediate_media=only_attend_immediate_media, + ) + self.attn_gate = nn.Parameter(torch.tensor([0.0])) + + self.ff = FeedForward(dim, mult=ff_mult) + self.ff_gate = nn.Parameter(torch.tensor([0.0])) + + def forward( + self, + x, + media, + media_locations=None, + attend_previous=True, + ): + x = ( + self.attn( + x, + media, + media_locations=media_locations, + attend_previous=attend_previous, + ) + * self.attn_gate.tanh() + + x + ) + x = self.ff(x) * self.ff_gate.tanh() + x + + return x diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/utils.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/utils.py new file mode 100644 index 0000000000..815c70016c --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/src/utils.py @@ -0,0 +1,31 @@ +def extend_instance(obj, mixin): + """Apply mixins to a class instance after creation""" + base_cls = obj.__class__ + base_cls_name = obj.__class__.__name__ + obj.__class__ = type( + base_cls_name, (mixin, base_cls), {} + ) # mixin needs to go first for our forward() logic to work + + +def getattr_recursive(obj, att): + """ + Return nested attribute of obj + Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c + """ + if att == "": + return obj + i = att.find(".") + if i < 0: + return getattr(obj, att) + else: + return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) + + +def setattr_recursive(obj, att, val): + """ + Set nested attribute of obj + Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val + """ + if "." in att: + obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) + setattr(obj, att.split(".")[-1], val) diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/__init__.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/__init__.py @@ -0,0 +1 @@ + diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/convert_mmc4_to_wds.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/convert_mmc4_to_wds.py new file mode 100644 index 0000000000..6dfc2c7766 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/convert_mmc4_to_wds.py @@ -0,0 +1,72 @@ +import argparse +import base64 +import json +import os +import tarfile +import uuid +import zipfile + +import braceexpand +import webdataset as wds + +arg_parser = argparse.ArgumentParser() +arg_parser.add_argument("--output_dir", type=str) +arg_parser.add_argument( + "--image_shards", + type=str, + help="Pass in a list of shards in the format path_to_shard/shard_{0..23098}_images_v2.tar", +) +arg_parser.add_argument( + "--doc_shards", + type=str, + help="Pass in a list of shards in the format path_to_shard/docs_shard_{0..23098}_v2.jsonl.zip", +) +args = arg_parser.parse_args() + + +def main(): + os.makedirs(args.output_dir, exist_ok=True) + + doc_shards = list(braceexpand.braceexpand(args.doc_shards)) + image_shards = list(braceexpand.braceexpand(args.image_shards)) + + assert len(doc_shards) == len( + image_shards + ), "Each doc shards must have a corresponding image shard" + + with wds.ShardWriter(args.output_dir + "/%09d.tar", maxcount=1000) as sink: + for idx in range(len(doc_shards)): + image_tar = tarfile.open(image_shards[idx]) + + # Open the ZIP archive and extract the JSON file + with zipfile.ZipFile(doc_shards[idx], "r") as zip_file: + # Assumes the JSON file is the first file in the archive + json_filename = zip_file.namelist()[0] + with zip_file.open(json_filename, "r") as json_file: + for sample_data in json_file: + # get image names from json + sample_data = json.loads(sample_data) + image_info = sample_data["image_info"] + image_names = [image["image_name"] for image in image_info] + + # Add each image to the tar file + for img_idx, image_name in enumerate(image_names): + image = image_tar.extractfile( + f"{image_tar.getnames()[0]}/{image_name}" + ) + + # convert to base64 + image_bytes = image.read() + image_base64 = base64.b64encode(image_bytes).decode("utf-8") + sample_data["image_info"][img_idx][ + "image_base64" + ] = image_base64 + + key_str = uuid.uuid4().hex + sink.write({"__key__": key_str, "json": sample_data}) + + image_tar.close() + + +if __name__ == "__main__": + main() diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/data.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/data.py new file mode 100644 index 0000000000..e48a5891e7 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/data.py @@ -0,0 +1,576 @@ +import ast +import functools +import io +import json +import logging +import math +import os +import random +import sys +import tarfile +from dataclasses import dataclass +from multiprocessing import Value + +import braceexpand +import torch +import torchvision +import webdataset as wds +from PIL import Image +from torch.utils.data import DataLoader, IterableDataset, get_worker_info +from torch.utils.data.distributed import DistributedSampler +from webdataset.filters import _shuffle +from webdataset.tariterators import ( + base_plus_ext, + tar_file_expander, + url_opener, + valid_sample, +) +import base64 + +Image.MAX_IMAGE_PIXELS = 1000000000 +MAX_NUM_TOKENS = 256 +MAX_NUM_IMAGES = 5 +TINY_IMAGE_SIZE_THRESHOLD = 1 +N_CHANNELS = 3 +INTERLEAVED_IMAGE_SIZE = 224 + +import numpy as np +np.fft.ifft2() + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +class SharedEpoch: + def __init__(self, epoch: int = 0): + self.shared_epoch = Value("i", epoch) + + def set_value(self, epoch): + self.shared_epoch.value = epoch + + def get_value(self): + return self.shared_epoch.value + + +@dataclass +class DataInfo: + dataloader: DataLoader + sampler: DistributedSampler = None + shared_epoch: SharedEpoch = None + + def set_epoch(self, epoch): + if self.shared_epoch is not None: + self.shared_epoch.set_value(epoch) + if self.sampler is not None and isinstance(self.sampler, DistributedSampler): + self.sampler.set_epoch(epoch) + + +def get_dataset_size(shards): + shards_list = list(braceexpand.braceexpand(shards)) + shards_list = shards + dir_path = os.path.dirname(shards[0]) + sizes_filename = os.path.join(dir_path, "sizes.json") + len_filename = os.path.join(dir_path, "__len__") + if os.path.exists(sizes_filename): + sizes = json.load(open(sizes_filename, "r")) + total_size = sum( + [ + int(sizes[os.path.basename(shard)]) + if os.path.basename(shard) in sizes + else 0 + for shard in shards_list + ] + ) + elif os.path.exists(len_filename): + # FIXME this used to be eval(open(...)) but that seemed rather unsafe + total_size = ast.literal_eval(open(len_filename, "r").read()) + else: + total_size = None # num samples undefined + # some common dataset sizes (at time of authors last download) + # CC3M (train): 2905954 + # CC12M: 10968539 + # LAION-400M: 407332084 + # LAION-2B (english): 2170337258 + num_shards = len(shards_list) + return total_size, num_shards + + +def count_samples(dataloader): + os.environ["WDS_EPOCH"] = "0" + n_elements, n_batches = 0, 0 + for images, texts in dataloader: + n_batches += 1 + n_elements += len(images) + assert len(images) == len(texts) + return n_elements, n_batches + + +def filter_no_caption_or_no_image(sample): + return ("txt" in sample) and ( + "png" in sample or "jpg" in sample or "jpeg" in sample + ) + + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, issue a warning, and continue.""" + if "No images in sample" in str(exn) or "Only one image in sample" in str( + exn + ): # Avoid spamming logs with these + return True + logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") + return True + + +def group_by_keys_nothrow( + data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None +): + """Return function over iterator that groups key, value pairs into samples. + + :param keys: function that splits the key into key and extension (base_plus_ext) + :param lcase: convert suffixes to lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + # FIXME webdataset version throws if suffix in current_sample, but we have a potential for + # this happening in the current LAION400m dataset if a tar ends with same prefix as the next + # begins, rare, but can happen since prefix aren't unique across tar files in that dataset + if ( + current_sample is None + or prefix != current_sample["__key__"] + or suffix in current_sample + ): + if valid_sample(current_sample): + yield current_sample + current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if valid_sample(current_sample): + yield current_sample + + +def tarfile_to_samples_nothrow(src, handler=log_and_continue): + # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw + streams = url_opener(src, handler=handler) + files = tar_file_expander(streams, handler=handler) + samples = group_by_keys_nothrow(files, handler=handler) + return samples + + +def pytorch_worker_seed(increment=0): + """get dataloader worker seed from pytorch""" + worker_info = get_worker_info() + if worker_info is not None: + # favour using the seed already created for pytorch dataloader workers if it exists + seed = worker_info.seed + if increment: + # space out seed increments so they can't overlap across workers in different iterations + seed += increment * max(1, worker_info.num_workers) + return seed + # fallback to wds rank based seed + return wds.utils.pytorch_worker_seed() + + +_SHARD_SHUFFLE_SIZE = 2000 +_SHARD_SHUFFLE_INITIAL = 500 +_SAMPLE_SHUFFLE_SIZE = 5000 +_SAMPLE_SHUFFLE_INITIAL = 1000 + + +class detshuffle2(wds.PipelineStage): + def __init__( + self, + bufsize=1000, + initial=100, + seed=0, + epoch=-1, + ): + self.bufsize = bufsize + self.initial = initial + self.seed = seed + self.epoch = epoch + + def run(self, src): + if isinstance(self.epoch, SharedEpoch): + epoch = self.epoch.get_value() + else: + # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) + # situation as different workers may wrap at different times (or not at all). + self.epoch += 1 + epoch = self.epoch + rng = random.Random() + if self.seed < 0: + # If seed is negative, we use the worker's seed, this will be different across all nodes/workers + seed = pytorch_worker_seed(epoch) + else: + # This seed to be deterministic AND the same across all nodes/workers in each epoch + seed = self.seed + epoch + rng.seed(seed) + return _shuffle(src, self.bufsize, self.initial, rng) + + +class ResampledShards2(IterableDataset): + """An iterable dataset yielding a list of urls.""" + + def __init__( + self, + urls, + nshards=sys.maxsize, + worker_seed=None, + deterministic=False, + epoch=-1, + ): + """Sample shards from the shard list with replacement. + :param urls: a list of URLs as a Python list or brace notation string + """ + super().__init__() + urls = wds.shardlists.expand_urls(urls) + self.urls = urls + assert isinstance(self.urls[0], str) + self.nshards = nshards + self.rng = random.Random() + self.worker_seed = worker_seed + self.deterministic = deterministic + self.epoch = epoch + + def __iter__(self): + """Return an iterator over the shards.""" + if isinstance(self.epoch, SharedEpoch): + epoch = self.epoch.get_value() + else: + # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) + # situation as different workers may wrap at different times (or not at all). + self.epoch += 1 + epoch = self.epoch + + if self.deterministic: + # reset seed w/ epoch if deterministic + if self.worker_seed is None: + # pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id + seed = pytorch_worker_seed(epoch) + else: + seed = self.worker_seed() + epoch + self.rng.seed(seed) + for _ in range(self.nshards): + yield dict(url=self.rng.choice(self.urls)) + + +def preprocess_image(sample, image_processor): + image = [image_processor(s).unsqueeze(0) for s in sample] + image = torch.cat(image, dim=0) + # apply random horizontal flip + image = torchvision.transforms.RandomHorizontalFlip(p=0.5)(image) + # NOTE: potentially move jitter into the image_preprocessor before normalization + # image = torchvision.transforms.ColorJitter(brightness=0.5, hue=0.3)(image) + return image + + +def preprocess_text(sample, tokenizer): + tokenizer.padding_side = "right" + sample = [ + (f"{s.strip()}<|endofchunk|>{tokenizer.eos_token}") for s in sample + ] + text = tokenizer( + sample, + max_length=32, + padding="longest", + truncation="only_first", + return_tensors="pt", + ) + return text["input_ids"], text["attention_mask"] + + +MIN_KB = 10 + + +def preprocess_interleaved(sample, tokenizer, clip_processor, sim_threshold): + info = json.loads(sample[0]) + sentences = info["text_list"] + + images, sentence_ixs = [], [] + for sample_image in info["image_info"]: + image_base64 = sample_image["image_base64"] + rawbytes = base64.b64decode(image_base64) + + # filter to images >= 10KB + if len(rawbytes) // 1000 <= MIN_KB: + continue + if sample_image["matched_sim"] < sim_threshold: + continue + image = Image.open(io.BytesIO(rawbytes)).convert("RGB") + + images.append(image) + sentence_ixs.append(sample_image["matched_text_index"]) + + if len(images) == 0: + raise ValueError("No images in sample") + + # images -> tensors + images_tensors = preprocess_image(images, clip_processor) + keep_ixs = range(min(len(images_tensors), MAX_NUM_IMAGES)) + images_tensors = images_tensors[keep_ixs] + sentence_ixs = [sentence_ixs[ix] for ix in keep_ixs] + + # pad to 5 images + if len(images_tensors) < MAX_NUM_IMAGES: + zero_padding = torch.zeros( + (MAX_NUM_IMAGES - len(images_tensors), 3, 224, 224), dtype=torch.float + ) + images_tensors = torch.cat((images_tensors, zero_padding), dim=0) + + # add in and tokens + # eoc after sentence = "sentence loss" + for ix in sentence_ixs: + sentences[ix] = f"<|endofchunk|>{sentences[ix]}" + + text = " ".join(sentences) + text = text.replace("<|endofchunk|>", "", 1) # but remove first eoc + # whitespace cleanup + text = ( + text.replace(" <|endofchunk|>", "<|endofchunk|>") + .replace(" ", "") + .replace(" ", "") + ) + text = f"{text}<|endofchunk|>{tokenizer.eos_token}" + tokenizer.padding_side = "right" + text_tensor = tokenizer( + text, max_length=256, truncation=True, padding="max_length", return_tensors="pt" + ) + + # reject sequences with too few images (after truncation) + num_images = torch.count_nonzero( + text_tensor["input_ids"] + == tokenizer.additional_special_tokens_ids[ + tokenizer.additional_special_tokens.index("") + ] + ) + + if num_images == 0: + raise ValueError("No images in sample") + elif ( + num_images == 1 and random.random() <= 0.5 + ): # 50% chance of keeping single image samples + raise ValueError("Only one image in sample") + + return ( + images_tensors, + (text_tensor["input_ids"], text_tensor["attention_mask"]), + ) + + +def get_mmc4_dataset(args, image_processor, tokenizer, epoch=0, floor=False): + input_shards = args.mmc4_shards + assert input_shards is not None + resampled = getattr(args, "dataset_resampled", False) + + num_samples, num_shards = get_dataset_size(input_shards) + num_samples = None + if not num_samples: + num_samples = args.train_num_samples_mmc4 + if not num_samples: + raise RuntimeError( + "Currently, number of dataset samples must be specified for training dataset. " + "Please specify via `--train-num-samples` if no dataset length info present." + ) + + # create a shared epoch store to sync epoch to dataloader worker proc + shared_epoch = SharedEpoch(epoch=epoch) + if resampled: + pipeline = [ + ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch) + ] + else: + pipeline = [wds.SimpleShardList(input_shards)] + + preprocess_fn = functools.partial( + preprocess_interleaved, + clip_processor=image_processor, + tokenizer=tokenizer, + sim_threshold=args.mmc4_textsim_threshold, + ) + + # at this point we have an iterator over all the shards + if not resampled: + pipeline.extend( + [ + detshuffle2( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + epoch=shared_epoch, + ), + wds.split_by_node, + wds.split_by_worker, + ] + ) + pipeline.extend( + [ + # at this point, we have an iterator over the shards assigned to each worker at each node + # wds.tarfile_to_samples(handler=log_and_continue), + tarfile_to_samples_nothrow, + wds.shuffle( + bufsize=_SAMPLE_SHUFFLE_SIZE, + initial=_SAMPLE_SHUFFLE_INITIAL, + ), + ] + ) + + pipeline.extend( + [ + wds.to_tuple("json", handler=log_and_continue), + wds.map(preprocess_fn, handler=log_and_continue), + wds.batched(args.batch_size_mmc4, partial=False), + ] + ) + + dataset = wds.DataPipeline(*pipeline) + if not resampled: + assert ( + num_shards >= args.workers * args.world_size + ), "number of shards must be >= total workers" + # roll over and repeat a few samples to get same number of full batches on each node + round_fn = math.floor if floor else math.ceil + global_batch_size = args.batch_size_mmc4 * args.world_size + num_batches = round_fn(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + # each worker is iterating over this + dataset = dataset.with_epoch(num_worker_batches) + + dataloader = wds.WebLoader( + dataset, + batch_size=None, + shuffle=False, + num_workers=args.workers, + persistent_workers=True, + ) + + # add meta-data to dataloader instance for convenience + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) + + +def get_laion_dataset(args, image_processor, tokenizer, epoch=0, floor=False): + input_shards = args.laion_shards + assert input_shards is not None + resampled = getattr(args, "dataset_resampled", False) + + num_samples, num_shards = get_dataset_size(input_shards) + num_samples = None + if not num_samples: + num_samples = args.train_num_samples_laion + if not num_samples: + raise RuntimeError( + "Currently, number of dataset samples must be specified for training dataset. " + "Please specify via `--train-num-samples` if no dataset length info present." + ) + + # create a shared epoch store to sync epoch to dataloader worker proc + shared_epoch = SharedEpoch(epoch=epoch) + if resampled: + pipeline = [ + ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch) + ] + else: + pipeline = [wds.SimpleShardList(input_shards)] + + # create two preprocess functions that take in the passed in image_processor and tokenizer + preprocess_image_fn = functools.partial( + preprocess_image, image_processor=image_processor + ) + preprocess_text_fn = functools.partial(preprocess_text, tokenizer=tokenizer) + + # at this point we have an iterator over all the shards + if not resampled: + pipeline.extend( + [ + detshuffle2( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + epoch=shared_epoch, + ), + wds.split_by_node, + wds.split_by_worker, + ] + ) + pipeline.extend( + [ + # at this point, we have an iterator over the shards assigned to each worker at each node + # wds.tarfile_to_samples(handler=log_and_continue), + tarfile_to_samples_nothrow, + wds.shuffle( + bufsize=_SAMPLE_SHUFFLE_SIZE, + initial=_SAMPLE_SHUFFLE_INITIAL, + ), + ] + ) + + pipeline.extend( + [ + wds.select(filter_no_caption_or_no_image), + wds.decode("pilrgb", handler=log_and_continue), + wds.to_tuple("jpg;png;jpeg", "txt", handler=log_and_continue), + wds.batched(args.batch_size_laion, partial=False), + wds.map_tuple( + preprocess_image_fn, preprocess_text_fn, handler=log_and_continue + ), + ] + ) + + dataset = wds.DataPipeline(*pipeline) + if not resampled: + assert ( + num_shards >= args.workers * args.world_size + ), "number of shards must be >= total workers" + # roll over and repeat a few samples to get same number of full batches on each node + round_fn = math.floor if floor else math.ceil + global_batch_size = args.batch_size_laion * args.world_size + num_batches = round_fn(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + # each worker is iterating over this + dataset = dataset.with_epoch(num_worker_batches) + + dataloader = wds.WebLoader( + dataset, + batch_size=None, + shuffle=False, + num_workers=args.workers, + persistent_workers=True, + ) + + # add meta-data to dataloader instance for convenience + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) + + +def get_dataset_fn(dataset_type): + if dataset_type == "image_text": + return get_laion_dataset + elif dataset_type == "mmc4": + return get_mmc4_dataset + else: + raise ValueError(f"Unsupported dataset type: {dataset_type}") + + +def get_data(args, image_processor, tokenizer, dataset_type, epoch=0): + return get_dataset_fn(dataset_type)( + args, image_processor=image_processor, epoch=epoch, tokenizer=tokenizer + ) diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/distributed.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/distributed.py new file mode 100644 index 0000000000..3938d063d5 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/distributed.py @@ -0,0 +1,128 @@ +import os + +import torch + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +def is_global_master(args): + return args.rank == 0 + + +def is_local_master(args): + return args.local_rank == 0 + + +def is_master(args, local=False): + return is_local_master(args) if local else is_global_master(args) + + +def is_using_horovod(): + # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set + # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... + ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] + pmi_vars = ["PMI_RANK", "PMI_SIZE"] + if all([var in os.environ for var in ompi_vars]) or all( + [var in os.environ for var in pmi_vars] + ): + return True + else: + return False + + +def is_using_distributed(): + if "WORLD_SIZE" in os.environ: + return int(os.environ["WORLD_SIZE"]) > 1 + if "SLURM_NTASKS" in os.environ: + return int(os.environ["SLURM_NTASKS"]) > 1 + return False + + +def world_info_from_env(): + local_rank = 0 + for v in ( + "LOCAL_RANK", + "MPI_LOCALRANKID", + "SLURM_LOCALID", + "OMPI_COMM_WORLD_LOCAL_RANK", + ): + if v in os.environ: + local_rank = int(os.environ[v]) + break + global_rank = 0 + for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"): + if v in os.environ: + global_rank = int(os.environ[v]) + break + world_size = 1 + for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"): + if v in os.environ: + world_size = int(os.environ[v]) + break + + return local_rank, global_rank, world_size + + +def init_distributed_device(args): + # Distributed training = training on more than one GPU. + # Works in both single and multi-node scenarios. + args.distributed = False + args.world_size = 1 + args.rank = 0 # global rank + args.local_rank = 0 + if args.horovod: + assert hvd is not None, "Horovod is not installed" + hvd.init() + args.local_rank = int(hvd.local_rank()) + args.rank = hvd.rank() + args.world_size = hvd.size() + args.distributed = True + os.environ["LOCAL_RANK"] = str(args.local_rank) + os.environ["RANK"] = str(args.rank) + os.environ["WORLD_SIZE"] = str(args.world_size) + elif is_using_distributed(): + if "SLURM_PROCID" in os.environ: + # DDP via SLURM + args.local_rank, args.rank, args.world_size = world_info_from_env() + # SLURM var -> torch.distributed vars in case needed + os.environ["LOCAL_RANK"] = str(args.local_rank) + os.environ["RANK"] = str(args.rank) + os.environ["WORLD_SIZE"] = str(args.world_size) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + else: + # DDP via torchrun, torch.distributed.launch + args.local_rank, _, _ = world_info_from_env() + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url + ) + args.world_size = torch.distributed.get_world_size() + args.rank = torch.distributed.get_rank() + args.distributed = True + else: + # needed to run on single gpu + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=1, + rank=0, + ) + + if torch.cuda.is_available(): + if args.distributed and not args.no_set_device_rank: + device = "cuda:%d" % args.local_rank + else: + device = "cuda:0" + torch.cuda.set_device(device) + else: + device = "cpu" + args.device = device + device = torch.device(device) + return device diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/train.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/train.py new file mode 100644 index 0000000000..d9febbac4e --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/train.py @@ -0,0 +1,490 @@ +""" Main training script """ + +import argparse +import copy +import glob +import os +import random + +import numpy as np +import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu + +from torch import Tensor +from typing import List +from torch.optim.optimizer import Optimizer + +import wandb +from data import get_data +from distributed import init_distributed_device, world_info_from_env +from torch.nn.parallel import DistributedDataParallel as DDP +from train_utils import get_checkpoint, train_one_epoch +from transformers import ( + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + get_linear_schedule_with_warmup, +) + +from open_flamingo import create_model_and_transforms + +torch_npu.npu.set_compile_mode(jit_compile=False) + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) + +def adamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[int], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool): + r"""Functional API that performs AdamW algorithm computation. + See :class:`~torch.optim.AdamW` for details. + """ + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step = state_steps[i] + + # Perform stepweight decay + ## param.mul_(1 - lr * weight_decay) + bias_correction1 = beta1 ** step + bias_correction2 = beta2 ** step + + param.data, exp_avg, exp_avg_sq = torch_npu.npu_apply_adam_w( + bias_correction1, + bias_correction2, + lr, + weight_decay, + beta1, + beta2, + eps, + grad, + None, + amsgrad, + maximize, + out=(param.data, exp_avg, exp_avg_sq) + ) + +class AdamW(Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False, *, maximize: bool = False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + group.setdefault('maximize', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + # adamw_torch(params_with_grad, + adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=group['maximize']) + + return loss + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str) + parser.add_argument("--vision_encoder_pretrained", default="openai", type=str) + parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str) + parser.add_argument( + "--tokenizer_path", + default="facebook/opt-30b", + type=str, + help="path to tokenizer", + ) + parser.add_argument( + "--cross_attn_every_n_layers", + type=int, + default=1, + help="how often to add a cross-attention layer after each transformer layer", + ) + parser.add_argument( + "--run_name", + type=str, + default="openflamingo3B", + help="used to name saving directory and wandb run", + ) + parser.add_argument("--use_media_placement_augmentation", action="store_true") + parser.add_argument("--offline", action="store_true") + parser.add_argument("--num_epochs", type=int, default=1) + parser.add_argument( + "--logging_steps", type=int, default=100, help="log loss every n steps" + ) + # Sum of gradient optimization batch size + parser.add_argument("--batch_size_mmc4", type=int, default=8) + parser.add_argument("--batch_size_laion", type=int, default=128) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states", + default=None, + ) + parser.add_argument( + "--delete_previous_checkpoint", + action="store_true", + help="delete previous checkpoint when saving new checkpoint", + ) + parser.add_argument( + "--laion_shards", + type=str, + help="path to laion shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar", + ) + parser.add_argument( + "--mmc4_shards", + default=None, + type=str, + help="path to c4 shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar", + ) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--learning_rate", default=1e-4, type=float) + parser.add_argument( + "--lr_scheduler", + default="constant", + type=str, + help="constant, linear, or cosine", + ) + parser.add_argument("--loss_multiplier_mmc4", type=float, default=1.0) + parser.add_argument("--loss_multiplier_laion", type=float, default=1.0) + parser.add_argument("--warmup_steps", default=5000, type=int) + parser.add_argument("--weight_decay", default=0.1, type=float) + parser.add_argument( + "--precision", + choices=["amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"], + default="fp32", + help="Floating point precision.", + ) + # data args + parser.add_argument("--workers", type=int, default=1) + parser.add_argument("--train_num_samples_mmc4", type=int, default=10000) + parser.add_argument("--train_num_samples_laion", type=int, default=10000) + parser.add_argument("--dataset_resampled", action="store_true") + # distributed training args + parser.add_argument( + "--dist-url", + default="env://", + type=str, + help="url used to set up distributed training", + ) + parser.add_argument( + "--dist-backend", default="nccl", type=str, help="distributed backend" + ) + parser.add_argument( + "--horovod", + default=False, + action="store_true", + help="Use horovod for distributed training.", + ) + parser.add_argument( + "--no-set-device-rank", + default=False, + action="store_true", + help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", + ) + # wandb args + parser.add_argument("--report_to_wandb", default=False, action="store_true") + parser.add_argument( + "--wandb_project", + type=str, + ) + parser.add_argument( + "--wandb_entity", + type=str, + ) + parser.add_argument( + "--save_checkpoints_to_wandb", + default=False, + action="store_true", + help="save checkpoints to wandb", + ) + parser.add_argument( + "--mmc4_textsim_threshold", + default=0.32, + type=float, + help="threshold for filtering images in mmc4 based on image-text similarity", + ) + + args = parser.parse_args() + + if args.laion_shards.startswith("s3"): + args.laion_shards = f"pipe:aws s3 cp {args.laion_shards} -" + + if args.mmc4_shards is not None: + if args.mmc4_shards.startswith("s3"): + args.mmc4_shards = f"pipe:aws s3 cp {args.mmc4_shards} -" + + if args.save_checkpoints_to_wandb and not args.report_to_wandb: + raise ValueError("save_checkpoints_to_wandb requires report_to_wandb") + + if args.mmc4_shards is not None: + assert (args.train_num_samples_laion // args.batch_size_laion) == ( + args.train_num_samples_mmc4 // args.batch_size_mmc4 + ), "number of samples per epoch must be equal for mmc4 and laion" + + if args.offline: + os.environ["WANDB_MODE"] = "offline" + os.environ["TRANSFORMERS_OFFLINE"] = "1" + + args.local_rank, args.rank, args.world_size = world_info_from_env() + + device_id = init_distributed_device(args) + + random_seed(args.seed) + + model, image_processor, tokenizer = create_model_and_transforms( + args.vision_encoder_path, + args.vision_encoder_pretrained, + args.lm_path, + args.tokenizer_path if args.tokenizer_path else args.lm_path, + cross_attn_every_n_layers=args.cross_attn_every_n_layers, + use_local_files=args.offline, + use_media_placement_augmentation=args.use_media_placement_augmentation, + ) + + random_seed(args.seed, args.rank) + + print(f"Start running training on rank {args.rank}.") + + if args.rank == 0 and args.report_to_wandb: + wandb.init( + project=args.wandb_project, + entity=args.wandb_entity, + name=args.run_name, + config=vars(args), + ) + + device_id = args.rank % torch.cuda.device_count() + # print(model) + model = model.npu() + + ddp_model = DDP(model, device_ids=[device_id]) + + laion_dataset = get_data(args, image_processor, tokenizer, "image_text") + if args.mmc4_shards is not None: + mmc4_dataset = get_data(args, image_processor, tokenizer, "mmc4") + + def get_grouped_params(model): + params_with_wd, params_without_wd = [], [] + + def apply_decay(x): + return ( + "gated_cross_attn_layer" in x + and "ff_gate" not in x + and "attn_gate" not in x + and "norm" not in x + and "bias" not in x + ) + + for n, p in model.named_parameters(): + # if p.requires_grad: + if apply_decay(n): + params_with_wd.append(p) + else: + params_without_wd.append(p) + + return [ + {"params": params_with_wd, "weight_decay": args.weight_decay}, + {"params": params_without_wd, "weight_decay": 0.0}, + ] + + optimizer = AdamW(get_grouped_params(ddp_model), lr=args.learning_rate) + + if args.mmc4_shards is not None: + total_training_steps = ( + (args.train_num_samples_mmc4) // (args.batch_size_mmc4 * args.world_size) + ) * args.num_epochs + else: + total_training_steps = ( + (args.train_num_samples_laion) // (args.batch_size_laion * args.world_size) + ) * args.num_epochs + + if args.rank == 0: + print(f"Total training steps: {total_training_steps}") + + if args.lr_scheduler == "linear": + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=args.warmup_steps, + num_training_steps=total_training_steps, + ) + elif args.lr_scheduler == "cosine": + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=args.warmup_steps, + num_training_steps=total_training_steps, + ) + else: + lr_scheduler = get_constant_schedule_with_warmup( + optimizer, num_warmup_steps=args.warmup_steps + ) + + # check if a checkpoint exists for this run + if os.path.exists(f"{args.run_name}") and args.resume_from_checkpoint is None: + checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt") + if len(checkpoint_list) == 0: + print(f"Found no checkpoints for run {args.run_name}.") + else: + args.resume_from_checkpoint = sorted( + checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]) + )[-1] + print( + f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}." + ) + + resume_from_epoch = 0 + if args.resume_from_checkpoint is not None: + if args.rank == 0: + print(f"Loading checkpoint from {args.resume_from_checkpoint}") + checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu") + ddp_model.load_state_dict(checkpoint["model_state_dict"], False) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) + resume_from_epoch = checkpoint["epoch"] + 1 + + ddp_model.train() + + for epoch in range(resume_from_epoch, args.num_epochs): + laion_dataset.set_epoch(epoch) + laion_loader = laion_dataset.dataloader + if args.mmc4_shards is not None: + mmc4_dataset.set_epoch(epoch) + mmc4_loader = mmc4_dataset.dataloader + else: + mmc4_loader = None + + train_one_epoch( + args=args, + model=ddp_model, + epoch=epoch, + tokenizer=tokenizer, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + laion_loader=laion_loader, + mmc4_loader=mmc4_loader, + device_id=device_id, + wandb=wandb, + ) + + if args.rank == 0: + if not os.path.exists(args.run_name): + os.makedirs(args.run_name) + + checkpoint_dict = { + "epoch": epoch, + "model_state_dict": get_checkpoint(ddp_model), + "optimizer_state_dict": optimizer.state_dict(), + "lr_scheduler_state_dict": lr_scheduler.state_dict(), + } + + print(f"Saving checkpoint to {args.run_name}/checkpoint_{epoch}.pt") + torch.save(checkpoint_dict, f"{args.run_name}/checkpoint_{epoch}.pt") + if args.report_to_wandb and args.save_checkpoints_to_wandb: + wandb.save(f"{args.run_name}/checkpoint_{epoch}.pt") + + if args.delete_previous_checkpoint: + if epoch > 0: + os.remove(f"{args.run_name}/checkpoint_{epoch-1}.pt") + + if args.rank == 0: + if not os.path.exists(args.run_name): + os.makedirs(args.run_name) + + torch.save(get_checkpoint(ddp_model), f"{args.run_name}/final_weights.pt") + if args.report_to_wandb and args.save_checkpoints_to_wandb: + wandb.save(f"{args.run_name}/final_weights.pt") + + +if __name__ == "__main__": + main() diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/train_utils.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/train_utils.py new file mode 100644 index 0000000000..8a65cd6503 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/open_flamingo/train/train_utils.py @@ -0,0 +1,398 @@ +import time +from contextlib import suppress + +import torch +from tqdm import tqdm + +import torch_npu +from torch_npu.contrib import transfer_to_npu + + +def get_cast_dtype(precision: str): + cast_dtype = None + if precision == "bf16": + cast_dtype = torch.bfloat16 + elif precision == "fp16": + cast_dtype = torch.float16 + return cast_dtype + + +def get_autocast(precision): + if precision == "amp": + return torch.cuda.amp.autocast + elif precision == "amp_bfloat16" or precision == "amp_bf16": + # amp_bfloat16 is more stable than amp float16 for clip training + return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) + else: + return suppress + + +# prof = torch.npu.profile(profiler_result_path="./perf", use_e2e_profiler=True) + +def train_one_epoch( + args, + model, + epoch, + laion_loader, + mmc4_loader, + tokenizer, + optimizer, + lr_scheduler, + device_id, + wandb, +): + num_batches_per_epoch_laion = laion_loader.num_batches + if mmc4_loader is not None: + num_batches_per_epoch_mmc4 = mmc4_loader.num_batches + + assert ( + num_batches_per_epoch_laion == num_batches_per_epoch_mmc4 + ), "Number of batches in laion and mmc4 datasets must be the same" + num_batches_per_epoch = num_batches_per_epoch_mmc4 + else: + num_batches_per_epoch = num_batches_per_epoch_laion + total_training_steps = num_batches_per_epoch * args.num_epochs + + autocast = get_autocast(args.precision) + cast_dtype = get_cast_dtype(args.precision) + + media_token_id = tokenizer("", add_special_tokens=False)["input_ids"][-1] + endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)[ + "input_ids" + ][-1] + + model.train() + + # setup logging + step_time_m = ( + AverageMeter() + ) # time for one optimizer step (> 1 batch if using gradient accum) + data_time_m = ( + AverageMeter() + ) # avg time to load one batch of both C4 AND laion (= 1 batch regardless of gradient accum) + end = time.time() + + # experimental_config = torch_npu.profiler._ExperimentalConfig( + # profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + # aic_metrics=torch_npu.profiler.AiCMetrics.MemoryL0, + # l2_cache=True + # ) + + # with torch_npu.profiler.profile( + # activities=[torch_npu.profiler.ProfilerActivity.CPU, + # torch_npu.profiler.ProfilerActivity.NPU], + # with_stack=True, + # record_shapes=True, + # profile_memory=True, + # schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=2, repeat=1, skip_first=0), + # # experimental_config=experimental_config, + # on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prof_npu")) as prof: + + if mmc4_loader is not None: + # loop through dataloader + for num_steps, (batch_laion, batch_mmc4) in tqdm( + enumerate(zip(laion_loader, mmc4_loader)), + disable=args.rank != 0, + total=total_training_steps, + initial=(epoch * num_batches_per_epoch), + ): + end, step_time_m, data_time_m, optimizer, lr_scheduler = train_step( + batch_laion=batch_laion, + batch_mmc4=batch_mmc4, + model=model, + tokenizer=tokenizer, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + args=args, + device_id=device_id, + num_steps=num_steps, + epoch=epoch, + num_batches_per_epoch=num_batches_per_epoch, + autocast=autocast, + cast_dtype=cast_dtype, + media_token_id=media_token_id, + endofchunk_token_id=endofchunk_token_id, + end=end, + step_time_m=step_time_m, + data_time_m=data_time_m, + ) + # prof.step() + else: + for num_steps, batch_laion in tqdm( + enumerate(laion_loader), + disable=args.rank != 0, + total=total_training_steps, + initial=(epoch * num_batches_per_epoch), + ): + # if num_steps == 3: + # prof.__enter__() + end, step_time_m, data_time_m, optimizer, lr_scheduler = train_step( + batch_laion=batch_laion, + model=model, + tokenizer=tokenizer, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + args=args, + device_id=device_id, + num_steps=num_steps, + epoch=epoch, + num_batches_per_epoch=num_batches_per_epoch, + autocast=autocast, + cast_dtype=cast_dtype, + media_token_id=media_token_id, + endofchunk_token_id=endofchunk_token_id, + end=end, + step_time_m=step_time_m, + data_time_m=data_time_m, + wandb=wandb, + ) + # prof.step() + + # if num_steps == 3: + # prof.__exit__(None, None, None) + # return + + +def train_step( + batch_laion, + model, + tokenizer, + optimizer, + lr_scheduler, + args, + device_id, + num_steps, + epoch, + num_batches_per_epoch, + autocast, + cast_dtype, + media_token_id, + endofchunk_token_id, + end, + step_time_m, + data_time_m, + wandb, + batch_mmc4=None, +): + data_time_m.update(time.time() - end) + + global_step = num_steps + epoch * num_batches_per_epoch + + #### LAION FORWARD PASS #### + images = ( + batch_laion[0] + .to(device_id, dtype=cast_dtype, non_blocking=True) + .unsqueeze(1) + .unsqueeze(1) + ) + + input_ids = batch_laion[1][0].to(device_id, dtype=cast_dtype, non_blocking=True) + attention_mask = batch_laion[1][1].to( + device_id, dtype=cast_dtype, non_blocking=True + ) + + labels = input_ids.clone() + labels[labels == tokenizer.pad_token_id] = -100 + labels[:, 0] = -100 + labels[labels == media_token_id] = -100 + labels.to(device_id) + + with autocast(): + loss_laion = model( + vision_x=images, + lang_x=input_ids, + attention_mask=attention_mask, + labels=labels, + )[0] + divided_loss_laion = loss_laion / args.gradient_accumulation_steps + + if batch_mmc4 is not None: + #### C4 FORWARD PASS #### + images = ( + batch_mmc4[0] + .to(device_id, dtype=cast_dtype, non_blocking=True) + .unsqueeze(2) + ) + input_ids = torch.stack([x[0] for x in batch_mmc4[1]]).squeeze(1) + attention_mask = torch.stack([x[1] for x in batch_mmc4[1]]).squeeze(1) + + # NOTE: irena: expected shape of clip_text_input_ids / attention_mask is (N, I, max_seq_len) + labels = input_ids.clone() + labels[labels == tokenizer.pad_token_id] = -100 + labels[:, 0] = -100 + + for i in range(labels.shape[0]): + # remove loss for any token before the first token + label_idx = 0 + while ( + label_idx < labels.shape[1] and labels[i][label_idx] != media_token_id + ): + labels[i][label_idx] = -100 + label_idx += 1 + + # get index of all endofchunk tokens in the sequence + endofchunk_idxs = torch.where(labels[i] == endofchunk_token_id)[0] + for endofchunk_idx in endofchunk_idxs: + token_idx = endofchunk_idx + 1 + while ( + token_idx < labels.shape[1] + and labels[i][token_idx] != media_token_id + ): + labels[i][token_idx] = -100 + token_idx += 1 + + labels[labels == media_token_id] = -100 + labels.to(device_id) + + with autocast(): + loss_mmc4 = model( + vision_x=images, + lang_x=input_ids, + attention_mask=attention_mask, + labels=labels, + )[0] + + # if loss is nan, skip this batch + if torch.isnan(loss_mmc4): + print("loss is nan, skipping this batch") + print("input_ids: ", tokenizer.batch_decode(input_ids)) + print("labels: ", labels) + print("images: ", images) + optimizer.zero_grad() + return end, step_time_m, data_time_m, optimizer, lr_scheduler + + divided_loss_mmc4 = loss_mmc4 / args.gradient_accumulation_steps + + #### BACKWARD PASS #### + loss = ( + divided_loss_laion * args.loss_multiplier_laion + + divided_loss_mmc4 * args.loss_multiplier_mmc4 + ) + else: + loss = divided_loss_laion + # loss.backward(retain_graph=True) + loss.backward() + + #### MASK GRADIENTS FOR EMBEDDINGS #### + # Note (anas): Do not apply weight decay to embeddings as it will break this function. + def mask_embedding(m): + if isinstance(m, torch.nn.Embedding) and m.weight.requires_grad: + zero_mask = torch.zeros_like(m.weight.grad) + zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id]) + zero_mask[endofchunk_token_id] = torch.ones_like( + zero_mask[endofchunk_token_id] + ) + m.weight.grad = m.weight.grad * zero_mask + + model.apply(mask_embedding) + + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + + # step optimizer and log + if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or ( + num_steps == num_batches_per_epoch - 1 + ): + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # step time and reset end outside of rank 0 + step_time_m.update(time.time() - end) + end = time.time() + + if args.rank == 0 and args.report_to_wandb: + # compute within rank 0 + laion_samples_per_second = ( + args.gradient_accumulation_steps + * args.batch_size_laion + * args.world_size + / step_time_m.val + ) + laion_samples_per_second_per_gpu = ( + args.gradient_accumulation_steps + * args.batch_size_laion + / step_time_m.val + ) + + c4_samples_per_second = ( + args.gradient_accumulation_steps + * args.batch_size_mmc4 + * args.world_size + / step_time_m.val + ) + c4_samples_per_second_per_gpu = ( + args.gradient_accumulation_steps + * args.batch_size_mmc4 + / step_time_m.val + ) + + wandb.log( + { + "data_time": data_time_m.avg, + "step_time": step_time_m.avg, + "laion_samples_per_second": laion_samples_per_second, + "laion_samples_per_second_per_gpu": laion_samples_per_second_per_gpu, + "c4_samples_per_second": c4_samples_per_second, + "c4_samples_per_second_per_gpu": c4_samples_per_second_per_gpu, + "lr": optimizer.param_groups[0]["lr"], + }, + commit=False, + ) + step_time_m.reset() + data_time_m.reset() + + if batch_mmc4 is not None: + wandb.log( + {"loss_mmc4": divided_loss_mmc4.item(), "global_step": global_step}, + commit=False, + ) + + wandb.log( + { + "loss_laion": divided_loss_laion.item(), + "global_step": global_step, + }, + commit=True, + ) + + # Log loss to console + if ((num_steps + 1) % args.logging_steps == 0) and args.rank == 0: + if batch_mmc4 is not None: + print( + f"Step {num_steps + 1}/{num_batches_per_epoch} of epoch {epoch + 1}/{args.num_epochs} complete. Loss LAION: {loss_laion.item():.3f} // Loss MMC4: {loss_mmc4.item():.3f}" + ) + else: + print( + f"Step {num_steps + 1}/{num_batches_per_epoch} of epoch {epoch + 1}/{args.num_epochs} complete. Loss LAION: {loss_laion.item():.3f}" + ) + + return end, step_time_m, data_time_m, optimizer, lr_scheduler + + +def get_checkpoint(model): + state_dict = model.state_dict() + + for name, p in model.named_parameters(): + if not p.requires_grad: + del state_dict[name] + + return state_dict + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count \ No newline at end of file diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/requirements-dev.txt b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/requirements-dev.txt new file mode 100644 index 0000000000..429f646f46 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/requirements-dev.txt @@ -0,0 +1,5 @@ +black +mypy +pylint +pytest +requests \ No newline at end of file diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/requirements.txt b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/requirements.txt new file mode 100644 index 0000000000..75e8a03466 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/requirements.txt @@ -0,0 +1,19 @@ +einops +einops-exts +transformers==4.28.1 +tokenizers==0.13.3 +torch +torchvision +pillow +more-itertools +datasets +braceexpand +webdataset +wandb +nltk +scipy +inflection +sentencepiece==0.1.98 +pycocoevalcap +pycocotools +open_clip_torch>=2.16.0 diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/setup.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/setup.py new file mode 100644 index 0000000000..989ee19715 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/setup.py @@ -0,0 +1,57 @@ +from pathlib import Path + +from setuptools import find_packages, setup + +if __name__ == "__main__": + with Path(Path(__file__).parent, "README.md").open(encoding="utf-8") as file: + long_description = file.read() + + # TODO: This is a hack to get around the fact that we can't read the requirements.txt file, we should fix this. + # def _read_reqs(relpath): + # fullpath = os.path.join(Path(__file__).parent, relpath) + # with open(fullpath) as f: + # return [ + # s.strip() + # for s in f.readlines() + # if (s.strip() and not s.startswith("#")) + # ] + + REQUIREMENTS = [ + "einops", + "einops-exts", + "transformers>=4.28.1", + "torch", + "torchvision", + "pillow", + "more-itertools", + "datasets", + "braceexpand", + "webdataset", + "wandb", + "nltk", + "scipy", + "inflection", + "sentencepiece", + "open_clip_torch", + ] + + setup( + name="open_flamingo", + packages=find_packages(), + include_package_data=True, + version="0.0.3", + license="MIT", + description="An open-source framework for training large multimodal models", + long_description=long_description, + long_description_content_type="text/markdown", + data_files=[(".", ["README.md"])], + keywords=["machine learning"], + install_requires=REQUIREMENTS, + classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.9", + ], + ) diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/tests/test_flamingo_model.py b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/tests/test_flamingo_model.py new file mode 100644 index 0000000000..164192d2a4 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/tests/test_flamingo_model.py @@ -0,0 +1,77 @@ +# import unittest + +# import requests +# from PIL import Image + +# from open_flamingo import create_model_and_transforms + + +# class TestFlamingoModel(unittest.TestCase): +# def test_forward_pass(self): +# model, image_processor, tokenizer = create_model_and_transforms( +# clip_vision_encoder_path="hf-internal-testing/tiny-random-clip-zero-shot-image-classification", +# clip_processor_path="hf-internal-testing/tiny-random-clip-zero-shot-image-classification", +# lang_encoder_path="hf-internal-testing/tiny-random-OPTModel", +# tokenizer_path="hf-internal-testing/tiny-random-OPTModel", +# ) + +# image = Image.open( +# requests.get( +# "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True +# ).raw +# ) +# vis_x = image_processor(images=[image, image], return_tensors="pt")[ +# "pixel_values" +# ] +# vis_x = vis_x.unsqueeze(1).unsqueeze(1) +# lang_x = tokenizer( +# [" A dog", " A cat"], +# max_length=10, +# padding=True, +# truncation=True, +# return_tensors="pt", +# ) + +# # try batched forward pass +# model(vis_x, lang_x["input_ids"], attention_mask=lang_x["attention_mask"]) + +# def test_generate(self): +# model, image_processor, tokenizer = create_model_and_transforms( +# clip_vision_encoder_path="hf-internal-testing/tiny-random-clip-zero-shot-image-classification", +# clip_processor_path="hf-internal-testing/tiny-random-clip-zero-shot-image-classification", +# lang_encoder_path="hf-internal-testing/tiny-random-OPTModel", +# tokenizer_path="hf-internal-testing/tiny-random-OPTModel", +# ) + +# tokenizer.padding_side = ( +# "left" # we want to pad on the left side for generation +# ) + +# image = Image.open( +# requests.get( +# "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True +# ).raw +# ) +# vis_x = image_processor(images=[image, image], return_tensors="pt")[ +# "pixel_values" +# ] +# vis_x = vis_x.unsqueeze(1).unsqueeze(1) +# lang_x = tokenizer( +# [" A dog", " A cat <|endofchunk|>"], +# max_length=10, +# padding=True, +# truncation=True, +# return_tensors="pt", +# ) + +# # try batched generation +# model.generate( +# vis_x, +# lang_x["input_ids"], +# attention_mask=lang_x["attention_mask"], +# max_new_tokens=20, +# ) + + +# if __name__ == "__main__": +# unittest.main() diff --git a/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/train_4_npus.sh b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/train_4_npus.sh new file mode 100644 index 0000000000..c3db0ae562 --- /dev/null +++ b/PyTorch/contrib/others/OpenFlamingo_ for PyTorch/train_4_npus.sh @@ -0,0 +1,17 @@ +torchrun --nnodes=1 --nproc_per_node=4 open_flamingo/train/train.py \ +--run_name new \ +--vision_encoder_pretrained "" \ +--lm_path facebook/opt-1.3b \ +--tokenizer_path facebook/opt-1.3b \ +--dataset_resampled \ +--laion_shards "/home/linzheyuan/open_flamingo/datasets/laion2b/00000.tar" \ +--batch_size_laion 4 \ +--train_num_samples_laion 50000 \ +--logging_steps 1 \ +--learning_rate 2.5e-5 \ +--num_epochs 5 \ +--loss_multiplier_laion 0.2 \ +--workers=6 \ +--lr_scheduler constant \ +--warmup_steps 0 \ +--use_media_placement_augmentation -- Gitee