# text-classification **Repository Path**: busixianyu/text-classification ## Basic Information - **Project Name**: text-classification - **Description**: 本项目设计目的致力于打造一个通用的文本分类项目,让大家尽量写少的代码就能进行文本分类模型训练。 - **Primary Language**: Python - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 1 - **Created**: 2022-07-17 - **Last Updated**: 2025-02-26 ## Categories & Tags **Categories**: Uncategorized **Tags**: PyTorch, bert, transformers, text-classification, Nlp ## README # text-classification #### 介绍 文本分类项目在自然语言处理中属于比较简单的任务,可以使用机器学习和深度学习方法进行分类,本项目将提供一个统一的框架,方便使用深度学习进行分类任务, 既可以使用例如TextCNN等方法进行分类,也可以使用各种bert模型进行分类。本项目设计目的致力于打造一个通用的文本分类项目, 让大家尽量写少的代码就能进行文本分类模型训练。此项目通过统一的日志输出,图表展示,数据分析处理,模型训练,测试以及预测统一方法,只需要写最少的代码, 就可以完成文本分类任务。 #### 软件架构 主要使用pytorch和transformers框架进行代码编码。 数据集默认存放data目录下面,可以在启动参数中进行修改。数据集需要四个文件,训练集,测试集,验证集和分类文本。 如果数据缺失这四个文件,可以使用utils包下面的工具类进行数据的转换与生成。 1. conf: 配置文件目录,目前包含的配置文件如下: | 配置文件 | 说明 | |-------------|-----------------------------------------| | logging.yml | 日志配置文件,目前日志会进行两次输出,一个是控制台,另一个是日志文件 | | bert.yml | bert配置文件,项目配置的bert都在这个文件,如果想自己添加在此文件中添加 | | model.yml | 模型配置文件,自定义的模型在这个文件进行配置 | 2. data: 这个是默认存放数据的目录,在启动项配置可以修改 `--data_dir`。 3. logs: 日志存放目录,在日志配置文件中指定了输出文件中,生成的日志文件就在这个目录下,可以在启动项配置中修改 `--logging_path`。 4. utils: 工具类存放目录,目前有一些简单工具可以使用: 1. split_text_data.py:文本分割工具,如果只有训练集,需要拆分成训练集和验证集,或者拆分成三个训练集都可以使用这个文件,生成分类文本也使用这个文件。目前只支持csv格式文件。 2. txt2csv.py:将txt文件转换成csv文件,项目中使用csv更方便操作。 3. model_util.py:模型相关工具类,与模型操作项目可以使用这个工具类 4. utils.py:常用工具类,模型验证、测试等方法在这个工具类中。 5. 其他工具类可以自行查看,这里不一一介绍。 5. config.py:配置文件,启动默认参数可以在这个文件查找与修改,也可运行文件时指定。 6. train.py: 模型训练文件。 7. test.py: 模型测试文件。 8. predict.py: 模型预测文件。 9. word_sequence.py: 模型词表生成类,非bert模型需要先生成词表,放入savepoint相应目录下面 ```shell python train.py --help ``` 运行上面命令后帮助结果如下: ``` text --data_dir DATA_DIR 数据存储目录:news; THUCNews; xinwen_multi_label (default: news) --data_suffix DATA_SUFFIX 数据存储目录 (default: ) --root_dir ROOT_DIR 项目根目录 (default: ) --file_type FILE_TYPE 数据集文件类型:cvs,txt等 (default: csv) --train_file_path TRAIN_FILE_PATH 训练集文件路径 (default: train.csv) --dev_file_path DEV_FILE_PATH 验证集文件路径 (default: dev.csv) --test_file_path TEST_FILE_PATH 测试集文件路径 (default: test.csv) --class_path CLASS_PATH 分类文件路径 (default: class.txt) --is_split_dataset [IS_SPLIT_DATASET] 是否拆分数据集 (default: False) --test_out_path TEST_OUT_PATH 测试集文件路径 (default: test_out.csv) --predict_file_path PREDICT_FILE_PATH 预测文件路径 (default: test.csv) --dev_ratio DEV_RATIO 验证集拆分比例 (default: 0.1) --logging_step LOGGING_STEP 打印日志频率 (default: 50) --logging_path LOGGING_PATH 日志保存路径 (default: logging.log) --label_name LABEL_NAME 分类标签列名称 (default: label) --raw_label_name RAW_LABEL_NAME 原始分类标签名称 (default: label) --text_name TEXT_NAME 文本列名称 (default: text) --num_labels NUM_LABELS 类别数量 (default: 0) --problem_type PROBLEM_TYPE 单标签分类还是多标签分类 multi_label_classification, single_label_classification (default: single_label_classification) --model_name MODEL_NAME 模型名称: ernie (default: bert) --hidden_size HIDDEN_SIZE 隐藏层大小 (default: 768) --model_dir MODEL_DIR 模型存储目录 (default: savepoint/) --model_suffix MODEL_SUFFIX 实际模型名称 (default: ) --checkpoint CHECKPOINT 检查点文件路径 (default: model_best.pth.tar) --ws_path WS_PATH word sequence 文件路径 (default: ws.json) --embedding_type EMBEDDING_TYPE embedding 文件类型:weibo (default: None) --embedding_path EMBEDDING_PATH embedding path (default: sgns.weibo.word.bz2) --custom CUSTOM 自定义模型名称 (default: text_cnn) --is_bert [IS_BERT] 是否使用bert模型 (default: False) --is_tokenizer [IS_TOKENIZER] 是否使用bert模型分词 (default: False) --freeze_bert_head [FREEZE_BERT_HEAD] 是否冻结bert参数 (default: False) --max_grad_norm MAX_GRAD_NORM 梯度修剪 (default: 10) --batch_size BATCH_SIZE batch_size大小 (default: 64) --max_length MAX_LENGTH 最大长度 (default: 65) --epochs EPOCHS 迭代次数 (default: 100) --learning_rate LEARNING_RATE 学习率 (default: 5e-05) --add_special_tokens [ADD_SPECIAL_TOKENS] 是否添加特别标记 (default: True) --no_add_special_tokens 是否添加特别标记 (default: False) --tokenizer TOKENIZER 分词器 (default: None) --filter_sizes FILTER_SIZES [FILTER_SIZES ...] cnn卷积核大小 (default: []) --num_filters NUM_FILTERS cnn输出通道数 (default: 256) --dropout DROPOUT dropout比例 (default: 0.3) --seed SEED 随机种子 (default: 1) --device_name DEVICE_NAME 使用cpu,cuda或mps (default: cpu) --device DEVICE 设备对象,cpu/cuda/mps (default: None) --label2id LABEL2ID --id2label ID2LABEL --local_rank LOCAL_RANK For distributed training: local_rank (default: -1) --model_dict MODEL_DICT 模型字典数据 (default: {}) --patience PATIENCE early stop迭代数 (default: 20) --inner_patience INNER_PATIENCE 内层忍耐数 (default: 400) --writer WRITER tensorboard->SummaryWriter (default: None) --model MODEL 模型 (default: None) --optimizer OPTIMIZER 优化器 (default: None) --ws WS word sequence 类 (default: None) --embedding EMBEDDING embedding (default: None) --vocab_size VOCAB_SIZE 词表大小 (default: None) --only_dev [ONLY_DEV] 是否只是验证 (default: False) --top_k TOP_K 是否只是验证 (default: (1, 2)) ``` 以上命令随着项目的完善会有变动,具体帮助名称请自己运行查看。 #### 安装教程 1. 使用requirements.txt进行依赖安装 #### 使用说明 1. 本项目使用的数据集为头条新闻分类数据集,存放在data/news目录下面。主要包含15个类别,具体类别如下: ```text news_agriculture news_car news_culture news_edu news_entertainment news_finance news_game news_house news_military news_sports news_stock news_story news_tech news_travel news_world ``` 训练集中每个数据集的数量柱状图如下: ![](assets/label_count.png) jieba分词的词云效果如下: 训练集准确率图像: 训练集损失图像: 验证集和训练集精度图像: 验证集和训练集损失图像: | model_name | bert模型名称 | |------------|----------| | bert | bert-base-chinese | | bert_wwm | hfl/chinese-bert-wwm | | raw_bert | bert-base-chinese | | roberta | hfl/chinese-roberta-wwm-ext | |ernie | nghuyong/ernie-1.0 | | ernie_healthy |nghuyong/ernie-health-zh | |albert | voidful/albert_chinese_tiny | |reformer | junnyu/roformer_chinese_base| 本数据运行步骤: 1. 本数据集只有原始数据,没有验证集,测试集也是只有文本,没有标签。 首先我们要先把数据拆分成训练集和验证集。可以使用process_data.py统一方法进行数据处理。 2. 进行模型训练。在模型训练过程中每轮训练后都会对验证集进行验证,如果验证集的最好分数一直没有被更新,到达忍耐度就会自动停止。 3. 进行模型测试。预测的结果会保存文本,正确标签,错误标签以及错误预测的概率。 4. 进行模型预测。模型预测中有两个方法,一个是对csv文件进行预测,一个是对单个句子进行预测。 ##### 下面是对news数据集使用bert以上四个步骤的整体处理命令: ```shell python process_data.py \ --is_bert true \ --split_path raw.csv \ --is_split true \ --is_split_dev true \ --is_split_test true \ --is_create_class_txt true \ --is_create_class_index true \ --is_remove_punc false \ --is_analysis true \ --label_column label_desc \ --data_dir news \ --split char \ --is_save_analysis false \ --raw_label_name label \ --text_name sentence # 训练 python train.py \ --is_bert true \ --data_dir news \ --model_name bert \ --freeze_bert_head false \ --raw_label_name label \ --text_name text \ --hidden_size 768 \ --is_tokenizer true \ --batch_size 64 \ --max_length 65 \ --epochs 100 \ --logging_step 50 \ --patience 3 \ --device_name cpu # 测试 python test.py \ --is_bert true \ --data_dir news \ --model_name bert \ --freeze_bert_head true \ --raw_label_name label \ --text_name text \ --hidden_size 768 \ --is_tokenizer true \ --batch_size 64 \ --max_length 65 \ --epochs 100 \ --logging_step 50 \ --patience 3 \ --device_name cpu # 预测 python predict.py \ --is_bert true \ --data_dir news \ --model_name bert \ --freeze_bert_head true \ --raw_label_name label \ --text_name text \ --hidden_size 768 \ --is_tokenizer true \ --batch_size 64 \ --max_length 65 \ --epochs 100 \ --logging_step 50 \ --patience 3 \ --device_name cpu \ --predict_file_path raw_test.csv ``` 其他参数可以根据自己的使用情况进行指定,如果运行默认的数据集,无需指定参数。