# text-classification
**Repository Path**: busixianyu/text-classification
## Basic Information
- **Project Name**: text-classification
- **Description**: 本项目设计目的致力于打造一个通用的文本分类项目,让大家尽量写少的代码就能进行文本分类模型训练。
- **Primary Language**: Python
- **License**: Not specified
- **Default Branch**: master
- **Homepage**: None
- **GVP Project**: No
## Statistics
- **Stars**: 0
- **Forks**: 1
- **Created**: 2022-07-17
- **Last Updated**: 2025-02-26
## Categories & Tags
**Categories**: Uncategorized
**Tags**: PyTorch, bert, transformers, text-classification, Nlp
## README
# text-classification
#### 介绍
文本分类项目在自然语言处理中属于比较简单的任务,可以使用机器学习和深度学习方法进行分类,本项目将提供一个统一的框架,方便使用深度学习进行分类任务,
既可以使用例如TextCNN等方法进行分类,也可以使用各种bert模型进行分类。本项目设计目的致力于打造一个通用的文本分类项目,
让大家尽量写少的代码就能进行文本分类模型训练。此项目通过统一的日志输出,图表展示,数据分析处理,模型训练,测试以及预测统一方法,只需要写最少的代码,
就可以完成文本分类任务。
#### 软件架构
主要使用pytorch和transformers框架进行代码编码。
数据集默认存放data目录下面,可以在启动参数中进行修改。数据集需要四个文件,训练集,测试集,验证集和分类文本。
如果数据缺失这四个文件,可以使用utils包下面的工具类进行数据的转换与生成。
1. conf: 配置文件目录,目前包含的配置文件如下:
| 配置文件 | 说明 |
|-------------|-----------------------------------------|
| logging.yml | 日志配置文件,目前日志会进行两次输出,一个是控制台,另一个是日志文件 |
| bert.yml | bert配置文件,项目配置的bert都在这个文件,如果想自己添加在此文件中添加 |
| model.yml | 模型配置文件,自定义的模型在这个文件进行配置 |
2. data: 这个是默认存放数据的目录,在启动项配置可以修改 `--data_dir`。
3. logs: 日志存放目录,在日志配置文件中指定了输出文件中,生成的日志文件就在这个目录下,可以在启动项配置中修改 `--logging_path`。
4. utils: 工具类存放目录,目前有一些简单工具可以使用:
1. split_text_data.py:文本分割工具,如果只有训练集,需要拆分成训练集和验证集,或者拆分成三个训练集都可以使用这个文件,生成分类文本也使用这个文件。目前只支持csv格式文件。
2. txt2csv.py:将txt文件转换成csv文件,项目中使用csv更方便操作。
3. model_util.py:模型相关工具类,与模型操作项目可以使用这个工具类
4. utils.py:常用工具类,模型验证、测试等方法在这个工具类中。
5. 其他工具类可以自行查看,这里不一一介绍。
5. config.py:配置文件,启动默认参数可以在这个文件查找与修改,也可运行文件时指定。
6. train.py: 模型训练文件。
7. test.py: 模型测试文件。
8. predict.py: 模型预测文件。
9. word_sequence.py: 模型词表生成类,非bert模型需要先生成词表,放入savepoint相应目录下面
```shell
python train.py --help
```
运行上面命令后帮助结果如下:
``` text
--data_dir DATA_DIR 数据存储目录:news; THUCNews; xinwen_multi_label (default: news)
--data_suffix DATA_SUFFIX
数据存储目录 (default: )
--root_dir ROOT_DIR 项目根目录 (default: )
--file_type FILE_TYPE
数据集文件类型:cvs,txt等 (default: csv)
--train_file_path TRAIN_FILE_PATH
训练集文件路径 (default: train.csv)
--dev_file_path DEV_FILE_PATH
验证集文件路径 (default: dev.csv)
--test_file_path TEST_FILE_PATH
测试集文件路径 (default: test.csv)
--class_path CLASS_PATH
分类文件路径 (default: class.txt)
--is_split_dataset [IS_SPLIT_DATASET]
是否拆分数据集 (default: False)
--test_out_path TEST_OUT_PATH
测试集文件路径 (default: test_out.csv)
--predict_file_path PREDICT_FILE_PATH
预测文件路径 (default: test.csv)
--dev_ratio DEV_RATIO
验证集拆分比例 (default: 0.1)
--logging_step LOGGING_STEP
打印日志频率 (default: 50)
--logging_path LOGGING_PATH
日志保存路径 (default: logging.log)
--label_name LABEL_NAME
分类标签列名称 (default: label)
--raw_label_name RAW_LABEL_NAME
原始分类标签名称 (default: label)
--text_name TEXT_NAME
文本列名称 (default: text)
--num_labels NUM_LABELS
类别数量 (default: 0)
--problem_type PROBLEM_TYPE
单标签分类还是多标签分类 multi_label_classification, single_label_classification (default: single_label_classification)
--model_name MODEL_NAME
模型名称: ernie (default: bert)
--hidden_size HIDDEN_SIZE
隐藏层大小 (default: 768)
--model_dir MODEL_DIR
模型存储目录 (default: savepoint/)
--model_suffix MODEL_SUFFIX
实际模型名称 (default: )
--checkpoint CHECKPOINT
检查点文件路径 (default: model_best.pth.tar)
--ws_path WS_PATH word sequence 文件路径 (default: ws.json)
--embedding_type EMBEDDING_TYPE
embedding 文件类型:weibo (default: None)
--embedding_path EMBEDDING_PATH
embedding path (default: sgns.weibo.word.bz2)
--custom CUSTOM 自定义模型名称 (default: text_cnn)
--is_bert [IS_BERT] 是否使用bert模型 (default: False)
--is_tokenizer [IS_TOKENIZER]
是否使用bert模型分词 (default: False)
--freeze_bert_head [FREEZE_BERT_HEAD]
是否冻结bert参数 (default: False)
--max_grad_norm MAX_GRAD_NORM
梯度修剪 (default: 10)
--batch_size BATCH_SIZE
batch_size大小 (default: 64)
--max_length MAX_LENGTH
最大长度 (default: 65)
--epochs EPOCHS 迭代次数 (default: 100)
--learning_rate LEARNING_RATE
学习率 (default: 5e-05)
--add_special_tokens [ADD_SPECIAL_TOKENS]
是否添加特别标记 (default: True)
--no_add_special_tokens
是否添加特别标记 (default: False)
--tokenizer TOKENIZER
分词器 (default: None)
--filter_sizes FILTER_SIZES [FILTER_SIZES ...]
cnn卷积核大小 (default: [])
--num_filters NUM_FILTERS
cnn输出通道数 (default: 256)
--dropout DROPOUT dropout比例 (default: 0.3)
--seed SEED 随机种子 (default: 1)
--device_name DEVICE_NAME
使用cpu,cuda或mps (default: cpu)
--device DEVICE 设备对象,cpu/cuda/mps (default: None)
--label2id LABEL2ID
--id2label ID2LABEL
--local_rank LOCAL_RANK
For distributed training: local_rank (default: -1)
--model_dict MODEL_DICT
模型字典数据 (default: {})
--patience PATIENCE early stop迭代数 (default: 20)
--inner_patience INNER_PATIENCE
内层忍耐数 (default: 400)
--writer WRITER tensorboard->SummaryWriter (default: None)
--model MODEL 模型 (default: None)
--optimizer OPTIMIZER
优化器 (default: None)
--ws WS word sequence 类 (default: None)
--embedding EMBEDDING
embedding (default: None)
--vocab_size VOCAB_SIZE
词表大小 (default: None)
--only_dev [ONLY_DEV]
是否只是验证 (default: False)
--top_k TOP_K 是否只是验证 (default: (1, 2))
```
以上命令随着项目的完善会有变动,具体帮助名称请自己运行查看。
#### 安装教程
1. 使用requirements.txt进行依赖安装
#### 使用说明
1. 本项目使用的数据集为头条新闻分类数据集,存放在data/news目录下面。主要包含15个类别,具体类别如下:
```text
news_agriculture
news_car
news_culture
news_edu
news_entertainment
news_finance
news_game
news_house
news_military
news_sports
news_stock
news_story
news_tech
news_travel
news_world
```
训练集中每个数据集的数量柱状图如下:

