From 99a708196d4e9424797caefdc1ba2489da5b7403 Mon Sep 17 00:00:00 2001 From: gitee_code_template Date: Mon, 29 Jan 2024 19:37:47 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E6=96=B0=E5=A2=9Ecogvlm=E5=A4=9A=E6=A8=A1?= =?UTF-8?q?=E6=80=81=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- PyTorch/built-in/foundation/CogVLM/README.md | 62 ++ .../CogVLM/cogvlm_utils/eva_clip_model.py | 159 ++++ .../CogVLM/cogvlm_utils/finetune_cogvlm.sh | 57 ++ .../cogvlm_utils/finetune_cogvlm_lora.sh | 59 ++ .../CogVLM/cogvlm_utils/inference.py | 33 + .../foundation/CogVLM/cogvlm_utils/mixin.py | 294 +++++++ .../CogVLM/cogvlm_utils/modeling_cogvlm.py | 832 ++++++++++++++++++ .../CogVLM/cogvlm_utils/rotary_embeddings.py | 128 +++ .../cogvlm_utils/triton_rotary_embeddings.py | 1 + .../foundation/CogVLM/cogvlm_utils/visual.py | 189 ++++ 10 files changed, 1814 insertions(+) create mode 100644 PyTorch/built-in/foundation/CogVLM/README.md create mode 100644 PyTorch/built-in/foundation/CogVLM/cogvlm_utils/eva_clip_model.py create mode 100644 PyTorch/built-in/foundation/CogVLM/cogvlm_utils/finetune_cogvlm.sh create mode 100644 PyTorch/built-in/foundation/CogVLM/cogvlm_utils/finetune_cogvlm_lora.sh create mode 100644 PyTorch/built-in/foundation/CogVLM/cogvlm_utils/inference.py create mode 100644 PyTorch/built-in/foundation/CogVLM/cogvlm_utils/mixin.py create mode 100644 PyTorch/built-in/foundation/CogVLM/cogvlm_utils/modeling_cogvlm.py create mode 100644 PyTorch/built-in/foundation/CogVLM/cogvlm_utils/rotary_embeddings.py create mode 100644 PyTorch/built-in/foundation/CogVLM/cogvlm_utils/triton_rotary_embeddings.py create mode 100644 PyTorch/built-in/foundation/CogVLM/cogvlm_utils/visual.py diff --git a/PyTorch/built-in/foundation/CogVLM/README.md b/PyTorch/built-in/foundation/CogVLM/README.md new file mode 100644 index 0000000000..5f04469f84 --- /dev/null +++ b/PyTorch/built-in/foundation/CogVLM/README.md @@ -0,0 +1,62 @@ +一、CogVLM官网地址:https://github.com/THUDM/CogVLM + +二、三方件安装 + +1、CogVLM官方项目下 pip install -r requirements.txt + +2、下载并安装en_core_web_sm-any-py3-none-any.whl :https://huggingface.co/spacy/en_core_web_sm/tree/main + +三、三方件文件替换: + +找到自己三方件安装路径,例如:xxx/xxx/lib/python3.8/site-packages + +sat/model/position_embedding/triton_rotary_embeddings.py 替换为model_zoo项目下cogvlm_utils/triton_rotary_embeddings.py + +四、cogvlm项目文件替换: + +1、utils/models/eva_clip_model.py 替换为model_zoo项目下 cogvlm_utils/eva_clip_model.py + +2、utils/models/mixin.py 替换为 model_zoo项目下 cogvlm_utils/mixin.py + +五、权重文件下载 + +1、微调权重下载:https://huggingface.co/THUDM/CogVLM/tree/main 下载cogvlm-base-224.zip + +2、推理权重下载:https://huggingface.co/THUDM/cogvlm-base-224-hf/tree/main + +3、分词器权重下载:https://huggingface.co/lmsys/vicuna-7b-v1.5/tree/main + + +六、推理文件修改: + +1、modeling_cogvlm.py 替换为项目下cogvlm_utils/modeling_cogvlm.py 和 cogvlm_utils/rotary_embeddings.py + +2、visual.py 替换为项目下 cogvlm_utils/visual.py + + +七、微调数据下载: + +下载路径:https://www.kaggle.com/datasets/aadhavvignesh/captcha-images + +根据CogVLM官网要求,处理数据 python utils/split_dataset.py + +八、执行微调 + +微调 + +1、cd finetune_demo + +2、bash finetune_cogvlm.sh + +lora微调: + +1、cd finetune_demo + +2、bash finetune_cogvlm_lora.sh + + +九、推理 + +新增inference.py 到 CogVLM项目finetune_demo文件夹下 + +执行python inference.py \ No newline at end of file diff --git a/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/eva_clip_model.py b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/eva_clip_model.py new file mode 100644 index 0000000000..c05bdca722 --- /dev/null +++ b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/eva_clip_model.py @@ -0,0 +1,159 @@ +import torch +import torch_npu +from sat.model.base_model import BaseModel +from sat.model.mixins import BaseMixin +from sat.model.official.vit_model import ViTProperty, ImagePatchEmbeddingMixin, InterpolatedPositionEmbeddingMixin, gelu +from sat import mpu + +class FlashSelfAttention(torch.nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__(self, causal=False, softmax_scale=1., attention_dropout=0.): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, q, k, v, n, attention_mask, pse): + + if self.causal: + output = torch_npu.npu_fusion_attention( + q, k, v, n, "BSND",# SBH + pse=pse, + padding_mask=None, + atten_mask=attention_mask, + scale=self.softmax_scale, + pre_tockens=k.shape[0], # seq_len + next_tockens=0, # 0 + keep_prob=1 - self.dropout_p, + )[0] + return output + raise Exception("the attention type {} is not support!".format(self.attention_type)) + +class IdentityMixin(BaseMixin): + def __init__(self): + super().__init__() + + def final_forward(self, logits, **kwargs): + return logits[:, 1:] + +import xformers.ops as xops +class XAttn(BaseMixin): + def __init__(self, head_dim): + super().__init__() + self.scale = head_dim ** -0.5 + self.core_attention_flash = FlashSelfAttention( + causal=True, softmax_scale=self.scale, attention_dropout=0.) + + def attention_fn(self, query_layer, key_layer, value_layer, attention_mask, + attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs): + dropout_p = 0. # xformers does not support dropout for eva hidden size + + query_layer = query_layer.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C + key_layer = key_layer.permute(0, 2, 1, 3) + value_layer = value_layer.permute(0, 2, 1, 3) + # 替换为NPU 的FA + out = self.core_attention_flash(query_layer, key_layer, value_layer, query_layer.shape[2], None, None) + return out + + def attention_forward(self, hidden_states, mask, **kw_args): + self = self.transformer.layers[kw_args['layer_id']].attention + attention_fn = self.hooks['attention_fn'] + + mixed_raw_layer = self.query_key_value(hidden_states) + + B, N, C = hidden_states.shape + mixed_raw_layer = mixed_raw_layer.reshape(B, N, 3, self.num_attention_heads_per_partition, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C + query_layer, key_layer, value_layer = mixed_raw_layer[0], mixed_raw_layer[1], mixed_raw_layer[2] + + dropout_fn = self.attention_dropout if self.training else None + + context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args) + + context_layer = context_layer.view(B, N, -1) + output = self.dense(context_layer) + + if self.training: + output = self.output_dropout(output) + return output + +class NewLayerForward(BaseMixin): + def __init__(self): + super().__init__() + + def layer_forward(self, hidden_states, mask, *args, **kw_args): + ''' + hidden_states: [batch, seq_len, hidden_size] + mask: [(1, 1), seq_len, seq_len] + ''' + self = self.transformer.layers[kw_args['layer_id']] + + attention_input = hidden_states + + # Self attention. + attention_output = self.input_layernorm(self.attention(attention_input, mask, **kw_args)) + + # DropPath for attention + if self.training and self.drop_path > 0.: + if mpu.get_cuda_rng_tracker is not None: + # drop_path must use model parallel rng tracker + # the tracker is initialized as seed of `seed + model_parallel_rank` + # deepspeed act-ckpt record the model parallel tracker states + with mpu.get_cuda_rng_tracker().fork(): + # drop_path percentage 0, others 1/(1-p) + random_tensor = (1-self.drop_path + + torch.rand((attention_output.shape[0],), dtype=attention_output.dtype, device=attention_output.device)).floor_() / (1-self.drop_path) + attention_output = random_tensor.view(-1, 1, 1) * attention_output + + # Residual connection. + hidden_states = attention_input + attention_output + mlp_input = hidden_states + + # MLP. + mlp_output = self.post_attention_layernorm(self.mlp(mlp_input, **kw_args)) + + # DropPath for mlp + if self.training and self.drop_path > 0.: + if mpu.get_cuda_rng_tracker is not None: + with mpu.get_cuda_rng_tracker().fork(): + random_tensor = (1-self.drop_path + + torch.rand((mlp_output.shape[0],), dtype=mlp_output.dtype, device=mlp_output.device)).floor_() / (1-self.drop_path) + mlp_output = random_tensor.view(-1, 1, 1) * mlp_output + + # Second residual connection. + output = mlp_input + mlp_output + + return output + +class EVA2CLIPModel(BaseModel): + def __init__(self, args, transformer=None, parallel_output=True, **kwargs): + property = ViTProperty(args.image_size, args.patch_size, args.pre_len, args.post_len) + args.max_sequence_length = property.pre_len + property.num_patches + property.post_len + if 'activation_func' not in kwargs: + kwargs['activation_func'] = gelu + super().__init__(args, transformer=transformer, parallel_output=parallel_output, **kwargs) + self.transformer.property = property + self.add_mixin("patch_embedding", ImagePatchEmbeddingMixin(args.in_channels, args.hidden_size, property)) + self.add_mixin("pos_embedding", InterpolatedPositionEmbeddingMixin()) + self.add_mixin("final", IdentityMixin()) + self.add_mixin("newpost", NewLayerForward()) + self.add_mixin("xattn", XAttn(args.hidden_size // args.num_attention_heads)) + + @classmethod + def add_model_specific_args(cls, parser): + group = parser.add_argument_group('EVA2CLIP', 'EVA2CLIP Configurations') + group.add_argument('--image-size', nargs='+', type=int, default=[224, 224]) + group.add_argument('--pre-len', type=int, default=1) # [cls] by default + group.add_argument('--post-len', type=int, default=0) # empty by default, but sometimes with special tokens, such as [det] in yolos. + group.add_argument('--in-channels', type=int, default=3) + group.add_argument('--patch-size', type=int, default=16) + return parser + diff --git a/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/finetune_cogvlm.sh b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/finetune_cogvlm.sh new file mode 100644 index 0000000000..f060d5dd02 --- /dev/null +++ b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/finetune_cogvlm.sh @@ -0,0 +1,57 @@ +#! /bin/bash +# export PATH=/usr/local/cuda/bin:$PATH +# export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH + +NUM_GPUS_PER_WORKER=8 +MP_SIZE=1 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) +main_dir=$(dirname $script_dir) +MODEL_TYPE="微调权重路径" +VERSION="base" +MODEL_ARGS="--from_pretrained $MODEL_TYPE \ + --max_length 1288 \ + --local_tokenizer 分析器权重路径 \ + --version $VERSION" +# Tips: If training models of resolution 244, you can set --max_length smaller + +OPTIONS_SAT="SAT_HOME=~/.sat_models" +OPTIONS_NCCL="HCCL_DEBUG=info HCCL_IB_DISABLE=0 HCCL_NET_GDR_LEVEL=2 LOCAL_WORLD_SIZE=$NUM_GPUS_PER_WORKER" +HOST_FILE_PATH="hostfile" + +train_data="./archive_split/train" +valid_data="./archive_split/valid" + +gpt_options=" \ + --experiment-name finetune-$MODEL_TYPE \ + --model-parallel-size ${MP_SIZE} \ + --mode finetune \ + --train-iters 800 \ + --resume-dataloader \ + $MODEL_ARGS \ + --train-data ${train_data} \ + --valid-data ${valid_data} \ + --distributed-backend hccl \ + --lr-decay-style cosine \ + --warmup .02 \ + --checkpoint-activations \ + --vit_checkpoint_activations \ + --save-interval 200 \ + --eval-interval 200 \ + --save "./checkpoints" \ + --eval-iters 10 \ + --eval-batch-size 1 \ + --split 1. \ + --deepspeed_config test_config_bf16.json \ + --skip-init \ + --seed 2023 +" + + + +run_cmd="${OPTIONS_NCCL} deepspeed --master_port 16666 --hostfile ${HOST_FILE_PATH} finetune_cogvlm_demo.py ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/finetune_cogvlm_lora.sh b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/finetune_cogvlm_lora.sh new file mode 100644 index 0000000000..1f26c35fdb --- /dev/null +++ b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/finetune_cogvlm_lora.sh @@ -0,0 +1,59 @@ +#! /bin/bash +# export PATH=/usr/local/cuda/bin:$PATH +# export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH + +NUM_GPUS_PER_WORKER=8 +MP_SIZE=1 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) +main_dir=$(dirname $script_dir) +MODEL_TYPE="微调权重路径" +VERSION="base" +MODEL_ARGS="--from_pretrained $MODEL_TYPE \ + --max_length 1288 \ + --lora_rank 10 \ + --use_lora \ + --local_tokenizer 分析器权重路径 \ + --version $VERSION" +# Tips: If training models of resolution 244, you can set --max_length smaller + +OPTIONS_SAT="SAT_HOME=~/.sat_models" +OPTIONS_NCCL="HCCL_DEBUG=info HCCL_IB_DISABLE=0 HCCL_NET_GDR_LEVEL=2 LOCAL_WORLD_SIZE=$NUM_GPUS_PER_WORKER" +HOST_FILE_PATH="hostfile" + +train_data="./archive_split/train" +valid_data="./archive_split/valid" + +gpt_options=" \ + --experiment-name finetune-$MODEL_TYPE \ + --model-parallel-size ${MP_SIZE} \ + --mode finetune \ + --train-iters 800 \ + --resume-dataloader \ + $MODEL_ARGS \ + --train-data ${train_data} \ + --valid-data ${valid_data} \ + --distributed-backend hccl \ + --lr-decay-style cosine \ + --warmup .02 \ + --checkpoint-activations \ + --vit_checkpoint_activations \ + --save-interval 200 \ + --eval-interval 200 \ + --save "./checkpoints" \ + --eval-iters 10 \ + --eval-batch-size 1 \ + --split 1. \ + --deepspeed_config test_config_bf16.json \ + --skip-init \ + --seed 2023 +" + + + +run_cmd="${OPTIONS_NCCL} deepspeed --master_port 16666 --hostfile ${HOST_FILE_PATH} finetune_cogvlm_demo.py ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/inference.py b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/inference.py new file mode 100644 index 0000000000..5809c76205 --- /dev/null +++ b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/inference.py @@ -0,0 +1,33 @@ +import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu +import requests +from PIL import Image +from transformers import AutoModelForCausalLM, LlamaTokenizer + +tokenizer = LlamaTokenizer.from_pretrained('分词器权重路径') +model = AutoModelForCausalLM.from_pretrained( + '推理权重路径', + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True +).to('cuda').eval() + +image = Image.open("图片路径").convert('RGB') +inputs = model.build_conversation_input_ids(tokenizer, query='How many people', images=[image]) +inputs = { + 'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'), + 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'), + 'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'), + 'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)]], +} + +gen_kwargs = {"max_length": 2048, "do_sample": False} + +with torch.no_grad(): + print("Begin inference") + outputs = model.generate(**inputs, **gen_kwargs) + print("Inference End") + outputs = outputs[:, inputs['input_ids'].shape[1]:] + response = tokenizer.decode(outputs[0]) + print("\nCog:", response) diff --git a/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/mixin.py b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/mixin.py new file mode 100644 index 0000000000..3be8822446 --- /dev/null +++ b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/mixin.py @@ -0,0 +1,294 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from sat.transformer_defaults import attention_fn_default +from sat.model.base_model import BaseMixin, non_conflict +from sat.mpu.layers import ColumnParallelLinear, RowParallelLinear +from sat.mpu.utils import split_tensor_along_last_dim +from sat import mpu + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[2]) + cos = torch.gather(cos.cpu().repeat(gather_indices.shape[0], 1, 1, 1).npu( + torch.npu.current_device()), 2, gather_indices) + sin = torch.gather(sin.cpu().repeat(gather_indices.shape[0], 1, 1, 1).npu( + torch.npu.current_device()), 2, gather_indices) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class LlamaVisionExpertFCMixin(BaseMixin): + def __init__(self, in_features, hidden_features, num_layers=32, num_vision_layers=0, vision_layer_range=None, + params_dtype=torch.float, device=torch.device('cpu')): + super().__init__() + + self.num_layers = num_layers + self.num_vision_layers = num_vision_layers + if vision_layer_range is None: + vision_layer_range = [i for i in range(min(num_vision_layers, num_layers))] + self.vision_layer_range = vision_layer_range + self.gate_proj = nn.ModuleList([ColumnParallelLinear( + in_features, + hidden_features, + gather_output=False, + init_method=None, + bias=False, + params_dtype=params_dtype, + module=self, + name="dense_h_to_4h_gate", + skip_init=True, + device=device + ) for i in range(num_layers)]) + # Trainable vision expert parameters + vision_dense_h_to_4h_list = [] + vision_dense_4h_to_h_list = [] + gate_proj_list = [] + + + for i in vision_layer_range: + vision_dense_h_to_4h = ColumnParallelLinear( + in_features, + hidden_features, + gather_output=False, + init_method=None, + bias=False, + params_dtype=params_dtype, + module=self, + name="vision_dense_h_to_4h", + skip_init=True, + device=device + ) + + # Project back to h. + vision_dense_4h_to_h = RowParallelLinear( + hidden_features, + in_features, + input_is_parallel=True, + init_method=None, + bias=False, + params_dtype=params_dtype, + module=self, + name="vision_dense_4h_to_h", + skip_init=True, + device=device + ) + + gate_proj = ColumnParallelLinear( + in_features, + hidden_features, + gather_output=False, + init_method=None, + bias=False, + params_dtype=params_dtype, + module=self, + name="vision_gate_proj", + skip_init=True, + device=device + ) + + vision_dense_h_to_4h_list.append(vision_dense_h_to_4h) + vision_dense_4h_to_h_list.append(vision_dense_4h_to_h) + gate_proj_list.append(gate_proj) + + self.vision_dense_h_to_4h_list = nn.ModuleDict([ + (str(layer_id), vision_dense_h_to_4h) + for layer_id, vision_dense_h_to_4h in zip(vision_layer_range, vision_dense_h_to_4h_list) + ]) + self.vision_dense_4h_to_h_list = nn.ModuleDict([ + (str(layer_id), vision_dense_4h_to_h) + for layer_id, vision_dense_4h_to_h in zip(vision_layer_range, vision_dense_4h_to_h_list) + ]) + self.vision_gate_proj = nn.ModuleDict([ + (str(layer_id), gate_proj) + for layer_id, gate_proj in zip(vision_layer_range, gate_proj_list) + ]) + + def mlp_forward(self, hidden_states, **kw_args): + mixin_self = self + self = self.transformer.layers[kw_args['layer_id']].mlp + if "vision_expert_mask" in kw_args: + vision_expert_mask = kw_args['vision_expert_mask'] + else: + vision_expert_mask = None + + layer_id_key = str(int(kw_args['layer_id'])) + + if kw_args['layer_id'] in mixin_self.vision_layer_range and (vision_expert_mask is not None) and vision_expert_mask.any(): + vision_dense_h_to_4h = mixin_self.vision_dense_h_to_4h_list[layer_id_key] + vision_dense_4h_to_h = mixin_self.vision_dense_4h_to_h_list[layer_id_key] + vision_gate_proj = mixin_self.vision_gate_proj[layer_id_key] + output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device) + + language_hidden_state = hidden_states[~vision_expert_mask.bool()] + language_intermediate_parallel = self.activation_func(mixin_self.gate_proj[kw_args['layer_id']](language_hidden_state)) * self.dense_h_to_4h(language_hidden_state) + output[~vision_expert_mask.bool()] = self.dense_4h_to_h(language_intermediate_parallel) # language_output + + vision_hidden_state = hidden_states[vision_expert_mask.bool()] + vision_intermediate_parallel = vision_dense_h_to_4h(vision_hidden_state) + gate_output = vision_gate_proj(vision_hidden_state) + + vision_intermediate_parallel *= self.activation_func(gate_output) + output[vision_expert_mask.bool()] = vision_dense_4h_to_h(vision_intermediate_parallel) # vision_output + else: + intermediate_parallel = self.activation_func(mixin_self.gate_proj[kw_args['layer_id']](hidden_states)) * self.dense_h_to_4h(hidden_states) + output = self.dense_4h_to_h(intermediate_parallel) + + return output.contiguous() + + def copy_param(self): + with torch.no_grad(): + for i in self.vision_layer_range: + self.vision_gate_proj[str(i)].weight.data.copy_(self.gate_proj[i].weight.data) + self.vision_dense_4h_to_h_list[str(i)].weight.data.copy_(self.transformer.layers[i].mlp.dense_4h_to_h.weight.data) + self.vision_dense_h_to_4h_list[str(i)].weight.data.copy_(self.transformer.layers[i].mlp.dense_h_to_4h.weight.data) + +from sat.mpu import get_model_parallel_world_size +from sat.mpu.utils import divide +from sat.model.position_embedding.triton_rotary_embeddings import FastRotaryEmbedding + +class LlamaVisionExpertAttnMixin(BaseMixin): + def __init__(self, hidden_size, num_heads, num_layers=28, num_vision_layers=0, use_vision_expert=True, vision_layer_range=None, + params_dtype=torch.float, device=torch.device('cpu')): + super().__init__() + + world_size = get_model_parallel_world_size() + self.hidden_size = hidden_size + self.num_attention_heads = num_heads + self.hidden_size_per_attention_head = divide(hidden_size, num_heads) + self.num_attention_heads_per_partition = divide(num_heads, world_size) + self.inner_hidden_size = num_heads * self.hidden_size_per_attention_head + + self.rotary_emb = FastRotaryEmbedding( + hidden_size // num_heads + ) + + self.num_vision_layers = num_vision_layers + self.num_layers = num_layers + if vision_layer_range is None: + vision_layer_range = [i for i in range(min(num_vision_layers, num_layers))] + self.vision_layer_range = vision_layer_range + + self.use_vision_expert = use_vision_expert + # Trainable vision expert parameters + + if self.use_vision_expert: + vision_query_key_value_list = [] + vision_dense_list = [] + for i in vision_layer_range: + vision_query_key_value = ColumnParallelLinear( + hidden_size, + 3 * hidden_size, + stride=3, + gather_output=False, + init_method=None, + bias=False, + params_dtype=params_dtype, + module=self, + name="vision_query_key_value", + skip_init=True, + device=device + ) + + vision_dense = RowParallelLinear( + self.inner_hidden_size, + hidden_size, + input_is_parallel=True, + init_method=None, + bias=False, + params_dtype=params_dtype, + module=self, + name="vision_dense", + skip_init=True, + device=device, + final_bias=False + ) + + vision_query_key_value_list.append(vision_query_key_value) + vision_dense_list.append(vision_dense) + + self.vision_query_key_value_list = nn.ModuleDict([ + (str(layer_id), vision_query_key_value) + for layer_id, vision_query_key_value in zip(vision_layer_range, vision_query_key_value_list) + ]) + self.vision_dense_list = nn.ModuleDict([ + (str(layer_id), vision_dense) + for layer_id, vision_dense in zip(vision_layer_range, vision_dense_list) + ]) + + def attention_forward(self, hidden_states, mask, **kw_args): + mixin_self = self + self = self.transformer.layers[kw_args['layer_id']].attention + attention_fn = attention_fn_default + if 'attention_fn' in self.hooks: + attention_fn = self.hooks['attention_fn'] + if "vision_expert_mask" in kw_args: + vision_expert_mask = kw_args['vision_expert_mask'] + else: + vision_expert_mask = None + + layer_id_key = str(int(kw_args['layer_id'])) + if mixin_self.use_vision_expert and kw_args['layer_id'] in mixin_self.vision_layer_range and ( + vision_expert_mask is not None) and vision_expert_mask.any(): + shape = list(hidden_states.shape) + parallel_size = mpu.get_model_parallel_world_size() + shape[-1] = shape[-1] * 3 // parallel_size + vision_query_key_value = mixin_self.vision_query_key_value_list[layer_id_key] + mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device) + language_hidden_states = hidden_states[~vision_expert_mask.bool()] + vision_hidden_states = hidden_states[vision_expert_mask.bool()] + mixed_raw_layer[~vision_expert_mask.bool()] = self.query_key_value( + language_hidden_states) # language_mixed_raw_layer + mixed_raw_layer[vision_expert_mask.bool()] = vision_query_key_value( + vision_hidden_states) # vision_mixed_raw_layer + else: + mixed_raw_layer = self.query_key_value(hidden_states) + + (mixed_query_layer, + mixed_key_layer, + mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3) + + dropout_fn = self.attention_dropout if self.training else None + + query_layer = self._transpose_for_scores(mixed_query_layer) + key_layer = self._transpose_for_scores(mixed_key_layer) + value_layer = self._transpose_for_scores(mixed_value_layer) + kv_seq_len = key_layer.shape[-2] + cos, sin = mixin_self.rotary_emb(value_layer, seq_len=kv_seq_len) + query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, kw_args['position_ids']) + + #query_layer, key_layer = mixin_self.rotary_emb(query_layer,key_layer, kw_args['position_ids'], max_seqlen=kw_args['position_ids'].max()+1, layer_id=kw_args['layer_id']) + + context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + if mixin_self.use_vision_expert and kw_args['layer_id'] in mixin_self.vision_layer_range and ( + vision_expert_mask is not None) and vision_expert_mask.any(): + vision_dense = mixin_self.vision_dense_list[layer_id_key] + parallel_size = mpu.get_model_parallel_world_size() + target_shape = context_layer.shape[:-1] + (context_layer.shape[-1] * parallel_size,) + output = torch.empty(target_shape, dtype=hidden_states.dtype, device=hidden_states.device) + output[~vision_expert_mask.bool()] = self.dense(context_layer[~vision_expert_mask.bool()]) # language + output[vision_expert_mask.bool()] = vision_dense(context_layer[vision_expert_mask.bool()]) # vision + else: + output = self.dense(context_layer) + + if self.training: + output = self.output_dropout(output) + return output.contiguous() + + def copy_param(self): + with torch.no_grad(): + for i in self.vision_layer_range: + self.vision_query_key_value_list[str(i)].weight.data.copy_(self.transformer.layers[i].attention.query_key_value.weight.data) + self.vision_dense_list[str(i)].weight.data.copy_(self.transformer.layers[i].attention.dense.weight.data) diff --git a/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/modeling_cogvlm.py b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/modeling_cogvlm.py new file mode 100644 index 0000000000..7fa1dd1e41 --- /dev/null +++ b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/modeling_cogvlm.py @@ -0,0 +1,832 @@ +"""largely copy from llama and adapt for cogvlm""" +import warnings +from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any + +import math +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torchvision import transforms +from einops import rearrange + +from transformers import PreTrainedModel, PreTrainedTokenizer +from transformers.utils.logging import get_logger +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast + +from .configuration_cogvlm import CogVLMConfig +#from .util import FastRotaryEmbedding +from .rotary_embeddings import RotaryEmbedding as FastRotaryEmbedding +from .visual import EVA2CLIPModel + +if TYPE_CHECKING: + from transformers.utils import ModelOutput + +logger = get_logger(__name__) + +LANGUAGE_TOKEN_TYPE = 0 +VISION_TOKEN_TYPE = 1 + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight * hidden_states).to(input_dtype) + + +class MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]": + vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool) + vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] == VISION_TOKEN_TYPE) + language_token_mask = ~vision_token_mask + return vision_token_mask, language_token_mask + + +class VisionExpertMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.language_mlp = MLP(config) + self.vision_mlp = MLP(config) + + def forward(self, hidden_states: "torch.Tensor(B, L, D)", token_type_ids: "torch.LongTensor(B, L)"): + output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device) + vision_token_mask, language_token_mask = get_expert_mask(token_type_ids) + output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask]) + output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask]) + return output + + +def attention_fn( + query_layer: "torch.tensor(B, H, L, HD)", + key_layer: "torch.tensor(B, H, L, HD)", + value_layer: "torch.tensor(B, H, L, HD)", + attention_mask: "torch.tensor(B, H, L, HD)", + *, + scaling_attention_score: bool = True, + attention_dropout: nn.Module = None +): + attention_mask_bool = (attention_mask == 0) + is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all() + is_full = (attention_mask_bool > 0).all() + if not (int(torch.__version__.split('.')[0]) >= 2): + warnings.warn("It's recommended to use torch2.0 or higher.") + if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle): + dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p + return torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, + attn_mask=None, + dropout_p=dropout_p, + is_causal=not is_full + ) + else: + if scaling_attention_score: + query_layer = query_layer / math.sqrt(query_layer.shape[-1]) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores + attention_mask + attention_scores = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype) + if attention_dropout is not None: + attention_scores = attention_dropout(attention_scores) + context_layer = torch.matmul(attention_scores, value_layer) + return context_layer + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[2]) + cos = torch.gather(cos.cpu().repeat(gather_indices.shape[0], 1, 1, 1).npu( + torch.npu.current_device()), 2, gather_indices) + sin = torch.gather(sin.cpu().repeat(gather_indices.shape[0], 1, 1, 1).npu( + torch.npu.current_device()), 2, gather_indices) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class VisionExpertAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_position_embeddings + + #self.rotary_emb = RotaryEmbedding(self.hidden_size // self.num_heads) + #self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False) + self.rotary_emb = FastRotaryEmbedding(self.hidden_size // self.num_heads) + self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False) + self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False) + self.language_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + def _transpose_for_scores(self, tensor): + """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD].""" + new_tensor_shape = tensor.size()[:-1] + (self.num_heads, self.head_dim) + tensor = tensor.view(*new_tensor_shape) + return tensor.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + token_type_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + vision_token_mask, language_token_mask = get_expert_mask(token_type_ids) + + shape = list(hidden_states.shape) + shape[-1] = shape[-1] * 3 + mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device) + mixed_raw_layer[vision_token_mask] = self.vision_expert_query_key_value(hidden_states[vision_token_mask]) + mixed_raw_layer[language_token_mask] = self.language_expert_query_key_value(hidden_states[language_token_mask]) + + query_states, key_states, value_states = torch.split(mixed_raw_layer, self.hidden_size, dim=-1) + query_states = self._transpose_for_scores(query_states) # B, H, L, HD + key_states = self._transpose_for_scores(key_states) # B, H, L, HD + value_states = self._transpose_for_scores(value_states) # B, H, L, HD + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + #torch.save(query_states.cpu(), "/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/query_states_rotary_atten_gpu.pth") + #torch.save(key_states.cpu(), "/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/key_states_rotary_atten_gpu.pth") + #torch.save(value_states.cpu(), "/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/value_states_rotary_atten_gpu.pth") + #breakpoint() + #print(f"==========query_states0:{query_states.shape}=========") + #print(f"===========key_states0:{key_states.shape}============") + #print(f"===========value_states0:{value_states.shape}=============") + #query_states = torch.load("/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/query_states_before_rotary_atten_gpu.pth", map_location="cpu").to(device=query_states.device, dtype=query_states.dtype) + #key_states = torch.load("/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/key_states_before_rotary_atten_gpu.pth", map_location="cpu").to(device=key_states.device, dtype=key_states.dtype) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + #query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids, max_seqlen=position_ids.max() + 1) + #torch.save(query_states.cpu(), "/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/query_states_rotary_after_atten_npu.pth") + #torch.save(key_states.cpu(), "/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/key_states_rotary_after_atten_npu.pth") + #breakpoint() + #print(f"==========query_states1:{query_states.shape}=========") + #print(f"===========key_states1:{key_states.shape}============") + #print(f"===========value_states1:{value_states.shape}=============") + #query_states = torch.load("/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/query_states_rotary_after_atten_gpu.pth", map_location="cpu").to(device=query_states.device, dtype=query_states.dtype) + #key_states = torch.load("/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/key_states_rotary_after_atten_gpu.pth", map_location="cpu").to(device=key_states.device, dtype=key_states.dtype) + #value_states = torch.load("/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/value_states_rotary_after_atten_gpu.pth", map_location="cpu").to(device=value_states.device, dtype=value_states.dtype) + #print(f"==========query_states2:{query_states.shape}=========") + #print(f"===========key_states2:{key_states.shape}============") + #print(f"===========value_states2:{value_states.shape}=============") + #print(f"============qkv==================") + + if past_key_value is not None: + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + context_layer = attention_fn( + query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask, + scaling_attention_score=True, attention_dropout=None) + if context_layer.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {context_layer.size()}" + ) + context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size) + + attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device) + attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask]) + attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask]) + + if output_attentions: + warnings.warn("output_attentions is not implemented.") + + return attn_output, None, past_key_value + + +class CogVLMDecoderLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = VisionExpertAttention(config=config) + self.mlp = VisionExpertMLP(config) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + token_type_ids: torch.LongTensor, + position_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, token_type_ids=token_type_ids) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs # type: ignore + + +class CogVLMPreTrainedModel(PreTrainedModel): + config_class = CogVLMConfig + base_model_prefix = "model" + supports_gradient_checkpointing = False + _no_split_modules = ["CogVLMDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +def is_empty(images_list: Optional[List[List[torch.Tensor]]]): + if images_list is None or len(images_list) == 0: + return True + for image_list in images_list: + if len(image_list): + return False + return True + + +def build_position_ids(x: "torch.BoolTensor(B, L)", attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)": + if attention_mask is not None: + tmp = x.clone() + tmp[~(attention_mask.bool())] = -1 + else: + tmp = x.clone() + # image boi eoi token as LANGUAGE_TOKEN_TYPE + is_boi_eoi = torch.zeros_like(x, dtype=torch.bool) + is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE) + is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE) + is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) + is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE) + tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE + # final position ids + y = torch.zeros_like(x, dtype=torch.long) + y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ((tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)) + y = y.cumsum(dim=-1) + return y + + +class CogVLMModel(CogVLMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([CogVLMDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.vision = EVA2CLIPModel(config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def encode_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor: + images_list, images = images, [] + + images = [] + for image_list in images_list: + for image in image_list: + images.append(image) + + images = torch.stack(images) + images_features = self.vision(images) + return images_features + + def forward( + self, + input_ids: torch.LongTensor = None, + images: List[List[torch.Tensor]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)""" + + if past_key_values is not None: + pass # generate mode with past_key_values. the image features are already mapped + else: + # not allow for inputs_embeds, because we want to process image feature + assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}" + if not is_empty(images): # multi-modality + assert token_type_ids is not None, f"multi-modality requires `token_type_ids`!" + assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}" + inputs_embeds = self.embed_tokens(input_ids) + images_features = self.encode_images(images) + images_features = rearrange(images_features, 'b n d -> (b n) d') + images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) + inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features) + else: # single-modality + if token_type_ids is None: + token_type_ids = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) * LANGUAGE_TOKEN_TYPE + assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}" + inputs_embeds = self.embed_tokens(input_ids) + + if position_ids is None: + position_ids = build_position_ids(token_type_ids, attention_mask) + input_ids = None + + return self.llm_forward( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + def llm_forward( + self, + input_ids: torch.LongTensor = None, + token_type_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """largely copy from llama forward and adapt for cogvlm with `token_type_ids`""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + layer_outputs = decoder_layer( + hidden_states, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # noinspection PyMethodMayBeStatic + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + +def chat_history_to_prompt(history, query): + prompt = " [INST] " + for i, (old_query, response) in enumerate(history): + prompt += old_query + " [/INST] " + response + " [INST] " + prompt += query + " [/INST] " + return prompt + + +def base_history_to_prompt(history, query): + prompt = query + return prompt + + +_history_to_prompt = { + "base": base_history_to_prompt, + "chat": chat_history_to_prompt +} + + +class CogVLMForCausalLM(CogVLMPreTrainedModel): + _auto_class = "AutoModelForCausalLM" + + def __init__(self, config): + super().__init__(config) + self.model = CogVLMModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + images: List[List[torch.Tensor]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + images=images, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def _prepare_attention_mask_for_generation( + self, + inputs: torch.Tensor, + pad_token_id: Optional[int], + eos_token_id: Optional[Union[int, List[int]]], + ) -> torch.LongTensor: + return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore + + def prepare_inputs_for_generation( + self, input_ids, token_type_ids, images=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + # build position_ids if needed + position_ids = kwargs.get("position_ids", None) + if position_ids is None: + position_ids = build_position_ids(token_type_ids, attention_mask) + + if past_key_values: + input_ids = input_ids[:, -1:] + token_type_ids = token_type_ids[:, -1:] + position_ids = position_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "token_type_ids": token_type_ids, + "images": images, + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + def _update_model_kwargs_for_generation( + self, + outputs: "ModelOutput", + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) + if getattr(outputs, "state", None) is not None: + model_kwargs["state"] = outputs.state + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype, device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE + model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1) + + if not is_encoder_decoder: + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + else: + # update decoder attention mask + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + model_kwargs["decoder_attention_mask"] = torch.cat( + [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))], + dim=-1, + ) + + return model_kwargs + + def _reorder_cache(self, past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + def build_conversation_input_ids( + self, + tokenizer: "PreTrainedTokenizer", + *, + query: str, + history: Optional[List[Tuple[str, str]]] = None, + images: Optional[List["PIL.Image"]] = None, + template_version: Optional[Literal["base", "chat"]] = None, + ): + image_size: int = self.config.vision_config['image_size'] + patch_size: int = self.config.vision_config['patch_size'] + template_version = template_version or self.config.template_version + assert images is None or len(images) <= 1, f"not support multi images by now." + history = history or [] + text = _history_to_prompt[template_version](history, query) + + input_ids = [tokenizer.bos_token_id] + token_type_ids = [LANGUAGE_TOKEN_TYPE] + if images is not None and len(images) == 1: + # vision + transform = transforms.Compose( + [ + transforms.Resize( + (image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ] + ) + images = [transform(images[0])] + # language + vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2 + input_ids += [tokenizer.pad_token_id] * vision_token_num + token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num + text_ids = tokenizer.encode(text, add_special_tokens=False) + + input_ids += text_ids + token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids) + attention_mask = [1] * len(input_ids) + + return { + 'input_ids': torch.tensor(input_ids, dtype=torch.long), + 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long), + 'attention_mask': torch.tensor(attention_mask, dtype=torch.long), + 'images': images, + } diff --git a/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/rotary_embeddings.py b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/rotary_embeddings.py new file mode 100644 index 0000000000..5d3715d249 --- /dev/null +++ b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/rotary_embeddings.py @@ -0,0 +1,128 @@ +# Extracted from: https://github.com/EleutherAI/gpt-neox +import torch +import torch.nn.functional as F + + +class RotaryEmbedding(torch.nn.Module): + + def __init__(self, dim, base=10000, precision=torch.half, learnable=False, device=torch.device('cpu')): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2, device=device).float() / dim)) + # inv_freq = inv_freq.half() + self.learnable = learnable + if learnable: + self.inv_freq = torch.nn.Parameter(inv_freq) + self.max_seq_len_cached = None + else: + self.register_buffer('inv_freq', inv_freq) + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + self.precision = precision + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + pass + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): + self.max_seq_len_cached = None if self.learnable else seq_len + t = torch.arange(seq_len, device=x.device, dtype=torch.float32) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + + # [sx, 1 (b * np), hn] + cos_cached = emb.cos()[:, None, :] + sin_cached = emb.sin()[:, None, :] + cos_cached = cos_cached.to(x.dtype) + sin_cached = sin_cached.to(x.dtype) + if self.learnable: + return cos_cached, sin_cached + self.cos_cached, self.sin_cached = cos_cached, sin_cached + return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + + +class RotaryPositionalEmbeddingFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, cos, sin): + import rotary_positional_embedding_cuda + + q_ = q.contiguous() + cos_ = cos.contiguous() + sin_ = sin.contiguous() + output = rotary_positional_embedding_cuda.forward(*q.shape, q_, cos_, sin_) + ctx.save_for_backward(cos_, sin_) + + return output + + @staticmethod + def backward(ctx, grad_output): + import rotary_positional_embedding_cuda + + cos_, sin_ = ctx.saved_tensors + grad_q = rotary_positional_embedding_cuda.backward(*grad_output.shape, grad_output, cos_, sin_) + + return grad_q, None, None + +# rotary pos emb helpers: + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + + +@torch.jit.script +def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): + cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...] + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + + +def apply_rotary_pos_emb_torch(q, k, cos, sin, offset: int = 0): # jitting fails with bf16 + cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...] + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + + +def apply_rotary_pos_emb_fused(q, k, cos, sin, offset: int = 0): + cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...] + q = RotaryPositionalEmbeddingFunction.apply(q, cos, sin) + k = RotaryPositionalEmbeddingFunction.apply(k, cos, sin) + return q, k + + +@torch.jit.script +def apply_rotary_pos_emb_index_single(q, cos, sin, position_id): + # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ + F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) + return (q * cos) + (rotate_half(q) * sin) + + +@torch.jit.script +def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): + # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ + F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) + q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + return q, k + + +def apply_rotary_pos_emb_index_torch(q, k, cos, sin, position_id): # jitting fails with bf16 + # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ + F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) + q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + return q, k + + +def apply_rotary_pos_emb_index_fused(q, k, cos, sin, position_id): + # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ + F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) + q = RotaryPositionalEmbeddingFunction.apply(q, cos, sin) + k = RotaryPositionalEmbeddingFunction.apply(k, cos, sin) + return q, k diff --git a/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/triton_rotary_embeddings.py b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/triton_rotary_embeddings.py new file mode 100644 index 0000000000..411b0343f8 --- /dev/null +++ b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/triton_rotary_embeddings.py @@ -0,0 +1 @@ +from .rotary_embeddings import RotaryEmbedding as FastRotaryEmbedding diff --git a/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/visual.py b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/visual.py new file mode 100644 index 0000000000..6567e01fff --- /dev/null +++ b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/visual.py @@ -0,0 +1,189 @@ +import torch +import torch_npu +from torch import nn +from argparse import Namespace +import xformers.ops as xops +from transformers.activations import ACT2FN + +class FlashSelfAttention(torch.nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__(self, causal=False, softmax_scale=1., attention_dropout=0.): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, q, k, v, n, attention_mask, pse): + + if self.causal: + output = torch_npu.npu_fusion_attention( + q, k, v, n, "BSND",# SBH + pse=pse, + padding_mask=None, + atten_mask=attention_mask, + scale=self.softmax_scale, + pre_tockens=k.shape[0], # seq_len + next_tockens=0, # 0 + keep_prob=1 - self.dropout_p, + )[0] + return output + raise Exception("the attention type {} is not support!".format(self.attention_type)) + + +class PatchEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size) + self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size) + + def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)": + x = self.proj(images) + x = x.flatten(2).transpose(1, 2) + cls_token = self.cls_embedding.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + x += self.position_embedding.weight.unsqueeze(0) + return x + + +class Attention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_heads = config.num_heads + head_dim = config.hidden_size // config.num_heads + self.scale = head_dim ** -0.5 + self.query_key_value = nn.Linear(config.hidden_size, config.hidden_size * 3) + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.output_dropout = torch.nn.Dropout(config.dropout_prob) + self.core_attention_flash = FlashSelfAttention( + causal=True, softmax_scale=self.scale, attention_dropout=config.dropout_prob + ) + + def forward(self, x: "tensor(B, L, D)") -> "tensor(B, L, D)": + B, L, _ = x.shape + qkv = self.query_key_value(x) + qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) # 3, B, L, H, D + q, k, v = qkv[0], qkv[1], qkv[2] + + #out = xops.memory_efficient_attention( + # q, k, v, scale=self.scale, + #) + #torch.save(q.cpu(), "/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/q_atten_npu.pth") + #torch.save(k.cpu(), "/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/k_atten_npu.pth") + #torch.save(v.cpu(), "/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/v_atten_npu.pth") + #q = torch.load("/home/gpt_neox/d00816453/cogvlm/attn_tensor_224_gpu/q_atten.pth", map_location="cpu").to(device=q.device, dtype=q.dtype) + #k = torch.load("/home/gpt_neox/d00816453/cogvlm/attn_tensor_224_gpu/k_atten.pth", map_location="cpu").to(device=k.device, dtype=k.dtype) + #v = torch.load("/home/gpt_neox/d00816453/cogvlm/attn_tensor_224_gpu/v_atten.pth", map_location="cpu").to(device=v.device, dtype=v.dtype) + #print(f"=======q shape:{q.shape}=============") + #print(f"=======k shape:{k.shape}==========") + #print(f"=======v shape:{v.shape}========") + out = self.core_attention_flash(q, k, v, self.num_heads, None, None) + #out = self.attention(q, k ,v) + #torch.save(out.cpu(), "/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/npu_out_atten.pth") + #print(f"=======out shape:{out.shape}==============") + #print(f"=======out:{out}===========") + #breakpoint() + output = self.dense(out.view(B, L, -1)) + output = self.output_dropout(output) + return output + + def attention(self, q, k, v): + attn_weights = torch.matmul(q * self.scale, k.transpose(-2, -1)) + attn_weights = attn_weights.softmax(dim=-1) + output = torch.matmul(attn_weights, v) + return output + + +class MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = self.activation_fn(x) + x = self.fc2(x) + return x + + +class TransformerLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = Attention(config) + self.mlp = MLP(config) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + attention_input = hidden_states + attention_output = self.input_layernorm(self.attention(attention_input)) + hidden_states = attention_input + attention_output + mlp_input = hidden_states + mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)) + output = mlp_input + mlp_output + return output + + +class Transformer(nn.Module): + def __init__(self, config): + super().__init__() + self.layers = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states): + for layer_module in self.layers: + hidden_states = layer_module(hidden_states) + return hidden_states + + +class GLU(nn.Module): + def __init__(self, config, in_features): + super().__init__() + self.linear_proj = nn.Linear(in_features, config.hidden_size, bias=False) + self.norm1 = nn.LayerNorm(config.hidden_size) + self.act1 = nn.GELU() + self.act2 = nn.functional.silu + self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + def forward(self, x): + x = self.linear_proj(x) + x = self.act1(self.norm1(x)) + x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x) + x = self.dense_4h_to_h(x) + return x + + +class EVA2CLIPModel(nn.Module): + def __init__(self, config): + super().__init__() + vision_config = Namespace(**config.vision_config) + self.patch_embedding = PatchEmbedding(vision_config) + self.transformer = Transformer(vision_config) + self.linear_proj = GLU(config, in_features=vision_config.hidden_size) + self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + + def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)": + #print(f"======images shape:{images.shape}===========") + x = self.patch_embedding(images) + #print(f"========x.shape:{x.shape}===========") + x = self.transformer(x) + x = x[:, 1:] + x = self.linear_proj(x) + boi = self.boi.expand(x.shape[0], -1, -1) + eoi = self.eoi.expand(x.shape[0], -1, -1) + x = torch.cat((boi, x, eoi), dim=1) + return x -- Gitee From 0b1423d052af26dedfb15960a9222c46166be9f8 Mon Sep 17 00:00:00 2001 From: gitee_code_template Date: Mon, 29 Jan 2024 19:52:53 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E6=96=B0=E5=A2=9Ecogvlm=E5=A4=9A=E6=A8=A1?= =?UTF-8?q?=E6=80=81=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- PyTorch/built-in/foundation/CogVLM/README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/PyTorch/built-in/foundation/CogVLM/README.md b/PyTorch/built-in/foundation/CogVLM/README.md index 5f04469f84..22de21e085 100644 --- a/PyTorch/built-in/foundation/CogVLM/README.md +++ b/PyTorch/built-in/foundation/CogVLM/README.md @@ -18,6 +18,13 @@ sat/model/position_embedding/triton_rotary_embeddings.py 替换为model_zoo项 2、utils/models/mixin.py 替换为 model_zoo项目下 cogvlm_utils/mixin.py +3、finetune_demo/finetune_cogvlm_demo.py文件的import torch引入下添加如下两个引用 + +import torch_npu + +from torch_npu.contrib import transfer_to_npu + + 五、权重文件下载 1、微调权重下载:https://huggingface.co/THUDM/CogVLM/tree/main 下载cogvlm-base-224.zip -- Gitee From 3c9154ed27d95152f27cecc0e3ede2b0223b7690 Mon Sep 17 00:00:00 2001 From: gitee_code_template Date: Mon, 5 Feb 2024 15:49:17 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=97=8B=E8=BD=AC?= =?UTF-8?q?=E7=BC=96=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../CogVLM/cogvlm_utils/modeling_cogvlm.py | 51 ++----------------- .../CogVLM/cogvlm_utils/rotary_embeddings.py | 8 +++ 2 files changed, 12 insertions(+), 47 deletions(-) diff --git a/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/modeling_cogvlm.py b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/modeling_cogvlm.py index 7fa1dd1e41..e4ad335c82 100644 --- a/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/modeling_cogvlm.py +++ b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/modeling_cogvlm.py @@ -15,8 +15,8 @@ from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from .configuration_cogvlm import CogVLMConfig -#from .util import FastRotaryEmbedding from .rotary_embeddings import RotaryEmbedding as FastRotaryEmbedding +from .rotary_embeddings import rotate_half,apply_rotary_pos_emb_index_bhs from .visual import EVA2CLIPModel if TYPE_CHECKING: @@ -144,25 +144,6 @@ def attention_fn( context_layer = torch.matmul(attention_scores, value_layer) return context_layer -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] - gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[2]) - cos = torch.gather(cos.cpu().repeat(gather_indices.shape[0], 1, 1, 1).npu( - torch.npu.current_device()), 2, gather_indices) - sin = torch.gather(sin.cpu().repeat(gather_indices.shape[0], 1, 1, 1).npu( - torch.npu.current_device()), 2, gather_indices) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - class VisionExpertAttention(nn.Module): def __init__(self, config): super().__init__() @@ -172,8 +153,6 @@ class VisionExpertAttention(nn.Module): self.head_dim = self.hidden_size // self.num_heads self.max_position_embeddings = config.max_position_embeddings - #self.rotary_emb = RotaryEmbedding(self.hidden_size // self.num_heads) - #self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False) self.rotary_emb = FastRotaryEmbedding(self.hidden_size // self.num_heads) self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False) self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False) @@ -217,31 +196,9 @@ class VisionExpertAttention(nn.Module): kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - #torch.save(query_states.cpu(), "/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/query_states_rotary_atten_gpu.pth") - #torch.save(key_states.cpu(), "/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/key_states_rotary_atten_gpu.pth") - #torch.save(value_states.cpu(), "/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/value_states_rotary_atten_gpu.pth") - #breakpoint() - #print(f"==========query_states0:{query_states.shape}=========") - #print(f"===========key_states0:{key_states.shape}============") - #print(f"===========value_states0:{value_states.shape}=============") - #query_states = torch.load("/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/query_states_before_rotary_atten_gpu.pth", map_location="cpu").to(device=query_states.device, dtype=query_states.dtype) - #key_states = torch.load("/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/key_states_before_rotary_atten_gpu.pth", map_location="cpu").to(device=key_states.device, dtype=key_states.dtype) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - #query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids, max_seqlen=position_ids.max() + 1) - #torch.save(query_states.cpu(), "/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/query_states_rotary_after_atten_npu.pth") - #torch.save(key_states.cpu(), "/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/key_states_rotary_after_atten_npu.pth") - #breakpoint() - #print(f"==========query_states1:{query_states.shape}=========") - #print(f"===========key_states1:{key_states.shape}============") - #print(f"===========value_states1:{value_states.shape}=============") - #query_states = torch.load("/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/query_states_rotary_after_atten_gpu.pth", map_location="cpu").to(device=query_states.device, dtype=query_states.dtype) - #key_states = torch.load("/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/key_states_rotary_after_atten_gpu.pth", map_location="cpu").to(device=key_states.device, dtype=key_states.dtype) - #value_states = torch.load("/home/gpt_neox/d00816453/cogvlm/attn_tensor_224/value_states_rotary_after_atten_gpu.pth", map_location="cpu").to(device=value_states.device, dtype=value_states.dtype) - #print(f"==========query_states2:{query_states.shape}=========") - #print(f"===========key_states2:{key_states.shape}============") - #print(f"===========value_states2:{value_states.shape}=============") - #print(f"============qkv==================") + + cos, sin = self.rotary_emb(value_states, seq_len=position_ids.max()+1) + query_states, key_states = apply_rotary_pos_emb_index_bhs(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: key_states = torch.cat([past_key_value[0], key_states], dim=2) diff --git a/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/rotary_embeddings.py b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/rotary_embeddings.py index 5d3715d249..7486200f8d 100644 --- a/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/rotary_embeddings.py +++ b/PyTorch/built-in/foundation/CogVLM/cogvlm_utils/rotary_embeddings.py @@ -75,6 +75,14 @@ def rotate_half(x): x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions +def apply_rotary_pos_emb_index_bhs(q, k, cos, sin, position_id): + # batch_size, num_head, seq_len, hidden_size + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(1), \ + F.embedding(position_id, sin.squeeze(1)).unsqueeze(1) + q = (q * cos) + (rotate_half(q) * sin) + k = (k * cos) + (rotate_half(k) * sin) + return q, k + @torch.jit.script def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): -- Gitee