From 680f8f8e412481df0011d7f38f039918cbf554f8 Mon Sep 17 00:00:00 2001 From: zhangyi Date: Tue, 19 Apr 2022 15:40:43 +0800 Subject: [PATCH] modify the files --- tutorials/application/source_en/index.rst | 5 +- .../source_en/nlp/sentiment_analysis.md | 613 ++++++++++++++++++ .../source_en/nlp/sequence_labeling.md | 467 +++++++++++++ tutorials/application/source_en/nlp/test2.md | 3 - tutorials/source_en/beginner/infer.md | 32 +- tutorials/source_en/beginner/train.md | 8 +- 6 files changed, 1101 insertions(+), 27 deletions(-) create mode 100644 tutorials/application/source_en/nlp/sentiment_analysis.md create mode 100644 tutorials/application/source_en/nlp/sequence_labeling.md delete mode 100644 tutorials/application/source_en/nlp/test2.md diff --git a/tutorials/application/source_en/index.rst b/tutorials/application/source_en/index.rst index 1225ea8552..c08037fbb1 100644 --- a/tutorials/application/source_en/index.rst +++ b/tutorials/application/source_en/index.rst @@ -12,10 +12,11 @@ Application :caption: CV cv/test1 - + .. toctree:: :glob: :maxdepth: 1 :caption: NLP - nlp/test2 + nlp/sentiment_analysis + nlp/sequence_labeling diff --git a/tutorials/application/source_en/nlp/sentiment_analysis.md b/tutorials/application/source_en/nlp/sentiment_analysis.md new file mode 100644 index 0000000000..3daf4b7113 --- /dev/null +++ b/tutorials/application/source_en/nlp/sentiment_analysis.md @@ -0,0 +1,613 @@ +# Using RNN for Sentiment Classification + + + +## Overview + +Sentiment classification is a classic task in natural language processing. It is a typical classification problem. The following uses MindSpore to implement an RNN-based sentimental classification model to achieve the following effects: + +```text +Input: This film is terrible +Correct label: Negative +Forecast label: Negative + +Input: This film is great +Correct label: Positive +Forecast label: Positive +``` + +## Data Preparation + +This section uses the classic [IMDB Movie Review Dataset](https://ai.stanford.edu/~amaas/data/sentiment/) for sentimental classification. The dataset contains positive and negative data. The following is an example: + +| Review | Label | +|:---|:---:| +| "Quitting" may be as much about exiting a pre-ordained identity as about drug withdrawal. As a rural guy coming to Beijing, class and success must have struck this young artist face on as an appeal to separate from his roots and far surpass his peasant parents' acting success. Troubles arise, however, when the new man is too new, when it demands too big a departure from family, history, nature, and personal identity. The ensuing splits, and confusion between the imaginary and the real and the dissonance between the ordinary and the heroic are the stuff of a gut check on the one hand or a complete escape from self on the other. | Negative | +| This movie is amazing because the fact that the real people portray themselves and their real life experience and do such a good job it's like they're almost living the past over again. Jia Hongsheng plays himself an actor who quit everything except music and drugs struggling with depression and searching for the meaning of life while being angry at everyone especially the people who care for him most. | Positive | + +In addition, the pre-trained word vectors are used to encode natural language words to obtain semantic features of text. In this section, the Global Vectors for Word Representation ([GloVe](https://nlp.stanford.edu/projects/glove/)) are selected as embeddings. + +### Data Downloading Module + +To facilitate the download of datasets and pre-trained word vectors, a data download module is designed to implement a visualized download process and save the data to a specified path. The data download module uses the `requests` library to send HTTP requests and uses the `tqdm` library to visualize the download percentage. To ensure download security, temporary files are downloaded in I/O mode, saved to a specified path, and returned. + +> The `tqdm` and `requests` libraries need to be manually installed. The command is `pip install tqdm requests`. + +```python +import os +import shutil +import requests +import tempfile +from tqdm import tqdm +from typing import IO +from pathlib import Path + +# Set the storage path to `home_path/.mindspore_examples`. +cache_dir = Path.home() / '.mindspore_examples' + +def http_get(url: str, temp_file: IO): + """Download data using the requests library and visualize the process using the tqdm library."" + req = requests.get(url, stream=True) + content_length = req.headers.get('Content-Length') + total = int(content_length) if content_length is not None else None + progress = tqdm(unit='B', total=total) + for chunk in req.iter_content(chunk_size=1024): + if chunk: + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + +def download(file_name: str, url: str): + """Download data and save it with the specified name."" + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + cache_path = os.path.join(cache_dir, file_name) + cache_exist = os.path.exists(cache_path) + if not cache_exist: + with tempfile.NamedTemporaryFile() as temp_file: + http_get(url, temp_file) + temp_file.flush() + temp_file.seek(0) + with open(cache_path, 'wb') as cache_file: + shutil.copyfileobj(temp_file, cache_file) + return cache_path +``` + +After the data download module is complete, download the IMDB dataset for testing. (The HUAWEI CLOUD image is used to improve the download speed.) The download process and storage path are as follows: + +```python +imdb_path = download('aclImdb_v1.tar.gz', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/aclImdb_v1.tar.gz') +imdb_path +``` + +```text + '/root/.mindspore_examples/aclImdb_v1.tar.gz' +``` + +### Loading the IMDB Dataset + +The downloaded IMDB dataset is a `tar.gz` file. Use the `tarfile` library of Python to read the dataset and store all data and labels separately. The decompression directory of the original IMDB dataset is as follows: + +```text + ├── aclImdb + │ ├── imdbEr.txt + │ ├── imdb.vocab + │ ├── README + │ ├── test + │ └── train + │ ├── neg + │ ├── pos + ... +``` + +The dataset has been divided into two parts: train and test. Each part contains the neg and pos folders. You need to use the train and test parts to read and process data and labels, respectively. + +```python +import re +import six +import string +import tarfile + +class IMDBData(): + """IMDB dataset loader. + + Load the IMDB dataset and processes it as a Python iteration object. + + """ + label_map = { + "pos": 1, + "neg": 0 + } + def __init__(self, path, mode="train"): + self.mode = mode + self.path = path + self.docs, self.labels = [], [] + + self._load("pos") + self._load("neg") + + def _load(self, label): + pattern = re.compile(r"aclImdb/{}/{}/.*\.txt$".format(self.mode, label)) + # Load data to the memory. + with tarfile.open(self.path) as tarf: + tf = tarf.next() + while tf is not None: + if bool(pattern.match(tf.name)): + # Segment text, remove punctuations and special characters, and convert text to lowercase. + self.docs.append(str(tarf.extractfile(tf).read().rstrip(six.b("\n\r")) + .translate(None, six.b(string.punctuation)).lower()).split()) + self.labels.append([self.label_map[label]]) + tf = tarf.next() + + def __getitem__(self, idx): + return self.docs[idx], self.labels[idx] + + def __len__(self): + return len(self.docs) +``` + +After the IMDB data is loaded, load the training dataset for testing and output the number of datasets. + +```python +imdb_train = IMDBData(imdb_path, 'train') +len(imdb_train) +``` + +```text + 25000 +``` + +After the IMDB dataset is loaded to the memory and built as an iteration object, you can use the `GeneratorDataset` API provided by `mindspore.dataset` to load the dataset iteration object and then perform data processing. The following encapsulates a function to load train and test using `GeneratorDataset`, and set `column_name` of the text and label in the dataset to `text` and `label`, respectively. + +```python +import mindspore.dataset as dataset + +def load_imdb(imdb_path): + imdb_train = dataset.GeneratorDataset(IMDBData(imdb_path, "train"), column_names=["text", "label"]) + imdb_test = dataset.GeneratorDataset(IMDBData(imdb_path, "test"), column_names=["text", "label"]) + return imdb_train, imdb_test +``` + +Load the IMDB dataset. You can see that `imdb_train` is a GeneratorDataset object. + +```python +imdb_train, imdb_test = load_imdb(imdb_path) +imdb_train +``` + +### Loading Pre-trained Word Vectors + +A pre-trained word vector is a numerical representation of an input word. The `nn.Embedding` layer uses the table lookup mode to input the index in the vocabulary corresponding to the word to obtain the corresponding expression vector. +Therefore, before model build, word vectors and vocabulary required by the Embedding layer need to be built. Here, we use the classic pre-trained word vectors, GloVe. +The data format is as follows: + +| Word | Vector | +|:---|:---:| +| the | 0.418 0.24968 -0.41242 0.1217 0.34527 -0.044457 -0.49688 -0.17862 -0.00066023 ...| +| , | 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -0.42852 -0.55641 -0.364 ... | + +The words in the first column are used as the vocabulary, and `dataset.text.Vocab` is used to load the words in sequence. In addition, the vector of each row is read and converted into `numpy.array` for the `nn.Embedding` to load weights. The sample code is as follows: + +```python +import zipfile +import numpy as np + +def load_glove(glove_path): + glove_100d_path = os.path.join(cache_dir, 'glove.6B.100d.txt') + if not os.path.exists(glove_100d_path): + glove_zip = zipfile.ZipFile(glove_path) + glove_zip.extractall(cache_dir) + + embeddings = [] + tokens = [] + with open(glove_100d_path, encoding='utf-8') as gf: + for glove in gf: + word, embedding = glove.split(maxsplit=1) + tokens.append(word) + embeddings.append(np.fromstring(embedding, dtype=np.float32, sep=' ')) + # Add the embeddings corresponding to the special placeholders and . + embeddings.append(np.random.rand(100)) + embeddings.append(np.zeros((100,), np.float32)) + + vocab = dataset.text.Vocab.from_list(tokens, special_tokens=["", ""], special_first=False) + embeddings = np.array(embeddings).astype(np.float32) + return vocab, embeddings +``` + +The dataset may contain words that are not covered by the vocabulary. Therefore, the `` token needs to be added. In addition, because the input lengths are different, the `` tokens need to be added to short text when the text is packed into a batch. The length of the completed vocabulary is the length of the original vocabulary plus 2. + +Download and load GloVe to generate a vocabulary and a word vector weight matrix. + +```python +glove_path = download('glove.6B.zip', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/glove.6B.zip') +vocab, embeddings = load_glove(glove_path) +len(vocab.vocab()) +``` + +```text + 400002 +``` + +Use a vocabulary to convert `the` into an index ID, and query a word vector corresponding to the word vector matrix: + +```python +idx = vocab.tokens_to_ids('the') +embedding = embeddings[idx] +idx, embedding +``` + +```text + (0, + array([-0.038194, -0.24487 , 0.72812 , -0.39961 , 0.083172, 0.043953, + -0.39141 , 0.3344 , -0.57545 , 0.087459, 0.28787 , -0.06731 , + 0.30906 , -0.26384 , -0.13231 , -0.20757 , 0.33395 , -0.33848 , + -0.31743 , -0.48336 , 0.1464 , -0.37304 , 0.34577 , 0.052041, + 0.44946 , -0.46971 , 0.02628 , -0.54155 , -0.15518 , -0.14107 , + -0.039722, 0.28277 , 0.14393 , 0.23464 , -0.31021 , 0.086173, + 0.20397 , 0.52624 , 0.17164 , -0.082378, -0.71787 , -0.41531 , + 0.20335 , -0.12763 , 0.41367 , 0.55187 , 0.57908 , -0.33477 , + -0.36559 , -0.54857 , -0.062892, 0.26584 , 0.30205 , 0.99775 , + -0.80481 , -3.0243 , 0.01254 , -0.36942 , 2.2167 , 0.72201 , + -0.24978 , 0.92136 , 0.034514, 0.46745 , 1.1079 , -0.19358 , + -0.074575, 0.23353 , -0.052062, -0.22044 , 0.057162, -0.15806 , + -0.30798 , -0.41625 , 0.37972 , 0.15006 , -0.53212 , -0.2055 , + -1.2526 , 0.071624, 0.70565 , 0.49744 , -0.42063 , 0.26148 , + -1.538 , -0.30223 , -0.073438, -0.28312 , 0.37104 , -0.25217 , + 0.016215, -0.017099, -0.38984 , 0.87424 , -0.72569 , -0.51058 , + -0.52028 , -0.1459 , 0.8278 , 0.27062 ], dtype=float32)) +``` + +## Dataset Preprocessing + +Word segmentation is performed on the IMDB dataset loaded by the loader, but the dataset does not meet the requirements for building training data. Therefore, extra preprocessing is required. The preprocessing is as follows: + +- Use the Vocab to convert all tokens to index IDs. +- The length of the text sequence is unified. If the length is insufficient, `` is used to supplement the length. If the length exceeds the limit, the excess part is truncated. + +Here, the API provided in `mindspore.dataset` is used for preprocessing. The APIs used here are designed for MindSpore high-performance data engines. The operations corresponding to each API are considered as a part of the data pipeline. For details, see [MindSpore Data Engine](https://www.mindspore.cn/docs/en/r1.7/design/data_engine.html). +For the table query operation from a token to an index ID, use the `text.Lookup` API to load the built vocabulary and specify `unknown_token`. The `PadEnd` API is used to unify the length of the text sequence. This API defines the maximum length and padding value (`pad_value`). In this example, the maximum length is 500, and the padding value corresponds to the index ID of `` in the vocabulary. + +> In addition to pre-processing the `text` data in the dataset, the `label` data needs to be converted to the float32 format to meet the subsequent model training requirements. + +```python +import mindspore + +lookup_op = dataset.text.Lookup(vocab, unknown_token='') +pad_op = dataset.transforms.c_transforms.PadEnd([500], pad_value=vocab.tokens_to_ids('')) +type_cast_op = dataset.transforms.c_transforms.TypeCast(mindspore.float32) +``` + +After the preprocessing is complete, you need to add data to the dataset processing pipeline and use the `map` API to add operations to the specified column. + +```python +imdb_train = imdb_train.map(operations=[lookup_op, pad_op], input_columns=['text']) +imdb_train = imdb_train.map(operations=[type_cast_op], input_columns=['label']) + +imdb_test = imdb_test.map(operations=[lookup_op, pad_op], input_columns=['text']) +imdb_test = imdb_test.map(operations=[type_cast_op], input_columns=['label']) +``` + +The IMDB dataset does not contain the validation set. Therefore, you need to manually divide the dataset into training and validation parts, with the ratio of 0.7 to 0.3. + +```python +imdb_train, imdb_valid = imdb_train.split([0.7, 0.3]) +``` + +Finally, specify the batch size of the dataset by using the `batch` API and determine whether to discard the remaining data that cannot be exactly divided by the batch size. + +> Call the `map`, `split`, and `batch` APIs of the dataset to add corresponding operations to the dataset processing pipeline. The return value is of the new dataset type. Currently, only the pipeline operation is defined. During execution, the data processing pipeline is executed to obtain the processed data and send the data to the model for training. + +```python +imdb_train = imdb_train.batch(64, drop_remainder=True) +imdb_valid = imdb_valid.batch(64, drop_remainder=True) +``` + +## Model Building + +After the dataset is processed, we design the model structure for sentimental classification. First, the input text (that is, the serialized index ID list) needs to be converted into a vectorized representation through table lookup. In this case, the `nn.Embedding` layer needs to be used to load the GloVe, and then the RNN is used to perform feature extraction. Finally, the RNN is connected to a fully-connected layer, that is, `nn.Dense`, to convert the feature into a size that is the same as the number of classifications for subsequent model optimization training. The overall model structure is as follows: + +```text +nn.Embedding -> nn.RNN -> nn.Dense +``` + +The long short term memory (LSTM) variant that can avoid the RNN gradient vanishing problem is used as the feature extraction layer. The model is described as follows: + +### Embedding + +The Embedding layer may also be referred to as an EmbeddingLookup layer. A function of the Embedding layer is to use an index ID to search for a vector of an ID corresponding to the weight matrix. When an input is a sequence including index IDs, a matrix with a same length is searched for and returned. For example: + +```text +embedding = nn.Embedding (1000, 100) # The size of the vocabulary (the value range of index) is 1000, and the size of the vector is 100. +input shape: (1, 16) # The sequence length is 16. +output shape: (1, 16, 100) +``` + +Here, the processed GloVe word vector matrix is used. `embedding_table` of `nn.Embedding` is set to the pre-trained word vector matrix. The vocabulary size `vocab_size` is 400002, and `embedding_size` is the size of the selected `glove.6B.100d` vector, that is, 100. + +### Recurrent Neural Network (RNN) + +RNN is a type of neural network that uses sequence data as an input, performs recursion in the evolution direction of a sequence, and connects all nodes (circulating units) in a chain. The following figure shows the general RNN structure. + +![RNN-0](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/tutorials/application/source_zh_cn/nlp/images/0-RNN-0.png) + +> The left part of the figure shows an RNN Cell cycle, and the right part shows the RNN chain connection. Actually, there is only one Cell parameter regardless of a single RNN Cell or an RNN network, and the parameter is updated in continuous cyclic calculation. + +The recurrent feature of the RNN matches the sequence feature (a sentence is a sequence composed of words) of the natural language text. Therefore, the RNN is widely used in the research of natural language processing. The following figure shows the disassembled RNN structure. + +![RNN](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/tutorials/application/source_zh_cn/nlp/images/0-RNN.png) + +A structure of a single RNN Cell is simple, causing the gradient vanishing problem. Specifically, when a sequence in the RNN is relatively long, information of a sequence header is basically lost at a tail of the sequence. To solve this problem, the long short term memory (LSTM) is proposed. The gating mechanism is used to control the retention and discarding of information flows in each cycle. The following figure shows the disassembled LSTM structure. + +![LSTM](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.7/tutorials/application/source_zh_cn/nlp/images/0-LSTM.png) + +In this section, the LSTM variant instead of the classic RNN is used for feature extraction to avoid the gradient vanishing problem and obtain a better model effect. The formula corresponding to `nn.LSTM` in MindSpore is as follows: + +$$h_{0:t}, (h_t, c_t) = \text{LSTM}(x_{0:t}, (h_0, c_0))$$ + +Herein, `nn.LSTM` hides a cycle of the entire recurrent neural network on a sequence time step. After the input sequence and the initial state are sent, you can obtain a matrix formed by splicing hidden states of each time step and a hidden state corresponding to the last time step. We use the hidden state of the last time step as the encoding feature of the input sentence and send it to the next layer. + +> Time step: Each cycle calculated by the recurrent neural network is a time step. When a text sequence is sent, a time step corresponds to a word. Therefore, in this example, the output $h_{0:t}$ of the LSTM corresponds to the hidden state set of each word, and $h_t$ and $c_t$ correspond to the hidden state corresponding to the last word. + +### Dense + +After the sentence feature is obtained through LSTM encoding, the sentence feature is sent to a fully-connected layer, that is, `nn.Dense`. The feature dimension is converted into dimension 1 required for binary classification. The output after passing through the Dense layer is the model prediction result. + +> The `sigmoid` operation is performed after the Dense layer to normalize the predicted value to the `[0,1]` range. The normalized value is used together with `BCELoss`(BinaryCrossEntropyLoss) to calculate the binary cross entropy loss. + +```python +import mindspore +import mindspore.nn as nn +import mindspore.numpy as mnp +import mindspore.ops as ops +from mindspore import Tensor + +class RNN(nn.Cell): + def __init__(self, embeddings, hidden_dim, output_dim, n_layers, + bidirectional, dropout, pad_idx): + super().__init__() + vocab_size, embedding_dim = embeddings.shape + self.embedding = nn.Embedding(vocab_size, embedding_dim, embedding_table=Tensor(embeddings), padding_idx=pad_idx) + self.rnn = nn.LSTM(embedding_dim, + hidden_dim, + num_layers=n_layers, + bidirectional=bidirectional, + dropout=dropout, + batch_first=True) + self.fc = nn.Dense(hidden_dim * 2, output_dim) + self.dropout = nn.Dropout(1 - dropout) + self.sigmoid = ops.Sigmoid() + + def construct(self, inputs): + embedded = self.dropout(self.embedding(inputs)) + _, (hidden, _) = self.rnn(embedded) + hidden = self.dropout(mnp.concatenate((hidden[-2, :, :], hidden[-1, :, :]), axis=1)) + output = self.fc(hidden) + return self.sigmoid(output) +``` + +### Loss Function and Optimizer + +After the model body is built, instantiate the network based on the specified parameters, select the loss function and optimizer, and encapsulate them using `nn.TrainOneStepCell`. For a feature of the sentimental classification problem in this section, that is, a binary classification problem for predicting positive or negative, `nn.BCELoss` (binary cross entropy loss function) is selected. Herein, `nn.BCEWithLogitsLoss` may also be selected, and includes a `sigmoid` operation, that is: + +```text +BCEWithLogitsLoss = Sigmoid + BCELoss +``` + +If `BECLoss` is used, the `reduction` parameter must be set to the average value. It is then associated with the instantiated network object using `nn.WithLossCell`. + +After selecting a proper loss function and the `Adam` optimizer, pass them to `TrainOneStepCell`. + +> MindSpore is designed to calculate and optimize the entire graph. Therefore, the loss function and optimizer are considered as a part of the computational graph. Therefore, `TrainOneStepCell` is built as the Wrapper. + +```python +hidden_size = 256 +output_size = 1 +num_layers = 2 +bidirectional = True +dropout = 0.5 +lr = 0.001 +pad_idx = vocab.tokens_to_ids('') + +net = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, dropout, pad_idx) +loss = nn.BCELoss(reduction='mean') +net_with_loss = nn.WithLossCell(net, loss) +optimizer = nn.Adam(net.trainable_params(), learning_rate=lr) +train_one_step = nn.TrainOneStepCell(net_with_loss, optimizer) +``` + +### Training Logic + +After the model is built, design the training logic. Generally, the training logic consists of the following steps: + +1. Read the data of a batch. +2. Send the data to the network for forward computation and backward propagation, and update the weight. +3. Return the loss. + +Based on this logic, use the `tqdm` library to design an epoch training function for visualization of the training process and loss. + +```python +def train_one_epoch(model, train_dataset, epoch=0): + model.set_train() + total = train_dataset.get_dataset_size() + loss_total = 0 + step_total = 0 + with tqdm(total=total) as t: + t.set_description('Epoch %i' % epoch) + for i in train_dataset.create_tuple_iterator(): + loss = model(*i) + loss_total += loss.asnumpy() + step_total += 1 + t.set_postfix(loss=loss_total/step_total) + t.update(1) +``` + +### Evaluation Metrics and Logic + +After the training logic is complete, you need to evaluate the model. That is, compare the prediction result of the model with the correct label of the test set to obtain the prediction accuracy. Because sentimental classification of the IMDB is a binary classification problem, you can directly round off the predicted value to obtain a classification label (0 or 1), and then determine whether the classification label is equal to a correct label. The following describes the implementation of the function for calculating the binary classification accuracy: + +```python +def binary_accuracy(preds, y): + """ + Calculate the accuracy of each batch. + """ + + # Round off the predicted value. + rounded_preds = np.around(preds) + correct = (rounded_preds == y).astype(np.float32) + acc = correct.sum() / len(correct) + return acc +``` + +After the accuracy calculation function is available, similar to the training logic, the evaluation logic is designed in the following steps: + +1. Read the data of a batch. +2. Send the data to the network for forward computation to obtain the prediction result. +3. Calculate the accuracy. + +Similar to the training logic, `tqdm` is used to visualize the loss and process. In addition, the loss evaluation result is returned for determining the model quality when the model is saved. + +> During the evaluation, the model used is the network body that does not contain the loss function and optimizer. +> Before evaluation, you need to use `model.set_train(False)` to set the model to the evaluation state. In this case, Dropout does not take effect. + +```python +def evaluate(model, test_dataset, criterion, epoch=0): + total = test_dataset.get_dataset_size() + epoch_loss = 0 + epoch_acc = 0 + step_total = 0 + model.set_train(False) + + with tqdm(total=total) as t: + t.set_description('Epoch %i' % epoch) + for i in test_dataset.create_tuple_iterator(): + predictions = model(i[0]) + loss = criterion(predictions, i[1]) + epoch_loss += loss.asnumpy() + + acc = binary_accuracy(predictions.asnumpy(), i[1].asnumpy()) + epoch_acc += acc + + step_total += 1 + t.set_postfix(loss=epoch_loss/step_total, acc=epoch_acc/step_total) + t.update(1) + + return epoch_loss / total +``` + +## Model Training and Saving + +The model building, training, and evaluation logic design are complete. The following describes how to train a model. In this example, the number of training epochs is set to 5. In addition, maintain the `best_valid_loss` variable for saving the optimal model. Based on the loss value of each epoch of evaluation, select the epoch with the minimum loss value and save the model. + +By default, MindSpore uses the static graph mode (Define and Run) for training. In the first step, computational graph is built, which is time-consuming but improve the overall training efficiency. To perform single-step debugging or use the dynamic graph mode, you can use the following code: + +```python +from mindspore import context +context.set_context(mode=context.PYNATIVE_MODE) +``` + +```python +from mindspore import save_checkpoint + +num_epochs = 5 +best_valid_loss = float('inf') +ckpt_file_name = os.path.join(cache_dir, 'sentiment-analysis.ckpt') + +for epoch in range(num_epochs): + train_one_epoch(train_one_step, imdb_train, epoch) + valid_loss = evaluate(net, imdb_valid, loss, epoch) + + if valid_loss < best_valid_loss: + best_valid_loss = valid_loss + save_checkpoint(net, ckpt_file_name) +``` + +```text + Epoch 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 273/273 [00:48<00:00, 5.59it/s, loss=0.681] + Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [00:42<00:00, 2.72it/s, acc=0.581, loss=0.674] + Epoch 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 273/273 [00:44<00:00, 6.15it/s, loss=0.661] + Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [00:41<00:00, 2.81it/s, acc=0.759, loss=0.519] + Epoch 2: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 273/273 [00:44<00:00, 6.15it/s, loss=0.487] + Epoch 2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [00:41<00:00, 2.82it/s, acc=0.836, loss=0.383] + Epoch 3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 273/273 [00:44<00:00, 6.15it/s, loss=0.35] + Epoch 3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [00:41<00:00, 2.83it/s, acc=0.868, loss=0.305] + Epoch 4: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 273/273 [00:44<00:00, 6.18it/s, loss=0.298] + Epoch 4: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 117/117 [00:41<00:00, 2.82it/s, acc=0.916, loss=0.219] +``` + +You can see that the loss decreases gradually in each epoch and the accuracy of the verification set increases gradually. + +## Model Loading and Testing + +After model training is complete, you need to test or deploy the model. In this case, you need to load the saved optimal model (that is, checkpoint) for subsequent tests. The checkpoint loading and network weight loading APIs provided by MindSpore are used to load the saved model checkpoint to the memory and load the checkpoint to the model. + +> The `load_param_into_net` API returns the weight name that does not match the checkpoint in the model. If the weight name matches the checkpoint, an empty list is returned. + +```python +from mindspore import load_checkpoint, load_param_into_net + +param_dict = load_checkpoint(ckpt_file_name) +load_param_into_net(net, param_dict) +``` + +```text + [] +``` + +Batch the test set, and then use the evaluation method to evaluate the effect of the model on the test set. + +```python +imdb_test = imdb_test.batch(64) +evaluate(net, imdb_test, loss) +``` + +```text + Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:31<00:00, 12.36it/s, acc=0.873, loss=0.322] + + 0.32153618827347863 +``` + +## Custom Input Test + +Finally, we design a prediction function to implement the effect described at the beginning. Enter a comment to obtain the sentimental classification of the comment. Specifically, the following steps are included: + +1. Perform word segmentation on an input sentence. +2. Obtain index ID sequence by using the vocabulary. +3. Convert the index IDs sequence into tensors. +4. Send tensors to the model to obtain the prediction result. +5. Print the prediction result. + +The sample code is as follows: + +```python +score_map = { + 1: "Positive", + 0: "Negative" +} + +def predict_sentiment(model, vocab, sentence): + model.set_train(False) + tokenized = sentence.lower().split() + indexed = vocab.tokens_to_ids(tokenized) + tensor = mindspore.Tensor(indexed, mindspore.int32) + tensor = tensor.expand_dims(0) + prediction = model(tensor) + return score_map[int(np.round(prediction.asnumpy()))] +``` + +Finally, predict the examples in the preceding section. It shows that the model can classify the sentiments of the statements. + +```python +predict_sentiment(net, vocab, "This film is terrible") +``` + +```text + 'Negative' +``` + +```python +predict_sentiment(net, vocab, "This film is great") +``` + +```text + 'Positive' +``` \ No newline at end of file diff --git a/tutorials/application/source_en/nlp/sequence_labeling.md b/tutorials/application/source_en/nlp/sequence_labeling.md new file mode 100644 index 0000000000..998cebb21b --- /dev/null +++ b/tutorials/application/source_en/nlp/sequence_labeling.md @@ -0,0 +1,467 @@ +# Sequence Labeling Implementation Using LSTM+CRF + + + +## Overview + +Sequence labeling refers to the process of labeling each token for a given input sequence. Sequence labeling is usually used to extract information from text, including word segmentation, position tagging, and named entity recognition (NER). The following uses NER as an example: + +| Input Sequence| the | wall | street | journal | reported | today | that | apple | corporation | made | money | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +|Output Labeling| B | I | I | I | O | O | O | B | I | O | O | + +As shown in the preceding table, `the wall street journal` and `apple corporation` are place names and need to be identified. We predict the label of each input word and identify the entity based on the label. +> A common labeling method for NER is used, that is, BIOE labeling. The beginning of an entity is labeled as B, other parts are labeled as I, and non-entity is labeled as O. + +## Conditional Random Field (CRF) + +It can be learned from the preceding example that labeling a sequence is actually performing label prediction on each token in the sequence, and may be directly considered as a simple multi-classification problem. However, sequence labeling not only needs to classify and predict a single token, but also directly associates adjacent tokens. The `the wall street journal` is used as an example. + +| Input Sequence| the | wall | street | journal | | +| --- | --- | --- | --- | --- | --- | +| Output Labeling| B | I | I | I | √ | +| Output Labeling| O | I | I | I | × | + +As shown in the preceding table, the four tokens contained in the correct entity depend on each other. A word before I must be B or I. However, in the error output, the token `the` is marked as O, which violates the dependency. If NER is regarded as a multi-classification problem, the prediction probability of each word is independent and similar problems may occur. Therefore, an algorithm that can learn the association relationship is introduced to ensure the correctness of the prediction result. CRF is a [probabilistic graphical model](https://en.wikipedia.org/wiki/Graphical_model) suitable for this scenario. The definition and parametric form of conditional random field are briefly analyzed in the following. + +> Considering the linear sequence feature of the sequence labeling problem, the CRF described in this section refers to the linear chain CRF. + +Assume that $x=\{x_0, ..., x_n\}$ indicates the input sequence, $y=\{y_0, ..., y_n\}, y \in Y$ indicates the output labeling sequence, $n$ indicates the maximum length of the sequence, and $Y$ indicates the set of all possible output sequences corresponding to $x$. The probability of the output sequence $y$ is as follows: + +$$\begin{align}P(y|x) = \frac{\exp{(\text{Score}(x, y)})}{\sum_{y' \in Y} \exp{(\text{Score}(x, y')})} \qquad (1)\end{align}$$ + +If $x_i$ and $y_i$ are the $i$th token and the corresponding label in the sequence, $\text{Score}$ must calculate the mapping between $x_i$ and $y_i$ and capture the relationship between adjacent labels $y_{i-1}$ and $y_{i}$. Therefore, two probability functions are defined: + +1. The emission probability function $\psi_\text{EMIT}$ indicates the probability of $x_i \rightarrow y_i$. +2. The transition probability function $\psi_\text{TRANS}$ indicates the probability of $y_{i-1} \rightarrow y_i$. + +The formula for calculating $\text{Score}$ is as follows: + +$$\begin{align}\text{Score}(x,y) = \sum_i \log \psi_\text{EMIT}(x_i \rightarrow y_i) + \log \psi_\text{TRANS}(y_{i-1} \rightarrow y_i) \qquad (2)\end{align} $$ + +Assume that the label set is $T$. Build a matrix $\textbf{P}$ with a size of $|T|x|T|$ to store the transition probability between labels. A hidden state $h$ output by the encoding layer (which may be Dense, LSTM, or the like) may be directly considered as an emission probability. In this case, the formula for calculating $\text{Score}$ can be converted into the following: + +$$\begin{align}\text{Score}(x,y) = \sum_i h_i[y_i] + \textbf{P}_{y_{i-1}, y_{i}} \qquad (3)\end{align}$$ + +> For details about the complete CRF-based deduction, see [Log-Linear Models, MEMMs, and CRFs](http://www.cs.columbia.edu/~mcollins/crf.pdf). + +Next, we use MindSpore to implement the CRF parameterization based on the preceding formula. First, a forward training part of a CRF layer is implemented, the CRF and a loss function are combined, and a negative log likelihood (NLL) function commonly used for a classification problem is selected. + +$$\begin{align}\text{Loss} = -log(P(y|x)) \qquad (4)\end{align} $$ + +According to the formula $(1)$, + +$$\begin{align}\text{Loss} = -log(\frac{\exp{(\text{Score}(x, y)})}{\sum_{y' \in Y} \exp{(\text{Score}(x, y')})}) \qquad (5)\end{align} $$ + +$$\begin{align}= log(\sum_{y' \in Y} \exp{(\text{Score}(x, y')}) - \text{Score}(x, y) \end{align}$$ + +According to the formula $(5)$, the minuend is called Normalizer, and the subtrahend is called Score. The final loss is obtained after the subtraction. + +### Score Calculation + +First, the score corresponding to the correct label sequence is calculated according to the formula $(3)$. It should be noted that, in addition to the transition probability matrix $\textbf{P}$, two vectors whose sizes are $|T|$ need to be maintained, and are respectively used as transition probabilities at the beginning and the end of the sequence. In addition, a mask matrix $mask$ is introduced. When multiple sequences are packed into a batch, the filled values are ignored. In this way, the $\text{Score}$ calculation contains only valid tokens. + +```python +def compute_score(emissions, tags, seq_ends, mask, trans, start_trans, end_trans): + # emissions: (seq_length, batch_size, num_tags) + # tags: (seq_length, batch_size) + # mask: (seq_length, batch_size) + + seq_length, batch_size = tags.shape + mask = mask.astype(emissions.dtype) + + # Set score to the initial transition probability. + # shape: (batch_size,) + score = start_trans[tags[0]] + # score += Probability of the first emission + # shape: (batch_size,) + score += emissions[0, mnp.arange(batch_size), tags[0]] + + for i in range(1, seq_length): + # Probability that the label is transited from i-1 to i (valid when mask == 1). + # shape: (batch_size,) + score += trans[tags[i - 1], tags[i]] * mask[i] + + # Emission probability of tags[i] prediction(valid when mask == 1). + # shape: (batch_size,) + score += emissions[i, mnp.arange(batch_size), tags[i]] * mask[i] + + # End the transition. + # shape: (batch_size,) + last_tags = tags[seq_ends, mnp.arange(batch_size)] + # score += End transition probability + # shape: (batch_size,) + score += end_trans[last_tags] + + return score +``` + +### Normalizer Calculation + +According to the formula $(5)$, Normalizer is the Log-Sum-Exp of scores of all possible output sequences corresponding to $x$. In this case, if the enumeration method is used for calculation, each possible output sequence score needs to be calculated, and there are $|T|^{n}$ results in total. Here, we use the dynamic programming algorithm to improve the efficiency by reusing the calculation result. + +Assume that you need to calculate the scores $\text{Score}{i}$ of all possible output sequences from token $0$ to token $i$. In this case, scores $\text{Score}{i-1}$ of all possible output sequences from the $0$th token to the $i-1$th token may be calculated first. Therefore, the Normalizer can be rewritten as follows: + +$$log(\sum_{y'_{0,i} \in Y} \exp{(\text{Score}_i})) = log(\sum_{y'_{0,i-1} \in Y} \exp{(\text{Score}_{i-1} + h_{i} + \textbf{P}})) \qquad (6)$$ + +$h_i$ is the emission probability of the $i$th token, and $\textbf{P}$ is the transition matrix. Because the emission probability matrix $h$ and the transition probability matrix $\textbf{P}$ are independent of the sequence path calculation of $y$, we can obtain that: + +$$log(\sum_{y'_{0,i} \in Y} \exp{(\text{Score}_i})) = log(\sum_{y'_{0,i-1} \in Y} \exp{(\text{Score}_{i-1}})) + h_{i} + \textbf{P} \qquad (7)$$ + +According to formula (7), the Normalizer is implemented as follows: + +```python +def compute_normalizer(emissions, mask, trans, start_trans, end_trans): + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + + seq_length = emissions.shape[0] + + # Set score to the initial transition probability and add the first emission probability. + # shape: (batch_size, num_tags) + score = start_trans + emissions[0] + + for i in range(1, seq_length): + # The score dimension is extended to calculate the total score. + # shape: (batch_size, num_tags, 1) + broadcast_score = score.expand_dims(2) + + # The emission dimension is extended to calculate the total score. + # shape: (batch_size, 1, num_tags) + broadcast_emissions = emissions[i].expand_dims(1) + + # Calculate score_i according to formula (7). + # In this case, broadcast_score indicates all possible paths from token 0 to the current token. + # log_sum_exp corresponding to score + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + trans + broadcast_emissions + + # Perform the log_sum_exp operation on score_i to calculate the score of the next token. + # shape: (batch_size, num_tags) + next_score = mnp.log(mnp.sum(mnp.exp(next_score), axis=1)) + + # The score changes only when mask == 1. + # shape: (batch_size, num_tags) + score = mnp.where(mask[i].expand_dims(1), next_score, score) + + # Add the end transition probability. + # shape: (batch_size, num_tags) + score += end_trans + # Calculate log_sum_exp based on the scores of all possible paths. + # shape: (batch_size,) + return mnp.log(mnp.sum(mnp.exp(score), axis=1)) +``` + +### Viterbi Algorithm + +After the forward training part is completed, the decoding part needs to be implemented. Here we select the [Viterbi algorithm](https://en.wikipedia.org/wiki/Viterbi_algorithm) that is suitable for finding the optimal path of the sequence. Similar to calculating Normalizer, dynamic programming is used to solve all possible prediction sequence scores. The difference is that the label with the maximum score corresponding to token $i$ needs to be saved during decoding. The label is used by the Viterbi algorithm to calculate the optimal prediction sequence. + +After obtaining the maximum probability score $\text{Score}$ and the label history $\text{History}$ corresponding to each token, use the Viterbi algorithm to calculate the following formula: + +$$P_{0,i} = max(P_{0, i-1}) + P_{i-1, i}$$ + +The 0th token to the $i$th token correspond to sequences with a maximum probability. Only sequences with a maximum probability corresponding to the 0th token to the $i-1$th token and labels with a maximum probability corresponding to the $i$th token to the $i-1$th token need to be considered. Therefore, we solve each label with the highest probability in reverse order to form the optimal prediction sequence. + +> Due to the syntax restrictions of static graphs, the Viterbi algorithm is used to solve the optimal prediction sequence as a post-processing function and is not included in the implementation of the CRF layer. + +```python +def viterbi_decode(emissions, mask, trans, start_trans, end_trans): + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + + seq_length = mask.shape[0] + + score = start_trans + emissions[0] + history = () + + for i in range(1, seq_length): + broadcast_score = score.expand_dims(2) + broadcast_emission = emissions[i].expand_dims(1) + next_score = broadcast_score + trans + broadcast_emission + + # Obtain the label with the maximum score corresponding to the current token and save the label. + indices = next_score.argmax(axis=1) + history += (indices,) + + next_score = next_score.max(axis=1) + score = mnp.where(mask[i].expand_dims(1), next_score, score) + + score += end_trans + + return score, history + +def post_decode(score, history, seq_length): + # Use Score and History to calculate the optimal prediction sequence. + batch_size = seq_length.shape[0] + seq_ends = seq_length - 1 + # shape: (batch_size,) + best_tags_list = [] + + # Decode each sample in a batch in sequence. + for idx in range(batch_size): + # Search for the label that maximizes the prediction probability corresponding to the last token. + # Add it to the list of best prediction sequence stores. + best_last_tag = score[idx].argmax(axis=0) + best_tags = [int(best_last_tag.asnumpy())] + + # Repeatedly search for the label with the maximum prediction probability corresponding to each token and add the label to the list. + for hist in reversed(history[:seq_ends[idx]]): + best_last_tag = hist[idx][best_tags[-1]] + best_tags.append(int(best_last_tag.asnumpy())) + + # Reset the solved label sequence in reverse order to the positive sequence. + best_tags.reverse() + best_tags_list.append(best_tags) + + return best_tags_list +``` + +### CRF Layer + +After the code of the forward training part and the code of the decoding part are completed, a complete CRF layer is assembled. Considering that the input sequence may be padded, the actual length of the input sequence needs to be considered during CRF input. Therefore, in addition to the emissions matrix and label, the `seq_length` parameter is added to transfer the length of the sequence before padding and implement the `sequence_mask` method for generating the mask matrix. + +Based on the preceding code, `nn.Cell` is used for encapsulation. The complete CRF layer is implemented as follows: + +```python +import mindspore +import mindspore.nn as nn +import mindspore.numpy as mnp +from mindspore import Parameter +from mindspore.common.initializer import initializer, Uniform + +def sequence_mask(seq_length, max_length, batch_first=False): + """Generate the mask matrix based on the actual length and maximum length of the sequence." + range_vector = mnp.arange(0, max_length, 1, seq_length.dtype) + result = range_vector < seq_length.view(seq_length.shape + (1,)) + if batch_first: + return result.astype(mindspore.int64) + return result.astype(mindspore.int64).swapaxes(0, 1) + +class CRF(nn.Cell): + def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None: + if num_tags <= 0: + raise ValueError(f'invalid number of tags: {num_tags}') + super().__init__() + if reduction not in ('none', 'sum', 'mean', 'token_mean'): + raise ValueError(f'invalid reduction: {reduction}') + self.num_tags = num_tags + self.batch_first = batch_first + self.reduction = reduction + self.start_transitions = Parameter(initializer(Uniform(0.1), (num_tags,)), name='start_transitions') + self.end_transitions = Parameter(initializer(Uniform(0.1), (num_tags,)), name='end_transitions') + self.transitions = Parameter(initializer(Uniform(0.1), (num_tags, num_tags)), name='transitions') + + def construct(self, emissions, tags=None, seq_length=None): + if tags is None: + return self._decode(emissions, seq_length) + return self._forward(emissions, tags, seq_length) + + def _forward(self, emissions, tags=None, seq_length=None): + if self.batch_first: + batch_size, max_length = tags.shape + emissions = emissions.swapaxes(0, 1) + tags = tags.swapaxes(0, 1) + else: + max_length, batch_size = tags.shape + + if seq_length is None: + seq_length = mnp.full((batch_size,), max_length, mindspore.int64) + + mask = sequence_mask(seq_length, max_length) + + # shape: (batch_size,) + numerator = compute_score(emissions, tags, seq_length-1, mask, self.transitions, self.start_transitions, self.end_transitions) + # shape: (batch_size,) + denominator = compute_normalizer(emissions, mask, self.transitions, self.start_transitions, self.end_transitions) + # shape: (batch_size,) + llh = denominator - numerator + + if self.reduction == 'none': + return llh + if self.reduction == 'sum': + return llh.sum() + if self.reduction == 'mean': + return llh.mean() + return llh.sum() / mask.astype(emissions.dtype).sum() + + def _decode(self, emissions, seq_length=None): + if self.batch_first: + batch_size, max_length = emissions.shape[:2] + emissions = emissions.swapaxes(0, 1) + else: + batch_size, max_length = emissions.shape[:2] + + if seq_length is None: + seq_length = mnp.full((batch_size,), max_length, mindspore.int64) + + mask = sequence_mask(seq_length, max_length) + + return viterbi_decode(emissions, mask, self.transitions, self.start_transitions, self.end_transitions) +``` + +## BiLSTM+CRF Model + +After CRF is implemented, a bidirectional LSTM+CRF model is designed to train NER tasks. The model structure is as follows: + +```text +nn.Embedding -> nn.LSTM -> nn.Dense -> CRF +``` + +The LSTM extracts a sequence feature, obtains an emission probability matrix by means of Dense layer transformation, and finally sends the emission probability matrix to the CRF layer. The sample code is as follows: + +```python +class BiLSTM_CRF(nn.Cell): + def __init__(self, vocab_size, embedding_dim, hidden_dim, num_tags, padding_idx=0): + super().__init__() + self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx) + self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, bidirectional=True, batch_first=True) + self.hidden2tag = nn.Dense(hidden_dim, num_tags, 'he_uniform') + self.crf = CRF(num_tags, batch_first=True) + + def construct(self, inputs, seq_length, tags=None): + embeds = self.embedding(inputs) + outputs, _ = self.lstm(embeds, seq_length=seq_length) + feats = self.hidden2tag(outputs) + + crf_outs = self.crf(feats, tags, seq_length) + return crf_outs +``` + +After the model design is complete, two examples and corresponding labels are generated, and a vocabulary and a label table are built. + +```python +embedding_dim = 16 +hidden_dim = 32 + +training_data = [( + "the wall street journal reported today that apple corporation made money".split(), + "B I I I O O O B I O O".split() +), ( + "georgia tech is a university in georgia".split(), + "B I O O O O B".split() +)] + +word_to_idx = {} +word_to_idx[''] = 0 +for sentence, tags in training_data: + for word in sentence: + if word not in word_to_idx: + word_to_idx[word] = len(word_to_idx) + +tag_to_idx = {"B": 0, "I": 1, "O": 2} +``` + +```python +len(word_to_idx) +``` + +```text + 21 +``` + +Instantiate the model, select an optimizer, and send the model and optimizer to the Wrapper. + +> The NLLLoss has been calculated at the CRF layer. Therefore, you do not need to set Loss. + +```python +model = BiLSTM_CRF(len(word_to_idx), embedding_dim, hidden_dim, len(tag_to_idx)) +optimizer = nn.SGD(model.trainable_params(), learning_rate=0.01, weight_decay=1e-4) +``` + +```python +train_one_step = nn.TrainOneStepCell(model, optimizer) +``` + +Pack the generated data into a batch, pad the sequence with insufficient length based on the maximum sequence length, and return tensors consisting of the input sequence, output label, and sequence length. + +```python +def prepare_sequence(seqs, word_to_idx, tag_to_idx): + seq_outputs, label_outputs, seq_length = [], [], [] + max_len = max([len(i[0]) for i in seqs]) + + for seq, tag in seqs: + seq_length.append(len(seq)) + idxs = [word_to_idx[w] for w in seq] + labels = [tag_to_idx[t] for t in tag] + idxs.extend([word_to_idx[''] for i in range(max_len - len(seq))]) + labels.extend([tag_to_idx['O'] for i in range(max_len - len(seq))]) + seq_outputs.append(idxs) + label_outputs.append(labels) + + return mindspore.Tensor(seq_outputs, mindspore.int64), \ + mindspore.Tensor(label_outputs, mindspore.int64), \ + mindspore.Tensor(seq_length, mindspore.int64) +``` + +```python +data, label, seq_length = prepare_sequence(training_data, word_to_idx, tag_to_idx) +data.shape, label.shape, seq_length.shape +``` + +```text + ((2, 11), (2, 11), (2,)) +``` + +After the model is precompiled, 500 steps are trained. + +> Training process visualization depends on the `tqdm` library, which can be installed by running the `pip install tqdm` command. + +```python +train_one_step.compile(data, seq_length, label) +``` + +```python +from tqdm import tqdm + +steps = 500 +with tqdm(total=steps) as t: + for i in range(steps): + loss = train_one_step(data, seq_length, label) + t.set_postfix(loss=loss) + t.update(1) +``` + +Finally, let's observe the model effect after 500 steps of training. First, use the model to predict possible path scores and candidate sequences. + +```python +score, history = model(data, seq_length) +score +``` + +```text + Tensor(shape=[2, 3], dtype=Float32, value= + [[ 2.28206425e+01, 2.73965702e+01, 2.17514019e+01], + [ 2.22613220e+01, 2.26435127e+01, 2.83346519e+01]]) +``` + +Perform post-processing on the predicted score. + +```python +predict = post_decode(score, history, seq_length) +predict +``` + +```text + [[0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 2], [0, 1, 2, 2, 2, 2, 0, 2, 2]] +``` + +Finally, convert the predicted index sequence into a label sequence, print the output result, and view the effect. + +```python +idx_to_tag = {idx: tag for tag, idx in tag_to_idx.items()} + +def sequence_to_tag(sequences, idx_to_tag): + outputs = [] + for seq in sequences: + outputs.append([idx_to_tag[i] for i in seq]) + return outputs +``` + +```python +sequence_to_tag(predict, idx_to_tag) +``` + +```text + [['B', 'I', 'I', 'I', 'O', 'O', 'O', 'B', 'I', 'O', 'O'], + ['B', 'I', 'O', 'O', 'O', 'O', 'B', 'O', 'O']] +``` diff --git a/tutorials/application/source_en/nlp/test2.md b/tutorials/application/source_en/nlp/test2.md deleted file mode 100644 index f9ccad0a13..0000000000 --- a/tutorials/application/source_en/nlp/test2.md +++ /dev/null @@ -1,3 +0,0 @@ -# Test2 - -Coming soon. \ No newline at end of file diff --git a/tutorials/source_en/beginner/infer.md b/tutorials/source_en/beginner/infer.md index 7bfbe3f3ae..26209cda14 100644 --- a/tutorials/source_en/beginner/infer.md +++ b/tutorials/source_en/beginner/infer.md @@ -323,29 +323,29 @@ The following describes how to convert the model file format: 1. Use MindSpore Lite Converter to convert the file format in Linux. [Linux-x86_64 tool download link](https://www.mindspore.cn/lite/docs/en/r1.7/use/downloads.html) -```shell -# Download and decompress the software package and set the path of the software package. {converter_path} indicates the path of the decompressed tool package, and PACKAGE_ROOT_PATH indicates the environment variable. -export PACKAGE_ROOT_PATH={converter_path} + ```shell + # Download and decompress the software package and set the path of the software package. {converter_path} indicates the path of the decompressed tool package, and PACKAGE_ROOT_PATH indicates the environment variable. + export PACKAGE_ROOT_PATH={converter_path} -# Add the dynamic link library required by the converter to the environment variable LD_LIBRARY_PATH. -export LD_LIBRARY_PATH=${PACKAGE_ROOT_PATH}/tools/converter/lib:${LD_LIBRARY_PATH} + # Add the dynamic link library required by the converter to the environment variable LD_LIBRARY_PATH. + export LD_LIBRARY_PATH=${PACKAGE_ROOT_PATH}/tools/converter/lib:${LD_LIBRARY_PATH} -# Run the conversion command on the mindspore-lite-linux-x64/tools/converter/converter. -./converter_lite --fmk=MINDIR --modelFile=mobilenet_v2_1.0_224.mindir --outputFile=mobilenet_v2_1.0_224 -``` + # Run the conversion command on the mindspore-lite-linux-x64/tools/converter/converter. + ./converter_lite --fmk=MINDIR --modelFile=mobilenet_v2_1.0_224.mindir --outputFile=mobilenet_v2_1.0_224 + ``` 2. Use MindSpore Lite Converter to convert the file format in Windows. [Windows-x64 tool download link](https://www.mindspore.cn/lite/docs/en/r1.7/use/downloads.html) -```shell -# Download and decompress the software package and set the path of the software package. {converter_path} indicates the path of the decompressed tool package, and PACKAGE_ROOT_PATH indicates the environment variable. -set PACKAGE_ROOT_PATH={converter_path} + ```shell + # Download and decompress the software package and set the path of the software package. {converter_path} indicates the path of the decompressed tool package, and PACKAGE_ROOT_PATH indicates the environment variable. + set PACKAGE_ROOT_PATH={converter_path} -# Add the dynamic link library required by the converter to the environment variable PATH. -set PATH=%PACKAGE_ROOT_PATH%\tools\converter\lib;%PATH% + # Add the dynamic link library required by the converter to the environment variable PATH. + set PATH=%PACKAGE_ROOT_PATH%\tools\converter\lib;%PATH% -# Run the following command in the mindspore-lite-win-x64\tools\converter\converter directory: -call converter_lite --fmk=MINDIR --modelFile=mobilenet_v2_1.0_224.mindir --outputFile=mobilenet_v2_1.0_224 -``` + # Run the following command in the mindspore-lite-win-x64\tools\converter\converter directory: + call converter_lite --fmk=MINDIR --modelFile=mobilenet_v2_1.0_224.mindir --outputFile=mobilenet_v2_1.0_224 + ``` After the conversion is successful, `CONVERT RESULT SUCCESS:0` is displayed, and the `mobilenet_v2_1.0_224.ms` file is generated in the current directory. diff --git a/tutorials/source_en/beginner/train.md b/tutorials/source_en/beginner/train.md index e6de7d50ff..f8f85c67d7 100644 --- a/tutorials/source_en/beginner/train.md +++ b/tutorials/source_en/beginner/train.md @@ -6,9 +6,7 @@ After learning how to create a model and build a dataset in the preceding tutori ## Hyperparameters -Hyperparameters can be adjusted to control the model training and optimization process. Different hyperparameter values may affect the model training and convergence speed. Currently, deep learning models are optimized using the batch stochastic gradient descent algorithm. The principle of the stochastic gradient descent algorithm is as follows: - -$$w_{t+1}=w_{t}-\eta \frac{1}{n} \sum_{x \in \mathcal{B}} \nabla l\left(x, w_{t}\right)$$ +Hyperparameters can be adjusted to control the model training and optimization process. Different hyperparameter values may affect the model training and convergence speed. Currently, deep learning models are optimized using the batch stochastic gradient descent algorithm. The principle of the stochastic gradient descent algorithm is as follows: $w_{t+1}=w_{t}-\eta \frac{1}{n} \sum_{x \in \mathcal{B}} \nabla l\left(x, w_{t}\right)$ In the formula, $n$ is the batch size, and $η$ is a learning rate. In addition, $w_{t}$ is the weight parameter in the training batch t, and $\nabla l$ is the derivative of the loss function. In addition to the gradient itself, the two factors directly determine the weight update of the model. From the perspective of the optimization itself, the two factors are the most important parameters that affect the convergence of the model performance. Generally, the following hyperparameters are defined for training: @@ -29,9 +27,7 @@ learning_rate = 1e-2 ## Loss Functions -The **loss function** is used to evaluate the difference between **predicted value** and **target value** of a model. Here, the absolute error loss function `L1Loss` is used: - -$$\text { L1 Loss Function }=\sum_{i=1}^{n}\left|y_{true}-y_{predicted}\right|$$ +The **loss function** is used to evaluate the difference between **predicted value** and **target value** of a model. Here, the absolute error loss function `L1Loss` is used: $\text { L1 Loss Function }=\sum_{i=1}^{n}\left|y_{true}-y_{predicted}\right|$ `mindspore.nn.loss` provides many common loss functions, such as `SoftmaxCrossEntropyWithLogits`, `MSELoss`, and `SmoothL1Loss`. -- Gitee