diff --git a/PyTorch/built-in/nlp/LSTM_ID0468_for_PyTorch/timit/steps/train_ctc.py b/PyTorch/built-in/nlp/LSTM_ID0468_for_PyTorch/timit/steps/train_ctc.py index 98e3f7ae16da47b7850b408c0c88515f788c62fb..d3508aaf1350407914e7f4e078482e9cc7e15329 100644 --- a/PyTorch/built-in/nlp/LSTM_ID0468_for_PyTorch/timit/steps/train_ctc.py +++ b/PyTorch/built-in/nlp/LSTM_ID0468_for_PyTorch/timit/steps/train_ctc.py @@ -41,6 +41,7 @@ import time import yaml import argparse import numpy as np +import apex from apex import amp import torch import torch.nn as nn @@ -188,9 +189,9 @@ def main(args,conf): print(params) loss_fn = nn.CTCLoss(reduction='sum') - optimizer = torch.optim.Adam(model.parameters(), lr=init_lr, weight_decay=weight_decay) + optimizer = apex.optimizers.NpuFusedAdam(model.parameters(), lr=init_lr, weight_decay=weight_decay) if args.apex: - model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, loss_scale=args.loss_scale) + model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, loss_scale=args.loss_scale,combine_grad=True) #visualization for training # from visdom import Visdom # viz = Visdom()