diff --git a/mxRAG/PocValidation/README.md b/mxRAG/PocValidation/README.md index 0b3ff8025bd116000a2ea7749cb0351e460ace5f..c2c563e0936eb0d9802e4f3c7c91b7f9ddf7506c 100644 --- a/mxRAG/PocValidation/README.md +++ b/mxRAG/PocValidation/README.md @@ -4,4 +4,5 @@ | chat_with_ascend | 问答场景参考样例 | | code_with_ascend | 代码补全参考样例 | | prompt_compressor | promt压缩参考样例 | -| rag_recursive_tree_demo.oy | 递归树检索参考样例 | +| rag_recursive_tree_demo.py | 递归树检索参考样例 | +| embedding_finetune.py | Embedding微调参考样例 | diff --git a/mxRAG/PocValidation/embedding_finetune.py b/mxRAG/PocValidation/embedding_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..2ac2be673540bcacc19dd65496c10b00b075fc51 --- /dev/null +++ b/mxRAG/PocValidation/embedding_finetune.py @@ -0,0 +1,209 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +import argparse +import os + +from paddle.base import libpaddle +from loguru import logger +import torch +import torch_npu +from datasets import load_dataset +from sentence_transformers import SentenceTransformer +from sentence_transformers.evaluation import InformationRetrievalEvaluator +from sentence_transformers.losses import MultipleNegativesRankingLoss +from sentence_transformers import SentenceTransformerTrainingArguments +from sentence_transformers.training_args import BatchSamplers +from sentence_transformers import SentenceTransformerTrainer + +from mx_rag.llm import Text2TextLLM +from mx_rag.reranker.local import LocalReranker +from mx_rag.tools.finetune.generator import TrainDataGenerator, DataProcessConfig +from mx_rag.utils import ClientParam +from mx_rag.utils.file_check import FileCheck + +DEFAULT_LLM_TIMEOUT = 10 * 60 + + +class Finetune: + def __init__(self, + document_path: str, + generate_dataset_path: str, + llm: Text2TextLLM, + embed_model_path: str, + reranker: LocalReranker, + finetune_output_path: str, + featured_percentage: float, + llm_threshold_score: float, + train_question_number: int, + query_rewrite_number: int, + eval_data_path: str, + log_path: str, + max_iter: int, + increase_rate: float): + self.document_path = document_path + self.generate_dataset_path = generate_dataset_path + self.llm = llm + self.embed_model_path = embed_model_path + self.reranker = reranker + self.finetune_output_path = finetune_output_path + + self.featured_percentage = featured_percentage + self.llm_threshold_score = llm_threshold_score + self.train_question_number = train_question_number + self.query_rewrite_number = query_rewrite_number + + self.eval_data_path = eval_data_path + + self.log_path = log_path + self.max_iter = max_iter + self.increase_rate = increase_rate + + def start(self): + # 配置日志文件 + logger.add(self.log_path, rotation="1 MB", retention="10 days", level="INFO", + format="{time} {level} {message}") + train_data_generator = TrainDataGenerator(self.llm, self.generate_dataset_path, self.reranker) + logger.info("--------------------Processing origin document--------------------") + split_doc_list = train_data_generator.generate_origin_document(self.document_path) + logger.info("--------------------Calculate origin embedding model recall--------------------") + origin_recall5 = self.evaluate("origin_model", self.embed_model_path) + logger.info(f"origin_recall@5: {origin_recall5}") + config = DataProcessConfig(question_number=self.train_question_number, + featured_percentage=self.featured_percentage, + llm_threshold_score=self.llm_threshold_score, + query_rewrite_numer=self.query_rewrite_number) + iter_num = 1 + while iter_num <= self.max_iter: + logger.info(f'the {iter_num} iteration beginning') + per_data_len = round(len(split_doc_list) // self.max_iter) + end_index = len(split_doc_list) if iter_num == self.max_iter else iter_num * per_data_len + train_doc_list = split_doc_list[:end_index] + logger.info("--------------------Generating training data--------------------") + train_data_generator.generate_train_data(train_doc_list, config) + + logger.info("--------------------Fine-tuning embedding--------------------") + train_data_path = os.path.join(self.generate_dataset_path, "train_data.jsonl") + output_embed_model_path = os.path.join(self.finetune_output_path, 'embedding', str(iter_num)) + if not os.path.exists(output_embed_model_path): + os.makedirs(output_embed_model_path) + FileCheck.dir_check(output_embed_model_path) + self.train_embedding(train_data_path, output_embed_model_path) + logger.info("--------------------Calculate origin embedding model recall--------------------") + finetune_recall5 = self.evaluate("finetune_model", output_embed_model_path) + logger.info(f"finetune_recall@5: {finetune_recall5}") + recall_increase = (finetune_recall5 - origin_recall5) / origin_recall5 * 100 + logger.info(f'The recall rate of the {iter_num} iteration increases by {recall_increase}%.') + iter_num += 1 + if recall_increase > self.increase_rate or finetune_recall5 >= 0.95: + break + if iter_num < self.max_iter: + self.delete_dataset_file() + + def train_embedding(self, train_data_path, output_path): + torch.npu.set_device(torch.device("npu:0")) + model = SentenceTransformer(self.embed_model_path, device="npu" if torch.npu.is_available() else "cpu") + train_loss = MultipleNegativesRankingLoss(model) + train_dataset = load_dataset("json", data_files=train_data_path, split="train") + args = SentenceTransformerTrainingArguments( + output_dir=output_path, # output directory and hugging face model ID + num_train_epochs=4, # number of epochs + per_device_train_batch_size=4, # train batch size + gradient_accumulation_steps=16, # for a global batch size of 512 + warmup_ratio=0.1, # warmup ratio + learning_rate=2e-5, # learning rate, 2e-5 is a good value + lr_scheduler_type="cosine", # use constant learning rate scheduler + optim="adamw_torch_fused", # use fused adamw optimizer + batch_sampler=BatchSamplers.NO_DUPLICATES, + # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch + logging_steps=10, # log every 10 steps + ) + trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset.select_columns(["anchor", "positive"]), + loss=train_loss, + ) + trainer.train() + trainer.save_model() + torch.npu.empty_cache() + + def evaluate(self, model_name, model_path): + torch.npu.set_device(torch.device("npu:0")) + model = SentenceTransformer(model_path, device="npu" if torch.npu.is_available() else "cpu") + eval_data = load_dataset("json", data_files=self.eval_data_path, split="train") + eval_data = eval_data.add_column("id", range(len(eval_data))) + corpus = dict( + zip(eval_data["id"], eval_data["positive"]) + ) + queries = dict( + zip(eval_data["id"], eval_data["anchor"]) + ) + relevant_docs = {} + for q_id in queries: + relevant_docs[q_id] = [q_id] + evaluator = InformationRetrievalEvaluator(queries=queries, + corpus=corpus, + relevant_docs=relevant_docs, + name=model_name) + result = evaluator(model) + return result[model_name + "_cosine_recall@5"] + + def delete_dataset_file(self): + # 删除dataset下所有文件 + for filename in os.listdir(self.generate_dataset_path): + file_path = os.path.join(self.generate_dataset_path, filename) + # 检查是否是文件 + if os.path.isfile(file_path): + try: + os.remove(file_path) + logger.info(f"delete file success: {file_path}") + except Exception as e: + logger.info(f"delete file occur error:", {file_path} - {e}) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--document_path", type=str, default="") + parser.add_argument("--generate_dataset_path", type=str, default="") + parser.add_argument("--llm_url", type=str, default="") + parser.add_argument("--llm_model_name", type=str, default="") + parser.add_argument("--use_http", type=bool, default=False) + parser.add_argument("--embedding_model_path", type=str, default="") + parser.add_argument("--reranker_model_path", type=str, default="") + parser.add_argument("--finetune_output_path", type=str, default="") + + parser.add_argument("--featured_percentage", type=float, default=0.8) + parser.add_argument("--llm_threshold_score", type=float, default=0.8) + parser.add_argument("--train_question_number", type=int, default=2) + parser.add_argument("--query_rewrite_number", type=int, default=1) + + parser.add_argument("--eval_data_path", type=str, default="") + + parser.add_argument("--log_path", type=str, default='./app.log') + parser.add_argument("--max_iter", type=int, default=5) + parser.add_argument("--increase_rate", type=float, default=20) + + args = parser.parse_args() + + logger.info("Fine-tuning beginning") + client_param = ClientParam(timeout=DEFAULT_LLM_TIMEOUT, use_http=args.use_http) + text_llm = Text2TextLLM(base_url=args.llm_url, model_name=args.llm_model_name, client_param=client_param) + local_reranker = LocalReranker(args.reranker_model_path, dev_id=1) + + finetune = Finetune(args.document_path, + args.generate_dataset_path, + text_llm, + args.embedding_model_path, + local_reranker, + args.finetune_output_path, + args.featured_percentage, + args.llm_threshold_score, + args.train_question_number, + args.query_rewrite_number, + args.eval_data_path, + args.log_path, + args.max_iter, + args.increase_rate) + finetune.start() + logger.info("Fine-tuning ending") \ No newline at end of file