From a10a79c236e3bbe156de469aa7685d17249779c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=B3=A2?= Date: Fri, 20 Sep 2024 13:33:52 +0800 Subject: [PATCH] minicpm v2.6 finetune --- .../built-in/mm/MiniCPM-V/finetune/dataset.py | 281 +++++++++++++----- .../MiniCPM-V/finetune/ds_config_zero2.json | 2 +- .../mm/MiniCPM-V/finetune/finetune.py | 22 +- .../mm/MiniCPM-V/finetune/finetune_ds.sh | 10 +- .../mm/MiniCPM-V/finetune/finetune_lora.sh | 8 +- .../mm/MiniCPM-V/web_demo_streamlit-2_5.py | 1 + 6 files changed, 235 insertions(+), 89 deletions(-) diff --git a/PyTorch/built-in/mm/MiniCPM-V/finetune/dataset.py b/PyTorch/built-in/mm/MiniCPM-V/finetune/dataset.py index 1cf15cdcda..dbf6bbca68 100644 --- a/PyTorch/built-in/mm/MiniCPM-V/finetune/dataset.py +++ b/PyTorch/built-in/mm/MiniCPM-V/finetune/dataset.py @@ -3,6 +3,8 @@ import json import logging import math import os +import re +import random from dataclasses import dataclass, field from typing import Dict, List, Optional @@ -12,6 +14,9 @@ from PIL import Image from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset from transformers import AutoProcessor, AutoTokenizer +import logging + +logger = logging.getLogger(__name__) llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}" @@ -28,6 +33,7 @@ class SupervisedDataset(Dataset): patch_size=14, query_nums=64, batch_vision=False, + max_length=2048, ): super(SupervisedDataset, self).__init__() self.raw_data = raw_data @@ -38,35 +44,46 @@ class SupervisedDataset(Dataset): self.patch_size = patch_size self.query_nums=query_nums self.batch_vision = batch_vision + self.max_length = max_length def __len__(self): return len(self.raw_data) def __getitem__(self, i) -> Dict[str, torch.Tensor]: - image = Image.open(self.raw_data[i]["image"]).convert("RGB") - ret = preprocess( - image, - self.raw_data[i]["conversations"], - self.tokenizer, - self.transform, - query_nums=self.query_nums, - slice_config=self.slice_config, - llm_type=self.llm_type, - patch_size=self.patch_size, - batch_vision=self.batch_vision, - ) - ret = dict( - input_ids=ret["input_ids"], - position_ids=ret["position_ids"], - labels=ret["target"], - attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool), - pixel_values=ret["pixel_values"], - tgt_sizes=ret["tgt_sizes"], - image_bound=ret["image_bound"], - ) - + try: + if isinstance(self.raw_data[i]["image"], str): + images_dict = { "" : Image.open(self.raw_data[i]["image"]).convert("RGB") } + elif isinstance(self.raw_data[i]["image"], Dict): + ### for multi-images input, the template for every image is , such as , + images_dict = {img_name : Image.open(img_path).convert("RGB") for img_name, img_path in self.raw_data[i]["image"].items()} + + ret = preprocess( + images_dict, + self.raw_data[i]["conversations"], + self.tokenizer, + self.transform, + query_nums=self.query_nums, + slice_config=self.slice_config, + llm_type=self.llm_type, + patch_size=self.patch_size, + batch_vision=self.batch_vision, + max_length=self.max_length + ) + ret = dict( + input_ids=ret["input_ids"], + position_ids=ret["position_ids"], + labels=ret["target"], + attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool), + pixel_values=ret["pixel_values"], + tgt_sizes=ret["tgt_sizes"], + image_bound=ret["image_bound"], + ) + except: + logger.error(f"data fetch error") + return self.__getitem__(random.randint(0, len(self))) return ret + def data_collator(examples, padding_value=0, max_length=2048): def trim_and_pad(seq, batch_first, padding_value): return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value) @@ -105,7 +122,7 @@ def data_collator(examples, padding_value=0, max_length=2048): } -def conversation_to_ids(conversation, tokenizer, llm_type=None): +def conversation_to_ids(conversation, tokenizer, llm_type=None, new_schema=False, max_length=2048): """ for single image multi-turn conversation conversation: [{'role': 'user', 'content': 'Describe this image'}, @@ -115,6 +132,10 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None): input_ids, context, raw_msg = conversation_to_ids_llama3( conversation, tokenizer ) + elif llm_type == "qwen2": + input_ids, context, raw_msg = conversation_to_ids_qwen2( + conversation, tokenizer + ) else: input_ids, context, raw_msg = conversation_to_ids_minicpm( conversation, tokenizer @@ -122,27 +143,42 @@ def conversation_to_ids(conversation, tokenizer, llm_type=None): ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32)) context = torch.from_numpy(np.hstack(context, dtype=np.int8)) + if input_ids.shape[-1] > max_length: + ids =ids[:max_length] + context = context[:max_length] + logger.warning(f"The input length ({input_ids.shape[-1]}) exceeds the model's maximum length ({max_length}), so it has been truncated") + + if torch.all(context): + logger.error("No tokens available to compute loss.") + raise Exception("No tokens available to compute loss.") # build target target = torch.full_like(ids, -100, dtype=torch.int32) - mask_zero = context == 0 - target[:-1][mask_zero[1:]] = ids[1:][mask_zero[1:]] - - mask_one_zero = (context == 1) & (torch.roll(context, 1, 0) == 0) - mask_one_zero = mask_one_zero[1:] - - if hasattr(tokenizer, "eot_id"): - eot_or_eos_id = tokenizer.eot_id - else: - eot_or_eos_id = tokenizer.eos_id - + + for i in range(1, len(ids)): + if context[i] == 0: + target[i - 1] = ids[i] + if context[i] == 1 and context[i - 1] == 0: + if hasattr(tokenizer, "eot_id"): + target[i - 1] = tokenizer.eot_id + else: + target[i - 1] = tokenizer.eos_id + # build image bound - image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0] - image_start_tokens += 1 - image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0] + if new_schema: + start_cond = (ids == tokenizer.im_start_id) | (ids == tokenizer.slice_start_id) + end_cond = (ids == tokenizer.im_end_id) | (ids == tokenizer.slice_end_id) + image_start_tokens = torch.where(start_cond)[0] + image_start_tokens += 1 + image_end_tokens = torch.where(end_cond)[0] + else: + image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0] + image_start_tokens += 1 + image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0] if len(image_start_tokens) != len(image_end_tokens): - print("image start token != image end tokens") - + logger.error("image start token != image end tokens") + raise Exception("image start token != image end tokens") + if len(image_start_tokens) > 0: image_bound = torch.hstack( [image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)] @@ -232,9 +268,48 @@ def conversation_to_ids_llama3(conversation, tokenizer): return input_ids, context, raw_msg +def conversation_to_ids_qwen2(conversation, tokenizer): + raw_msg = "" + chat = [] + context = [] + for idx, msg in enumerate(conversation): + role = msg["role"] + message = msg["content"] + assert role in ["user", "assistant"] + if role == "user": + prefix = "user" + else: + prefix = "assistant" + chat.append({"role":prefix, "content":message}) + raw_msg += prefix + message + assert set([i['role'] for i in chat]) & set(['assistant']) + + ret = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False) + input_ids = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False) + input_ids = np.array(input_ids) + + start_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_start|>'))[0] + assistant_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('assistant'))[0] + end_idxs = np.where(input_ids == tokenizer.convert_tokens_to_ids('<|im_end|>'))[0] + + context = np.ones_like(input_ids, dtype=np.int8) + + for assistant_idx in assistant_idxs: + if assistant_idx-1 in set(start_idxs): + st = assistant_idx + 1 + for end_idx in end_idxs: + if end_idx > st: + context[st: end_idx + 1] = 0 + break + + input_ids = np.hstack(input_ids) + context = np.hstack(context) + return input_ids, context, raw_msg + + def preprocess( - image, - conversation, + images_dict, + conversations, tokenizer, transform, query_nums=64, @@ -242,13 +317,14 @@ def preprocess( llm_type=None, patch_size=14, batch_vision=False, + max_length=2048, ): """ - single image preprocess, the image will be placed at the top of the conversation + single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation """ - conversation = copy.deepcopy(conversation) - assert len(conversation) > 1, "conversation length must large than 2" - assert conversation[0]["role"] == "user", "the first role must be user" + conversations = copy.deepcopy(conversations) + assert len(conversations) > 1, "conversations length must large than 2" + assert conversations[0]["role"] == "user", "the first role must be user" if slice_config is not None: assert isinstance(slice_config, Dict) @@ -258,37 +334,74 @@ def preprocess( default_image_placeholder = ( tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end ) - if slice_config: - images = [] - source_image, patches, best_grid = slice_image( - image, - slice_config["max_slice_nums"], - slice_config["scale_resolution"], - slice_config["patch_size"], - ) - images.append(source_image) - image_placeholder = default_image_placeholder - if len(patches) > 0: - for i in range(len(patches)): - for j in range(len(patches[0])): - images.append(patches[i][j]) - - image_placeholder += get_grid_placeholder( - tokenizer, best_grid, query_nums) - images = [transform(i) for i in images] - else: - images = [transform(image)] - image_placeholder = default_image_placeholder - if "" in conversation[0]["content"]: - conversation[0]["content"] = conversation[0]["content"].replace( - "", image_placeholder - ) + new_schema = False + use_image_id = False + if llm_type=='qwen2': + new_schema = True + use_image_id = True + image_placeholder_dict = {} + images = [] + image_id_cnt = 0 + for img_name, image in images_dict.items(): + if slice_config: + source_image, patches, best_grid = slice_image( + image, + slice_config["max_slice_nums"], + slice_config["scale_resolution"], + slice_config["patch_size"], + ) + images.append(source_image) + image_placeholder = default_image_placeholder + if len(patches) > 0: + for i in range(len(patches)): + for j in range(len(patches[0])): + images.append(patches[i][j]) + if use_image_id: + image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder + image_id_cnt += 1 + image_placeholder += get_grid_placeholder( + tokenizer, best_grid, query_nums, new_schema = new_schema) + image_placeholder_dict[img_name] = image_placeholder + else: + images.append(image) + if use_image_id: + image_placeholder = f'{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}' + image_placeholder + image_id_cnt += 1 + else: + image_placeholder = default_image_placeholder + image_placeholder_dict[img_name] = image_placeholder + + images = [transform(i) for i in images] + + if len(images_dict) == 1 and "" in images_dict: + if "" in conversations[0]["content"]: + conversations[0]["content"] = conversations[0]["content"].replace( + "", image_placeholder + ) + else: + conversations[0]["content"] = ( + image_placeholder + "\n" + conversation[0]["content"] + ) + input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length) else: - conversation[0]["content"] = ( - image_placeholder + "\n" + conversation[0]["content"] - ) - - input_dict = conversation_to_ids(conversation, tokenizer, llm_type) + pattern = r'' + new_conversations = [] + for conversation in conversations: + content = conversation['content'] + parts = re.split(f'({pattern})', content) + for i, part in enumerate(parts): + if not part.strip(): + continue + if re.match(pattern, part): + if part in image_placeholder_dict: + parts[i] = image_placeholder_dict[part] + else: + raise Exception(f"not found {part} in image dict") + conversation['content'] = '\n'.join(parts) + new_conversations.append(conversation) + conversations = new_conversations + + input_dict = conversation_to_ids(conversations, tokenizer, llm_type, new_schema, max_length) if batch_vision: tgt_sizes = [] @@ -426,10 +539,15 @@ def split_to_patches(image, grid): return patches -def get_grid_placeholder(tokenizer, grid, query_num): - image_placeholder = ( - tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end - ) +def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False): + if new_schema: + image_placeholder = ( + tokenizer.slice_start + tokenizer.unk_token * query_num + tokenizer.slice_end + ) + else: + image_placeholder = ( + tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end + ) cols = grid[0] rows = grid[1] @@ -439,7 +557,10 @@ def get_grid_placeholder(tokenizer, grid, query_num): for j in range(cols): lines.append(image_placeholder) slices.append("".join(lines)) - slice_placeholder = tokenizer.slice_start + \ + if new_schema: + slice_placeholder = '\n'.join(slices) + else: + slice_placeholder = tokenizer.slice_start + \ "\n".join(slices) + tokenizer.slice_end return slice_placeholder diff --git a/PyTorch/built-in/mm/MiniCPM-V/finetune/ds_config_zero2.json b/PyTorch/built-in/mm/MiniCPM-V/finetune/ds_config_zero2.json index a16674d460..4d42d440b4 100644 --- a/PyTorch/built-in/mm/MiniCPM-V/finetune/ds_config_zero2.json +++ b/PyTorch/built-in/mm/MiniCPM-V/finetune/ds_config_zero2.json @@ -1,7 +1,7 @@ { "fp16": { "enabled": "auto", - "loss_scale": 32, + "loss_scale": 0, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, diff --git a/PyTorch/built-in/mm/MiniCPM-V/finetune/finetune.py b/PyTorch/built-in/mm/MiniCPM-V/finetune/finetune.py index 6f79b36187..1023e05dfc 100644 --- a/PyTorch/built-in/mm/MiniCPM-V/finetune/finetune.py +++ b/PyTorch/built-in/mm/MiniCPM-V/finetune/finetune.py @@ -6,6 +6,8 @@ from dataclasses import dataclass, field from functools import partial from typing import Dict, List, Optional, Union, Literal, Tuple from types import MethodType +from torchvision import transforms + import torch import transformers from accelerate.utils import DistributedType @@ -118,6 +120,7 @@ def make_supervised_data_module( patch_size=patch_size, query_nums=query_nums, batch_vision=batch_vision, + max_length=max_length, ) if data_args.eval_data_path: @@ -131,6 +134,7 @@ def make_supervised_data_module( patch_size=patch_size, query_nums=query_nums, batch_vision=batch_vision, + max_length=max_length, ) else: eval_dataset = None @@ -142,6 +146,18 @@ def make_supervised_data_module( ) +def build_transform(): + IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN + IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD + return transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize( + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD + ), + ] + ) + def get_parameter_number(model): trainable_params, all_param = 0, 0 for param in model.parameters(): @@ -261,10 +277,11 @@ def train(): else: batch_vision = False + transform_func = build_transform() data_module = make_supervised_data_module( tokenizer=tokenizer, data_args=data_args, - transform=model.transform, + transform=transform_func, data_collator=data_collator, slice_config=slice_config, llm_type=llm_type, @@ -273,7 +290,8 @@ def train(): batch_vision=batch_vision, max_length=training_args.model_max_length, ) - + + training_args.gradient_checkpointing_kwargs={"use_reentrant":False} trainer = CPMTrainer( model=model, tokenizer=tokenizer, diff --git a/PyTorch/built-in/mm/MiniCPM-V/finetune/finetune_ds.sh b/PyTorch/built-in/mm/MiniCPM-V/finetune/finetune_ds.sh index 156dcd3f58..2ed365f680 100644 --- a/PyTorch/built-in/mm/MiniCPM-V/finetune/finetune_ds.sh +++ b/PyTorch/built-in/mm/MiniCPM-V/finetune/finetune_ds.sh @@ -20,12 +20,14 @@ NODE_RANK=0 MASTER_ADDR=localhost MASTER_PORT=6001 -MODEL="openbmb/MiniCPM-Llama3-V-2_5" # or openbmb/MiniCPM-V-2 +MODEL="openbmb/MiniCPM-V-2_6" +# or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5 # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations. # See the section for finetuning in README for more information. DATA="path/to/trainging_data" EVAL_DATA="path/to/test_data" -LLM_TYPE="llama3" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm +LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3" +MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096 DISTRIBUTED_ARGS=" --nproc_per_node $NPUS_PER_NODE \ @@ -55,8 +57,8 @@ torchrun $DISTRIBUTED_ARGS finetune/finetune.py \ --max_slice_nums 9 \ --max_steps 10000 \ --eval_steps 1000 \ - --output_dir finetune/output/output_minicpmv2 \ - --logging_dir finetune/output/output_minicpmv2 \ + --output_dir output/output_minicpmv26 \ + --logging_dir output/output_minicpmv26 \ --logging_strategy "steps" \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ diff --git a/PyTorch/built-in/mm/MiniCPM-V/finetune/finetune_lora.sh b/PyTorch/built-in/mm/MiniCPM-V/finetune/finetune_lora.sh index a7c458cac6..2c130e6484 100644 --- a/PyTorch/built-in/mm/MiniCPM-V/finetune/finetune_lora.sh +++ b/PyTorch/built-in/mm/MiniCPM-V/finetune/finetune_lora.sh @@ -20,12 +20,16 @@ NODE_RANK=0 MASTER_ADDR=localhost MASTER_PORT=6001 -MODEL="openbmb/MiniCPM-Llama3-V-2_5" # or openbmb/MiniCPM-V-2 +MODEL="openbmb/MiniCPM-V-2_6" # or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5 # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations. # See the section for finetuning in README for more information. DATA="path/to/trainging_data" EVAL_DATA="path/to/test_data" -LLM_TYPE="llama3" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm +LLM_TYPE="qwen2" +# if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm +#if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE=llama3 + +MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096 DISTRIBUTED_ARGS=" --nproc_per_node $NPUS_PER_NODE \ diff --git a/PyTorch/built-in/mm/MiniCPM-V/web_demo_streamlit-2_5.py b/PyTorch/built-in/mm/MiniCPM-V/web_demo_streamlit-2_5.py index 23d8817c17..4941e832ba 100644 --- a/PyTorch/built-in/mm/MiniCPM-V/web_demo_streamlit-2_5.py +++ b/PyTorch/built-in/mm/MiniCPM-V/web_demo_streamlit-2_5.py @@ -95,6 +95,7 @@ if user_text: # Generate reply using the model model = st.session_state.model tokenizer = st.session_state.tokenizer + imagefile = None with st.chat_message(A_NAME, avatar="assistant"): # If the previous message contains an image, pass the image to the model -- Gitee