From a485a8a071cedaad7728eb761dadee698e07b4e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mr=E9=9C=96?= <798948055@qq.com> Date: Thu, 29 Aug 2024 00:57:03 +0800 Subject: [PATCH] [PyTorch][built-in][MiniCPM-V] adapted MiniCPM-V to npu --- PyTorch/built-in/mlm/MiniCPM-V/README.md | 147 ++++++++++++++++++ .../mlm/MiniCPM-V/finetune/dataset.py | 18 ++- .../MiniCPM-V/finetune/ds_config_zero2.json | 2 +- .../mlm/MiniCPM-V/finetune/finetune.py | 14 +- .../mlm/MiniCPM-V/finetune/finetune_ds.sh | 45 ++++-- .../mlm/MiniCPM-V/finetune/finetune_lora.sh | 43 +++-- .../mlm/MiniCPM-V/finetune/trainer.py | 63 +++----- .../configuration_minicpm.py | 2 + .../huggingface_modify/modeling_minicpmv.py | 4 +- .../MiniCPM-V/huggingface_modify/resampler.py | 9 +- .../MiniCPM-V/npu_monkey_patch/__init__.py | 14 ++ .../idefics2_conv_monkey_patch.py | 31 +++- .../idefics2_flash_attn_monkey_patch.py | 91 +++++------ .../llama_flash_attn_monkey_patch.py | 78 ++++++---- .../llama_rmsnorm_monkey_patch.py | 16 +- .../llama_rope_monkey_patch.py | 138 ++++++---------- ...nsformers_check_flash_attn_monkey_patch.py | 17 ++ .../mlm/MiniCPM-V/npu_monkey_patch/utils.py | 16 ++ .../built-in/mlm/MiniCPM-V/requirements.txt | 9 +- PyTorch/built-in/mlm/MiniCPM-V/web_demo.py | 11 +- .../built-in/mlm/MiniCPM-V/web_demo_2.5.py | 11 +- .../mlm/MiniCPM-V/web_demo_streamlit-2_5.py | 7 + .../mlm/MiniCPM-V/web_demo_streamlit.py | 7 + 23 files changed, 523 insertions(+), 270 deletions(-) create mode 100644 PyTorch/built-in/mlm/MiniCPM-V/README.md create mode 100644 PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/__init__.py create mode 100644 PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/transformers_check_flash_attn_monkey_patch.py diff --git a/PyTorch/built-in/mlm/MiniCPM-V/README.md b/PyTorch/built-in/mlm/MiniCPM-V/README.md new file mode 100644 index 0000000000..65fcf6ce89 --- /dev/null +++ b/PyTorch/built-in/mlm/MiniCPM-V/README.md @@ -0,0 +1,147 @@ + +# MiniCPM-V for PyTorch + + + +# 目录 +- [MiniCPM-V](#llava-for-pytorch) + - [概述](#概述) + - [准备训练环境](#准备训练环境) + - [创建Python环境](#创建python环境) + - [准备数据集](#准备数据集) + - [准备预训练权重](#准备预训练权重) + - [快速开始](#快速开始) + - [模型训练](#模型训练) + - [结果展示](#结果展示) + - [模型推理](#模型推理) + - [公网地址说明](#公网地址说明) + - [变更说明](#变更说明) + - [FQA](#faq) + + + +## 概述 + +### 模型介绍 + +MiniCPM-V是面向图文理解的端侧多模态大模型系列。该系列模型接受图像和文本输入,并提供高质量的文本输出。MiniCPM-Llama3-V 2.5的多模态综合性能超越 GPT-4V-1106、Gemini Pro、Claude 3、Qwen-VL-Max 等商用闭源模型,OCR 能力及指令跟随能力进一步提升,并支持超过30种语言的多模态交互。 +### 支持任务列表 +本仓已经支持以下模型任务类型 + +| 模型 | 任务列表 | 是否支持 | +|:------------:|:------:|:-----:| +| MiniCPM-V | 全参微调 | ✔ | +| MiniCPM-V | Lora微调 | ✔ | +| MiniCPM-V | 在线推理 | ✔ | + +### 代码实现 +- 参考实现: + + ``` + url=https://github.com/OpenBMB/MiniCPM-V.git + commit_id=6a5f9a4d6556e47767e7b653a9279281d2ef7062 + ``` + +- 适配昇腾AI处理器的实现: + ```shell + url=https://gitee.com/ascend/ModelZoo-PyTorch.git + code_path=PyTorch/built-in/mlm/MiniCPM-V + ``` + +## 准备训练环境 + +### 创建Python环境 + +- git clone 远程仓 + ```shell + git clone https://gitee.com/ascend/ModelZoo-PyTorch.git + cd PyTorch/built-in/mlm/MiniCPM-V + ``` + +- 创建Python环境并且安装Python三方包 + ```shell + conda create -n MiniCPM-V python=3.10 -y + conda activate MiniCPM-V + pip install -r requirements.txt + ``` + +- 环境准备指导 + + 请参考《[Pytorch框架训练环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/ptes)》。 + + **表 1** 昇腾软件版本支持表 + + | 软件类型 | 支持版本 | + |:-----------:|:--------:| + | FrameworkPTAdapter | 在研版本 | + | CANN | 在研版本 | + | 昇腾NPU固件 | 在研版本 | + | 昇腾NPU驱动 | 在研版本 | + + +### 准备数据集 + +- 需要自行下载textVQA数据集,涉及到的数据集结构如下所示: + ``` + TextVQA + ├── TextVQA_0.5.1_train.json + ├── TextVQA_0.5.1_val.json + └── train_images + ``` + json文件格式请参考 https://github.com/OpenBMB/MiniCPM-V/blob/main/finetune/readme.md 中的数据准备章节。 +### 准备预训练权重 + +1. 联网情况下,预训练模型会自动下载。 + +2. 无网络时,用户可访问huggingface官网自行下载,文件namespace如下: + ```shell + openbmb/MiniCPM-Llama3-V-2_5 + ``` +## 快速开始 + +### 模型训练 + +1. 全参微调脚本位置位于finetune/finetune_ds.sh;Lora微调脚本位置位于finetune/finetune_lora.sh,需要手动将数据集,权重的路径传入到相应参数上,路径仅供参考,请用户根据实际情况修改。 + ```shell + MODEL="openbmb/MiniCPM-Llama3-V-2_5" # MiniCPM-V权重路径 + DATA="path/to/trainging_data" # 训练数据路径 + EVAL_DATA="path/to/test_data" # 验证数据路径 + ``` + +2. 运行训练脚本,该模型支持单机8卡训练。 + + ```shell + bash finetune/finetune_ds.sh # 全参微调 + bash finetune/finetune_lora.sh # Lora微调 + ``` + 训练完成后,权重文件保存在参数`--finetune/output`路径下。 +### 结果展示 + +**表 2** 训练结果展示 + +| 芯片 | 卡数 | 50-200步训练耗时(s) | batch_size | Data_Type | Torch_Version | +|:------------------:|:---:|:--------------:|:----------:|:---------:|:---:| +| 竞品A-全参微调 | 8p | 847 | 12 | bf16 | 2.1 | +| Atlas 800T A2-全参微调 | 8p | 1046 | 12 | bf16 | 2.1 | +| 竞品A-Lora微调 | 8p | 490 | 8 | bf16 | 2.1 | +| Atlas 800T A2-Lora微调 | 8p | 603 | 8 | bf16 | 2.1 | + +### 模型推理 + + 执行下面命令即可进行推理。 + ``` + python web_demo_2.5.py --device npu + ``` + +## 公网地址说明 + +代码涉及公网地址参考 public_address_statement.md + + +## 变更说明 +2024.08.26: 首次发布 + +2024.08.29: 添加NPU适配代码 + +## FAQ +无 diff --git a/PyTorch/built-in/mlm/MiniCPM-V/finetune/dataset.py b/PyTorch/built-in/mlm/MiniCPM-V/finetune/dataset.py index 92807c38b4..1cf15cdcda 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/finetune/dataset.py +++ b/PyTorch/built-in/mlm/MiniCPM-V/finetune/dataset.py @@ -125,14 +125,16 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None): # build target target = torch.full_like(ids, -100, dtype=torch.int32) - for i in range(1, len(ids)): - if context[i] == 0: - target[i - 1] = ids[i] - if context[i] == 1 and context[i - 1] == 0: - if hasattr(tokenizer, "eot_id"): - target[i - 1] = tokenizer.eot_id - else: - target[i - 1] = tokenizer.eos_id + mask_zero = context == 0 + target[:-1][mask_zero[1:]] = ids[1:][mask_zero[1:]] + + mask_one_zero = (context == 1) & (torch.roll(context, 1, 0) == 0) + mask_one_zero = mask_one_zero[1:] + + if hasattr(tokenizer, "eot_id"): + eot_or_eos_id = tokenizer.eot_id + else: + eot_or_eos_id = tokenizer.eos_id # build image bound image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0] diff --git a/PyTorch/built-in/mlm/MiniCPM-V/finetune/ds_config_zero2.json b/PyTorch/built-in/mlm/MiniCPM-V/finetune/ds_config_zero2.json index 4d42d440b4..a16674d460 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/finetune/ds_config_zero2.json +++ b/PyTorch/built-in/mlm/MiniCPM-V/finetune/ds_config_zero2.json @@ -1,7 +1,7 @@ { "fp16": { "enabled": "auto", - "loss_scale": 0, + "loss_scale": 32, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, diff --git a/PyTorch/built-in/mlm/MiniCPM-V/finetune/finetune.py b/PyTorch/built-in/mlm/MiniCPM-V/finetune/finetune.py index 44760495a0..3bf1fe8b63 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/finetune/finetune.py +++ b/PyTorch/built-in/mlm/MiniCPM-V/finetune/finetune.py @@ -20,6 +20,17 @@ from dataset import SupervisedDataset, data_collator from trainer import CPMTrainer from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +from PIL import ImageFile +from npu_monkey_patch.utils import is_npu_available + +if is_npu_available(): + import torch_npu + from torch_npu.contrib import transfer_to_npu + import npu_monkey_patch + torch.npu.config.allow_internal_format = False + +ImageFile.LOAD_TRUNCATED_IMAGES = True + @dataclass class ModelArguments: @@ -150,7 +161,6 @@ local_rank = 0 def train(): - seed_all(is_gpu=False, mode=False) global local_rank parser = transformers.HfArgumentParser( (ModelArguments, DataArguments, TrainingArguments, LoraArguments) @@ -183,11 +193,13 @@ def train(): "FSDP or ZeRO3 are not incompatible with QLoRA." ) + use_flash_attention_2 = os.getenv("use_flash_attention_2") == 'true' model = AutoModel.from_pretrained( model_args.model_name_or_path, trust_remote_code=True, torch_dtype=compute_dtype, device_map=device_map, + use_flash_attention_2=use_flash_attention_2 ) tokenizer = AutoTokenizer.from_pretrained( diff --git a/PyTorch/built-in/mlm/MiniCPM-V/finetune/finetune_ds.sh b/PyTorch/built-in/mlm/MiniCPM-V/finetune/finetune_ds.sh index 5dc3a3e67e..156dcd3f58 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/finetune/finetune_ds.sh +++ b/PyTorch/built-in/mlm/MiniCPM-V/finetune/finetune_ds.sh @@ -1,6 +1,20 @@ #!/bin/bash -GPUS_PER_NODE=8 +# 获取当前目录的名字 +current_dir=$(basename "$PWD") + +# 判断当前目录名字是否为 "finetune" +if [ "$current_dir" = "finetune" ]; then + # 如果在finetune目录下,则返回上一级目录 + cd .. +fi + +source /path/to/cann/ascend-toolkit/set_env.sh + +USE_FLASH_ATTENTION_2=true +export use_flash_attention_2=$USE_FLASH_ATTENTION_2 + +NPUS_PER_NODE=8 NNODES=1 NODE_RANK=0 MASTER_ADDR=localhost @@ -14,13 +28,13 @@ EVAL_DATA="path/to/test_data" LLM_TYPE="llama3" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm DISTRIBUTED_ARGS=" - --nproc_per_node $GPUS_PER_NODE \ + --nproc_per_node $NPUS_PER_NODE \ --nnodes $NNODES \ --node_rank $NODE_RANK \ --master_addr $MASTER_ADDR \ --master_port $MASTER_PORT " -torchrun $DISTRIBUTED_ARGS finetune.py \ +torchrun $DISTRIBUTED_ARGS finetune/finetune.py \ --model_name_or_path $MODEL \ --llm_type $LLM_TYPE \ --data_path $DATA \ @@ -28,20 +42,21 @@ torchrun $DISTRIBUTED_ARGS finetune.py \ --remove_unused_columns false \ --label_names "labels" \ --prediction_loss_only false \ - --bf16 false \ - --bf16_full_eval false \ - --fp16 true \ - --fp16_full_eval true \ + --bf16 true \ + --bf16_full_eval true \ + --fp16 false \ + --fp16_full_eval false \ --do_train \ --do_eval \ --tune_vision true \ --tune_llm true \ + --dataloader_num_workers 1 \ --model_max_length 2048 \ --max_slice_nums 9 \ --max_steps 10000 \ --eval_steps 1000 \ - --output_dir output/output_minicpmv2 \ - --logging_dir output/output_minicpmv2 \ + --output_dir finetune/output/output_minicpmv2 \ + --logging_dir finetune/output/output_minicpmv2 \ --logging_strategy "steps" \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ @@ -57,5 +72,13 @@ torchrun $DISTRIBUTED_ARGS finetune.py \ --lr_scheduler_type "cosine" \ --logging_steps 1 \ --gradient_checkpointing true \ - --deepspeed ds_config_zero2.json \ - --report_to "tensorboard" + --deepspeed finetune/ds_config_zero2.json \ + --report_to "tensorboard" 2>&1 | tee finetune/ds.log 2>&1 & + +wait + +# 输出50步-200步训练耗时 +start=$(grep -m1 -oP '50/\d+\s+\[\K\d{2}:\d{2}' finetune/ds.log) +end=$(grep -m1 -oP '200/\d+\s+\[\K\d{2}:\d{2}' finetune/ds.log) +time=$(( (${end%:*} * 60 + ${end#*:}) - (${start%:*} * 60 + ${start#*:}) )) +echo "50-200step training time : $time" diff --git a/PyTorch/built-in/mlm/MiniCPM-V/finetune/finetune_lora.sh b/PyTorch/built-in/mlm/MiniCPM-V/finetune/finetune_lora.sh index 22cf5a2902..3d00d2875a 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/finetune/finetune_lora.sh +++ b/PyTorch/built-in/mlm/MiniCPM-V/finetune/finetune_lora.sh @@ -1,6 +1,20 @@ #!/bin/bash -GPUS_PER_NODE=8 +# 获取当前目录的名字 +current_dir=$(basename "$PWD") + +# 判断当前目录名字是否为 "finetune" +if [ "$current_dir" = "finetune" ]; then + # 如果在finetune目录下,则返回上一级目录 + cd .. +fi + +source /path/to/cann/ascend-toolkit/set_env.sh + +USE_FLASH_ATTENTION_2=true +export use_flash_attention_2=$USE_FLASH_ATTENTION_2 + +NPUS_PER_NODE=8 NNODES=1 NODE_RANK=0 MASTER_ADDR=localhost @@ -20,7 +34,7 @@ DISTRIBUTED_ARGS=" --master_addr $MASTER_ADDR \ --master_port $MASTER_PORT " -torchrun $DISTRIBUTED_ARGS finetune.py \ +torchrun $DISTRIBUTED_ARGS finetune/finetune.py \ --model_name_or_path $MODEL \ --llm_type $LLM_TYPE \ --data_path $DATA \ @@ -28,22 +42,23 @@ torchrun $DISTRIBUTED_ARGS finetune.py \ --remove_unused_columns false \ --label_names "labels" \ --prediction_loss_only false \ - --bf16 false \ - --bf16_full_eval false \ - --fp16 true \ - --fp16_full_eval true \ + --bf16 true \ + --bf16_full_eval true \ + --fp16 false \ + --fp16_full_eval false \ --do_train \ --do_eval \ --tune_vision true \ --tune_llm false \ --use_lora true \ --lora_target_modules "llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj|o_proj)" \ + --dataloader_num_workers 1 \ --model_max_length 2048 \ --max_slice_nums 9 \ --max_steps 10000 \ --eval_steps 1000 \ - --output_dir output/output_minicpmv2_lora \ - --logging_dir output/output_minicpmv2_lora \ + --output_dir finetune/output/output_minicpmv2_lora \ + --logging_dir finetune/output/output_minicpmv2_lora \ --logging_strategy "steps" \ --per_device_train_batch_size 2 \ --per_device_eval_batch_size 1 \ @@ -59,5 +74,13 @@ torchrun $DISTRIBUTED_ARGS finetune.py \ --lr_scheduler_type "cosine" \ --logging_steps 1 \ --gradient_checkpointing true \ - --deepspeed ds_config_zero2.json \ - --report_to "tensorboard" # wandb + --deepspeed finetune/ds_config_zero2.json \ + --report_to "tensorboard" 2>&1 | tee finetune/lora.log 2>&1 & + +wait + +# 输出50步-200步训练耗时 +start=$(grep -m1 -oP '50/\d+\s+\[\K\d{2}:\d{2}' finetune/lora.log) +end=$(grep -m1 -oP '200/\d+\s+\[\K\d{2}:\d{2}' finetune/lora.log) +time=$(( (${end%:*} * 60 + ${end#*:}) - (${start%:*} * 60 + ${start#*:}) )) +echo "50-200step training time : $time" diff --git a/PyTorch/built-in/mlm/MiniCPM-V/finetune/trainer.py b/PyTorch/built-in/mlm/MiniCPM-V/finetune/trainer.py index fa57bd047b..03c786dbe1 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/finetune/trainer.py +++ b/PyTorch/built-in/mlm/MiniCPM-V/finetune/trainer.py @@ -14,21 +14,13 @@ class CPMTrainer(Trainer): labels = inputs.pop("labels") else: labels = None - self.model.resampler.pos_embed = self.model.resampler.pos_embed.to(self.model.device) - if is_deepspeed_zero3_enabled(): - with deepspeed.zero.GatheredParameters(self.model.resampler.attn.parameters(), modifier_rank=0): - if not self.args.use_lora: - outputs = self.model(data = inputs, use_cache=False) - else: - with self.model._enable_peft_forward_hooks(**inputs): - outputs = self.model.base_model(data = inputs, use_cache=False) + + if not self.args.use_lora: + outputs = self.model(data=inputs, use_cache=False) else: - if not self.args.use_lora: - outputs = self.model(data = inputs, use_cache=False) - else: - with self.model._enable_peft_forward_hooks(**inputs): - outputs = self.model.base_model(data = inputs, use_cache=False) - + with self.model._enable_peft_forward_hooks(**inputs): + outputs = self.model.base_model(data=inputs, use_cache=False) + if labels is not None: # Flatten the tokens loss_fct = nn.CrossEntropyLoss() @@ -50,11 +42,11 @@ class CPMTrainer(Trainer): return (loss, outputs) if return_outputs else loss def prediction_step( - self, - model: nn.Module, - inputs: Dict[str, Union[torch.Tensor, Any]], - prediction_loss_only: bool, - ignore_keys: Optional[List[str]] = None, + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform an evaluation step on `model` using `inputs`. @@ -106,7 +98,7 @@ class CPMTrainer(Trainer): # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. if has_labels or loss_without_labels: labels = nested_detach(tuple(inputs.get(name) - for name in self.label_names)) + for name in self.label_names)) if len(labels) == 1: labels = labels[0] else: @@ -176,7 +168,7 @@ class CPMTrainer(Trainer): logits = logits[0] return (loss, logits, labels) - + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: """ Perform a training step on a batch of inputs. @@ -205,9 +197,6 @@ class CPMTrainer(Trainer): with self.compute_loss_context_manager(): loss = self.compute_loss(model, inputs) - del inputs - torch.cuda.empty_cache() - if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training @@ -215,14 +204,10 @@ class CPMTrainer(Trainer): with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: - if is_deepspeed_zero3_enabled(): - with deepspeed.zero.GatheredParameters(self.model.resampler.attn.parameters(), modifier_rank=0): - self.accelerator.backward(loss) - else: - self.accelerator.backward(loss) + self.accelerator.backward(loss) return loss.detach() / self.args.gradient_accumulation_steps - + def _save(self, output_dir: Optional[str] = None, state_dict=None): # If we are executing this function, we are the process zero, so we don't check for that. output_dir = output_dir if output_dir is not None else self.args.output_dir @@ -249,20 +234,10 @@ class CPMTrainer(Trainer): else: torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: - if self.args.use_lora: - from collections import OrderedDict - state_dict_vision = OrderedDict() - for key, values in state_dict.items(): - if 'vpm' in key or 'resampler' in key or 'embed_tokens' in key: - state_dict_vision[key] = values - self.model.save_pretrained( - output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors - ) - torch.save(state_dict_vision, f"{output_dir}/vpm_resampler_embedtokens.pt", ) - else: - self.model.save_pretrained( - output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors - ) + + self.model.save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) diff --git a/PyTorch/built-in/mlm/MiniCPM-V/huggingface_modify/configuration_minicpm.py b/PyTorch/built-in/mlm/MiniCPM-V/huggingface_modify/configuration_minicpm.py index ac8a95a909..1502d9dce7 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/huggingface_modify/configuration_minicpm.py +++ b/PyTorch/built-in/mlm/MiniCPM-V/huggingface_modify/configuration_minicpm.py @@ -108,6 +108,8 @@ class MiniCPMVConfig(LlamaConfig): elif isinstance(vision_config, Idefics2VisionConfig): self.vision_config = vision_config + if os.getenv('use_flash_attention_2') == 'true': + self.vision_config._attn_implementation = 'flash_attention_2' self.patch_size = self.vision_config.patch_size super().__init__(**kwargs) diff --git a/PyTorch/built-in/mlm/MiniCPM-V/huggingface_modify/modeling_minicpmv.py b/PyTorch/built-in/mlm/MiniCPM-V/huggingface_modify/modeling_minicpmv.py index d0c5186c2c..59250f1322 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/huggingface_modify/modeling_minicpmv.py +++ b/PyTorch/built-in/mlm/MiniCPM-V/huggingface_modify/modeling_minicpmv.py @@ -106,11 +106,11 @@ class MiniCPMV(MiniCPMVPreTrainedModel): patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device) for i in range(B): - patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True + patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True vision_embedding = self.vpm(all_pixel_values.type(dtype), patch_attention_mask=patch_attn_mask).last_hidden_state - vision_embedding = self.resampler(vision_embedding, tgt_sizes) + vision_embedding = self.resampler(vision_embedding, (tgt_sizes, patch_attn_mask)) else: # get vision_embedding foreach vision_embedding = [] diff --git a/PyTorch/built-in/mlm/MiniCPM-V/huggingface_modify/resampler.py b/PyTorch/built-in/mlm/MiniCPM-V/huggingface_modify/resampler.py index c0d13540ae..84ce15d983 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/huggingface_modify/resampler.py +++ b/PyTorch/built-in/mlm/MiniCPM-V/huggingface_modify/resampler.py @@ -129,24 +129,21 @@ class Resampler(nn.Module): nn.init.constant_(m.weight, 1.0) def forward(self, x, tgt_sizes=None): + tgt_sizes, mask = tgt_sizes + key_padding_mask = ~mask.squeeze(1) + assert x.shape[0] == tgt_sizes.shape[0] bs = x.shape[0] device = x.device dtype = x.dtype - patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] - self._adjust_pos_cache(tgt_sizes, device=device) - max_patch_len = torch.max(patch_len) - key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool, device=device) - pos_embed = [] for i in range(bs): tgt_h, tgt_w = tgt_sizes[i] pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)) # patches * D - key_padding_mask[i, patch_len[i]:] = True pos_embed = torch.nn.utils.rnn.pad_sequence( pos_embed, batch_first=True, padding_value=0.0).permute(1, 0, 2) # BLD => L * B * D diff --git a/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/__init__.py b/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/__init__.py new file mode 100644 index 0000000000..7cc6703ef2 --- /dev/null +++ b/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/__init__.py @@ -0,0 +1,14 @@ +from npu_monkey_patch.idefics2_conv_monkey_patch import replace_with_torch_npu_idefics2_visionembeddings +from npu_monkey_patch.idefics2_flash_attn_monkey_patch import replace_with_torch_npu_idefics2_flash_attention +from npu_monkey_patch.llama_rmsnorm_monkey_patch import replace_with_torch_npu_llama_rmsnorm +from npu_monkey_patch.llama_rope_monkey_patch import replace_with_torch_npu_llama_rope +from npu_monkey_patch.llama_flash_attn_monkey_patch import replace_with_torch_npu_llama_flash_attention +from npu_monkey_patch.transformers_check_flash_attn_monkey_patch import replace_with_torch_npu_check_flash_attn_2 + + +replace_with_torch_npu_idefics2_visionembeddings() +replace_with_torch_npu_idefics2_flash_attention() +replace_with_torch_npu_llama_rmsnorm() +replace_with_torch_npu_llama_rope() +replace_with_torch_npu_llama_flash_attention() +replace_with_torch_npu_check_flash_attn_2() diff --git a/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/idefics2_conv_monkey_patch.py b/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/idefics2_conv_monkey_patch.py index a752532d39..24fa79666b 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/idefics2_conv_monkey_patch.py +++ b/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/idefics2_conv_monkey_patch.py @@ -1,9 +1,10 @@ import torch from torch import nn +import transformers from transformers.models.idefics2.configuration_idefics2 import Idefics2VisionConfig -class Idefics2VisionEmbeddings(nn.Module): +class NpuIdefics2VisionEmbeddings(nn.Module): """ This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable resolution. @@ -19,6 +20,7 @@ class Idefics2VisionEmbeddings(nn.Module): self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size + self.split_size = config.patch_size * 10 self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, @@ -29,18 +31,31 @@ class Idefics2VisionEmbeddings(nn.Module): ) self.num_patches_per_side = self.image_size // self.patch_size - self.num_patches = self.num_patches_per_side**2 + self.num_patches = self.num_patches_per_side ** 2 + self.boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) self.num_positions = self.num_patches self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: batch_size, _, max_im_h, max_im_w = pixel_values.shape - patch_embeds = self.patch_embedding(pixel_values) + patches = list(pixel_values.split(self.split_size, dim=-1)) + + pad_size = self.split_size - patches[-1].shape[-1] + if pad_size: + patches[-1] = torch.nn.functional.pad(patches[-1], (0, pad_size), mode='constant', value=0) + + combined_patches = torch.cat(patches, dim=0) + patch_embeds = self.patch_embedding(combined_patches) + patch_embeds = torch.split(patch_embeds, split_size_or_sections=pixel_values.shape[0], dim=0) + patch_embeds = torch.cat(patch_embeds, dim=-1) + + if pad_size: + patch_embeds = patch_embeds[..., :-int(pad_size / self.patch_size)] + embeddings = patch_embeds.flatten(2).transpose(1, 2) max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size - boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0) for batch_idx, p_attn_mask in enumerate(patch_attention_mask): @@ -50,8 +65,8 @@ class Idefics2VisionEmbeddings(nn.Module): fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) - bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) - bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + bucket_coords_h = torch.bucketize(fractional_coords_h, self.boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, self.boundaries, right=True) pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids @@ -59,3 +74,7 @@ class Idefics2VisionEmbeddings(nn.Module): position_ids = position_ids.to(self.position_embedding.weight.device) embeddings = embeddings + self.position_embedding(position_ids) return embeddings + + +def replace_with_torch_npu_idefics2_visionembeddings(): + transformers.models.idefics2.modeling_idefics2.Idefics2VisionEmbeddings = NpuIdefics2VisionEmbeddings diff --git a/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/idefics2_flash_attn_monkey_patch.py b/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/idefics2_flash_attn_monkey_patch.py index 010aedbca0..c153d68bd6 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/idefics2_flash_attn_monkey_patch.py +++ b/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/idefics2_flash_attn_monkey_patch.py @@ -5,15 +5,14 @@ import torch_npu from npu_monkey_patch.utils import index_first_axis, pad_input, unpad_input import transformers from transformers.cache_utils import Cache -from transformers.utils import is_flash_attn_greater_or_equal_2_10, logging, is_flash_attn_2_available -from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionAttention, _get_unpad_data -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func +from transformers.utils import is_flash_attn_greater_or_equal_2_10, logging +from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionAttention, _get_unpad_data, \ + IDEFICS_VISION_ATTENTION_CLASSES logger = logging.get_logger(__name__) -class Idefics2VisionFlashAttention2(Idefics2VisionAttention): +class NpuIdefics2VisionFlashAttention2(Idefics2VisionAttention): """ Idefics2Vision flash attention module. This module inherits from `Idefics2VisionAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of @@ -30,14 +29,14 @@ class Idefics2VisionFlashAttention2(Idefics2VisionAttention): self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: output_attentions = False @@ -47,23 +46,14 @@ class Idefics2VisionFlashAttention2(Idefics2VisionAttention): key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim) - kv_seq_len = key_states.shape[-2] + kv_seq_len = key_states.shape[1] if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - dropout_rate = self.dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons @@ -106,7 +96,7 @@ class Idefics2VisionFlashAttention2(Idefics2VisionAttention): # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward def _flash_attention_forward( - self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -141,26 +131,35 @@ class Idefics2VisionFlashAttention2(Idefics2VisionAttention): ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) + attn_output_unpad = torch_npu.npu_fusion_attention( + query_states, key_states, value_states, self.num_heads, + pse=None, + padding_mask=None, + atten_mask=None, + scale=self.scale, + keep_prob=1, + input_layout="TND", + actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()), + actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()), + pre_tockens=2147483647, + next_tockens=2147483647, + sparse_mode=0)[0] attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal - ) + attn_output = torch_npu.npu_fusion_attention(query_states, + key_states, + value_states, + self.num_heads, + "BSND", + pse=None, + keep_prob=1. - dropout, + scale=self.scale, + pre_tockens=65536, + next_tockens=65536, + sync=False, + inner_precise=0, + )[0] return attn_output @@ -202,3 +201,7 @@ class Idefics2VisionFlashAttention2(Idefics2VisionAttention): (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) + + +def replace_with_torch_npu_idefics2_flash_attention(): + IDEFICS_VISION_ATTENTION_CLASSES['flash_attention_2'] = NpuIdefics2VisionFlashAttention2 diff --git a/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/llama_flash_attn_monkey_patch.py b/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/llama_flash_attn_monkey_patch.py index 10e2a643fd..63e5c738b9 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/llama_flash_attn_monkey_patch.py +++ b/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/llama_flash_attn_monkey_patch.py @@ -1,20 +1,20 @@ from typing import Optional, Tuple -import warnings import torch -import torch.nn as nn import torch.utils.checkpoint +import torch_npu import transformers from transformers.cache_utils import Cache -from .utils import index_first_axis, pad_input, unpad_input -from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, _get_unpad_data +from npu_monkey_patch.utils import index_first_axis, pad_input, unpad_input +from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, _get_unpad_data, LLAMA_ATTENTION_CLASSES from transformers.utils import is_flash_attn_greater_or_equal_2_10, logging, is_flash_attn_2_available +from npu_monkey_patch.llama_rope_monkey_patch import apply_fused_rotary_pos_emb if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func logger = logging.get_logger(__name__) -class LlamaFlashAttention2(LlamaAttention): +class NpuLlamaFlashAttention2(LlamaAttention): """ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of @@ -28,6 +28,7 @@ class LlamaFlashAttention2(LlamaAttention): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.scale = self.head_dim ** -0.5 def forward( self, @@ -51,13 +52,13 @@ class LlamaFlashAttention2(LlamaAttention): # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = apply_fused_rotary_pos_emb(query_states, key_states, cos, sin) past_key_value = getattr(self, "past_key_value", past_key_value) @@ -66,12 +67,6 @@ class LlamaFlashAttention2(LlamaAttention): cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - dropout_rate = self.attention_dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons @@ -113,7 +108,7 @@ class LlamaFlashAttention2(LlamaAttention): return attn_output, attn_weights, past_key_value def _flash_attention_forward( - self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -148,26 +143,39 @@ class LlamaFlashAttention2(LlamaAttention): ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) + + attn_output_unpad = torch_npu.npu_fusion_attention( + query_states, key_states, value_states, self.num_heads, + pse=None, + padding_mask=None, + atten_mask=None, + scale=self.scale, + keep_prob=1, + input_layout="TND", + actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()), + actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()), + pre_tockens=2147483647, + next_tockens=2147483647, + sparse_mode=0)[0] attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal - ) + attn_mask_npu = torch.triu(torch.ones((query_length, query_length), device=query_states.device), + diagonal=1).bool() + attn_output = torch_npu.npu_fusion_attention(query_states, + key_states, + value_states, + self.num_heads, + "BSND", + pse=None, + keep_prob=1. - dropout, + scale=self.scale, + atten_mask=attn_mask_npu, + pre_tockens=65536, + next_tockens=65536, + sync=False, + inner_precise=0, + )[0] return attn_output @@ -208,3 +216,7 @@ class LlamaFlashAttention2(LlamaAttention): (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) + + +def replace_with_torch_npu_llama_flash_attention(): + LLAMA_ATTENTION_CLASSES["flash_attention_2"] = NpuLlamaFlashAttention2 diff --git a/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/llama_rmsnorm_monkey_patch.py b/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/llama_rmsnorm_monkey_patch.py index a74bb38d0a..5517f1242f 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/llama_rmsnorm_monkey_patch.py +++ b/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/llama_rmsnorm_monkey_patch.py @@ -1,8 +1,10 @@ import torch import torch.nn as nn +import torch_npu +import transformers -class LlamaRMSNorm(nn.Module): +class NpuLlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm @@ -14,6 +16,12 @@ class LlamaRMSNorm(nn.Module): def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return torch_npu.npu_rms_norm( + hidden_states, + self.weight.to(torch.float32), + epsilon=self.variance_epsilon + )[0].to(input_dtype) + + +def replace_with_torch_npu_llama_rmsnorm(): + transformers.models.llama.modeling_llama.LlamaRMSNorm = NpuLlamaRMSNorm diff --git a/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/llama_rope_monkey_patch.py b/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/llama_rope_monkey_patch.py index 3eb6cba8b0..adb6cd3f23 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/llama_rope_monkey_patch.py +++ b/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/llama_rope_monkey_patch.py @@ -1,11 +1,14 @@ +import os import torch import torch.nn as nn +import torch_npu +import transformers from transformers.utils import logging logger = logging.get_logger(__name__) -class LlamaRotaryEmbedding(nn.Module): +class NpuLlamaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): super().__init__() self.scaling_factor = scaling_factor @@ -14,107 +17,54 @@ class LlamaRotaryEmbedding(nn.Module): self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - # For BC we register cos and sin cached - self.max_seq_len_cached = max_position_embeddings + self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device) + + def _set_cos_sin_cache(self, seq_len, device): + self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) - self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) - - @property - def sin_cached(self): - logger.warning_once( - "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " - "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" - ) - return self._sin_cached - - @property - def cos_cached(self): - logger.warning_once( - "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " - "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" - ) - return self._cos_cached + self.register_buffer("cos_cached", emb.cos().unsqueeze(0).to(torch.get_default_dtype()), persistent=False) + self.register_buffer("sin_cached", emb.sin().unsqueeze(0).to(torch.get_default_dtype()), persistent=False) @torch.no_grad() def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def forward(self, x, position_ids): - # difference to the original RoPE: a scaling factor is aplied to the position ids - position_ids = position_ids.float() / self.scaling_factor - cos, sin = super().forward(x, position_ids) - return cos, sin - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def forward(self, x, position_ids): - # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / ( - base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + _, seq_len = position_ids.shape + + if seq_len != 1: + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device) + return ( + self.cos_cached[:, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :seq_len, ...].to(dtype=x.dtype), ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation - - cos, sin = super().forward(x, position_ids) - return cos, sin - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ + else: + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +unsqueeze_dim = 2 if os.getenv('use_flash_attention_2') == 'true' else 1 + + +def apply_fused_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=unsqueeze_dim): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed + return torch_npu.npu_rotary_mul(q, cos, sin), torch_npu.npu_rotary_mul(k, cos, sin) + + +def replace_with_torch_npu_llama_rope(): + transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = NpuLlamaRotaryEmbedding + transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_fused_rotary_pos_emb diff --git a/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/transformers_check_flash_attn_monkey_patch.py b/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/transformers_check_flash_attn_monkey_patch.py new file mode 100644 index 0000000000..897d613efd --- /dev/null +++ b/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/transformers_check_flash_attn_monkey_patch.py @@ -0,0 +1,17 @@ +from typing import Optional, Union, Dict +import torch +import transformers + + +def check_flash_attn_2( + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + hard_check_only: bool = False, +): + return config + + +def replace_with_torch_npu_check_flash_attn_2(): + transformers.modeling_utils.PreTrainedModel._check_and_enable_flash_attn_2 = check_flash_attn_2 diff --git a/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/utils.py b/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/utils.py index 8d3a0c5b42..18f5b1eae8 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/utils.py +++ b/PyTorch/built-in/mlm/MiniCPM-V/npu_monkey_patch/utils.py @@ -1,6 +1,7 @@ import torch import torch.nn.functional as F from einops import rearrange, repeat +import importlib class IndexFirstAxis(torch.autograd.Function): @@ -211,3 +212,18 @@ def pad_input(hidden_states, indices, batch, seqlen): # output[indices] = hidden_states output = index_put_first_axis(hidden_states, indices, batch * seqlen) return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +def is_npu_available(): + "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" + if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None: + return False + + import torch_npu + + try: + # Will raise a RuntimeError if no NPU is found + _ = torch.npu.device_count() + return torch.npu.is_available() + except RuntimeError: + return False diff --git a/PyTorch/built-in/mlm/MiniCPM-V/requirements.txt b/PyTorch/built-in/mlm/MiniCPM-V/requirements.txt index dce85f9984..f2cac408ab 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/requirements.txt +++ b/PyTorch/built-in/mlm/MiniCPM-V/requirements.txt @@ -17,8 +17,8 @@ seaborn==0.13.0 shortuuid==1.0.11 spacy==3.7.2 timm==0.9.10 -torch==2.1.2 -torchvision==0.16.2 +torch==2.1.0 +torchvision==0.16.0 tqdm==4.66.1 protobuf==4.25.0 transformers==4.40.0 @@ -31,3 +31,8 @@ accelerate==0.30.1 socksio==1.0.0 gradio gradio_client +scipy==1.14.0 +decorator==5.1.1 +deepspeed==0.14.4 +peft==0.12.0 +tensorboardX==2.6.2.2 diff --git a/PyTorch/built-in/mlm/MiniCPM-V/web_demo.py b/PyTorch/built-in/mlm/MiniCPM-V/web_demo.py index 668dcf4ddd..c46a1fa5c4 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/web_demo.py +++ b/PyTorch/built-in/mlm/MiniCPM-V/web_demo.py @@ -7,6 +7,13 @@ import re import torch import argparse from transformers import AutoModel, AutoTokenizer +from npu_monkey_patch.utils import is_npu_available + +if is_npu_available(): + import torch_npu + from torch_npu.contrib import transfer_to_npu + import npu_monkey_patch + torch.npu.config.allow_internal_format = False # README, How to run demo on different devices # For Nvidia GPUs support BF16 (like A100, H100, RTX3090) @@ -20,11 +27,11 @@ from transformers import AutoModel, AutoTokenizer # Argparser parser = argparse.ArgumentParser(description='demo') -parser.add_argument('--device', type=str, default='cuda', help='cuda or mps') +parser.add_argument('--device', type=str, default='cuda', help='cuda or mps or npu') parser.add_argument('--dtype', type=str, default='bf16', help='bf16 or fp16') args = parser.parse_args() device = args.device -assert device in ['cuda', 'mps'] +assert device in ['cuda', 'mps', 'npu'] if args.dtype == 'bf16': if device == 'mps': print('Warning: MPS does not support bf16, will use fp16 instead') diff --git a/PyTorch/built-in/mlm/MiniCPM-V/web_demo_2.5.py b/PyTorch/built-in/mlm/MiniCPM-V/web_demo_2.5.py index 6f6b81af37..a02014edfe 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/web_demo_2.5.py +++ b/PyTorch/built-in/mlm/MiniCPM-V/web_demo_2.5.py @@ -7,6 +7,13 @@ import re import torch import argparse from transformers import AutoModel, AutoTokenizer +from npu_monkey_patch.utils import is_npu_available + +if is_npu_available(): + import torch_npu + from torch_npu.contrib import transfer_to_npu + import npu_monkey_patch + torch.npu.config.allow_internal_format = False # README, How to run demo on different devices @@ -18,10 +25,10 @@ from transformers import AutoModel, AutoTokenizer # Argparser parser = argparse.ArgumentParser(description='demo') -parser.add_argument('--device', type=str, default='cuda', help='cuda or mps') +parser.add_argument('--device', type=str, default='cuda', help='cuda or mps or npu') args = parser.parse_args() device = args.device -assert device in ['cuda', 'mps'] +assert device in ['cuda', 'mps', 'npu'] # Load model model_path = 'openbmb/MiniCPM-Llama3-V-2_5' diff --git a/PyTorch/built-in/mlm/MiniCPM-V/web_demo_streamlit-2_5.py b/PyTorch/built-in/mlm/MiniCPM-V/web_demo_streamlit-2_5.py index e3d67e1b9b..80e5d18c96 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/web_demo_streamlit-2_5.py +++ b/PyTorch/built-in/mlm/MiniCPM-V/web_demo_streamlit-2_5.py @@ -2,6 +2,13 @@ import streamlit as st from PIL import Image import torch from transformers import AutoModel, AutoTokenizer +from npu_monkey_patch.utils import is_npu_available + +if is_npu_available(): + import torch_npu + from torch_npu.contrib import transfer_to_npu + import npu_monkey_patch + torch.npu.config.allow_internal_format = False # Model path model_path = "openbmb/MiniCPM-Llama3-V-2_5" diff --git a/PyTorch/built-in/mlm/MiniCPM-V/web_demo_streamlit.py b/PyTorch/built-in/mlm/MiniCPM-V/web_demo_streamlit.py index 204a495133..f4a3da0997 100644 --- a/PyTorch/built-in/mlm/MiniCPM-V/web_demo_streamlit.py +++ b/PyTorch/built-in/mlm/MiniCPM-V/web_demo_streamlit.py @@ -2,6 +2,13 @@ import streamlit as st from PIL import Image import torch from transformers import AutoModel, AutoTokenizer +from npu_monkey_patch.utils import is_npu_available + +if is_npu_available(): + import torch_npu + from torch_npu.contrib import transfer_to_npu + import npu_monkey_patch + torch.npu.config.allow_internal_format = False # Model path model_path = "openbmb/MiniCPM-V-2" -- Gitee