jieba分词的词云效果如下:
训练集准确率图像:
训练集损失图像:
验证集和训练集精度图像:
验证集和训练集损失图像:
| model_name | bert模型名称 |
|------------|----------|
| bert | bert-base-chinese |
| bert_wwm | hfl/chinese-bert-wwm |
| raw_bert | bert-base-chinese |
| roberta | hfl/chinese-roberta-wwm-ext |
|ernie | nghuyong/ernie-1.0 |
| ernie_healthy |nghuyong/ernie-health-zh |
|albert | voidful/albert_chinese_tiny |
|reformer | junnyu/roformer_chinese_base|
本数据运行步骤:
1. 本数据集只有原始数据,没有验证集,测试集也是只有文本,没有标签。
首先我们要先把数据拆分成训练集和验证集。可以使用process_data.py统一方法进行数据处理。
2. 进行模型训练。在模型训练过程中每轮训练后都会对验证集进行验证,如果验证集的最好分数一直没有被更新,到达忍耐度就会自动停止。
3. 进行模型测试。预测的结果会保存文本,正确标签,错误标签以及错误预测的概率。
4. 进行模型预测。模型预测中有两个方法,一个是对csv文件进行预测,一个是对单个句子进行预测。
##### 下面是对news数据集使用bert以上四个步骤的整体处理命令:
```shell
python process_data.py \
--is_bert true \
--split_path raw.csv \
--is_split true \
--is_split_dev true \
--is_split_test true \
--is_create_class_txt true \
--is_create_class_index true \
--is_remove_punc false \
--is_analysis true \
--label_column label_desc \
--data_dir news \
--split char \
--is_save_analysis false \
--raw_label_name label \
--text_name sentence
# 训练
python train.py \
--is_bert true \
--data_dir news \
--model_name bert \
--freeze_bert_head false \
--raw_label_name label \
--text_name text \
--hidden_size 768 \
--is_tokenizer true \
--batch_size 64 \
--max_length 65 \
--epochs 100 \
--logging_step 50 \
--patience 3 \
--device_name cpu
# 测试
python test.py \
--is_bert true \
--data_dir news \
--model_name bert \
--freeze_bert_head true \
--raw_label_name label \
--text_name text \
--hidden_size 768 \
--is_tokenizer true \
--batch_size 64 \
--max_length 65 \
--epochs 100 \
--logging_step 50 \
--patience 3 \
--device_name cpu
# 预测
python predict.py \
--is_bert true \
--data_dir news \
--model_name bert \
--freeze_bert_head true \
--raw_label_name label \
--text_name text \
--hidden_size 768 \
--is_tokenizer true \
--batch_size 64 \
--max_length 65 \
--epochs 100 \
--logging_step 50 \
--patience 3 \
--device_name cpu \
--predict_file_path raw_test.csv
```
其他参数可以根据自己的使用情况进行指定,如果运行默认的数据集,无需指定参数。