Ai
1 Star 0 Fork 2

Yi_zhang95/TokAI

forked from 逍遥叹wan/TokAI 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train.py 8.12 KB
一键复制 编辑 原始数据 按行查看 历史
逍遥叹wan 提交于 2024-01-15 22:37 +08:00 . model training with fsdp and non-distributed
"""train process"""
import argparse
import os
import random
import time
import numpy as np
import torch
from torch import nn
from torch.utils import data as Data
from src import (
calculate_l2_error,
create_datasets,
load_yaml_config,
pad_collate,
tokmak_model,
)
seed = 123456
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def parse_args():
"""Parse input args"""
parser = argparse.ArgumentParser(description="tokmak train")
parser.add_argument(
"--config_file_path", type=str, default="./configs/tokamak.yaml"
)
parser.add_argument(
"--device_target",
type=str,
default="Ascend",
choices=["GPU", "Ascend"],
help="The target device to run, support 'Ascend', 'GPU'",
)
parser.add_argument(
"--device_id", type=int, default=0, help="ID of the target device"
)
parser.add_argument(
"--mode",
type=str,
default="GRAPH",
choices=["GRAPH", "PYNATIVE"],
help="Running in GRAPH_MODE OR PYNATIVE_MODE",
)
parser.add_argument(
"--save_graphs",
type=bool,
default=False,
choices=[True, False],
help="Whether to save intermediate compilation graphs",
)
parser.add_argument("--save_graphs_path", type=str, default="./graphs")
input_args = parser.parse_args()
return input_args
def train_process(
model,
cur_steps,
optimizer,
scheduler,
train_loader,
out_channels,
lossfn_config,
):
"""
Args:
...
Returns:
...
"""
model.train()
epoch_train_loss = 0
for X, valid_len, valid_channels, h5_files in train_loader:
optimizer.zero_grad(set_to_none=True)
X = X.cuda()
valid_len = valid_len.cuda()
valid_channels = valid_channels.cuda()
device = X.device
dtype = X.dtype
batch_size = X.shape[0]
cur_steps += batch_size
enc_inputs = X[:, :, :-out_channels]
label = X[:, :, -out_channels:]
valid_channels = valid_channels[:, -out_channels:]
padded_values = torch.zeros(
label.shape[0], 1, label.shape[-1], device=device, dtype=dtype
)
dec_padded_inputs = torch.cat((padded_values, label), 1)
dec_inputs = dec_padded_inputs[:, :-1, :]
l2_loss = calculate_l2_error(
model,
enc_inputs,
dec_inputs,
cur_steps,
label,
valid_len,
valid_channels,
lossfn_config,
)
l2_loss.backward()
optimizer.step()
scheduler.step()
epoch_train_loss += l2_loss.detach().item()
return epoch_train_loss, cur_steps
@torch.no_grad()
def eval_process(model, val_loader, out_channels, lossfn_config):
# cur_steps is
""" """
model.eval()
epoch_eval_loss = 0
for X, valid_len, valid_channels, h5_files in val_loader:
X = X.cuda()
valid_len = valid_len.cuda()
valid_channels = valid_channels.cuda()
device = X.device
dtype = X.dtype
batch_size = X.shape[0]
enc_inputs = X[:, :, :-out_channels]
label = X[:, :, -out_channels:]
valid_channels = valid_channels[:, -out_channels:]
padded_values = torch.zeros(
label.shape[0], 1, label.shape[-1], device=device, dtype=dtype
)
dec_padded_inputs = torch.cat((padded_values, label), 1)
dec_inputs = dec_padded_inputs[:, :-1, :]
# cur_steps is meaningless when model evaluation.
cur_steps = 0
l2_loss = calculate_l2_error(
model,
enc_inputs,
dec_inputs,
cur_steps,
label,
valid_len,
valid_channels,
lossfn_config,
)
epoch_eval_loss += l2_loss.detach().item()
return epoch_eval_loss
def train():
"""Train and evaluate the pinns network"""
# load configurations
config = load_yaml_config(args.config_file_path)
train_paras = config["data"]["train"]
val_paras = config["data"]["validation"]
loss_paras = config["loss"]["train"]
# For next improve, please not rewrite this.
lossfn_config = {}
lossfn_config["limiter_steps"] = loss_paras["limiter_steps"]
# create dataset
train_set, val_set = create_datasets(config, is_debug=train_paras["is_debug"])
train_loader = Data.DataLoader(
train_set,
batch_size=train_paras["batch_size"],
num_workers=train_paras["num_workers"],
collate_fn=pad_collate,
)
train_steps_per_epoch = len(train_loader) + (train_paras["num_workers"] - 1)
val_loader = Data.DataLoader(
val_set,
batch_size=val_paras["batch_size"],
num_workers=val_paras["num_workers"],
collate_fn=pad_collate,
)
val_steps_per_epoch = len(val_loader) + (train_paras["num_workers"] - 1)
# define models and optimizers
model_params = config["model"]
optim_params = config["optimizer"]
model = tokmak_model(
in_channels=model_params["in_channels"],
hidden_size=model_params["hidden_size"],
num_layers=model_params["num_layers"],
dropout_rate=model_params["dropout_rate"],
out_channels=model_params["out_channels"],
noise_ratio=model_params["noise_ratio"],
)
model.cuda()
out_channels = model_params["out_channels"]
# define optimizer & scheduler
optimizer_fn = torch.optim.SGD
optimizer = optimizer_fn(
model.parameters(),
lr=float(optim_params["lr"]),
weight_decay=optim_params["weight_decay"],
)
scheduler_fn = torch.optim.lr_scheduler.OneCycleLR
scheduler_fn = torch.optim.lr_scheduler.OneCycleLR
scheduler = scheduler_fn(
optimizer,
max_lr=float(optim_params["lr"]),
steps_per_epoch=train_steps_per_epoch,
epochs=train_paras["epochs"],
)
epochs = config["data"]["train"]["epochs"]
cur_steps = 0
for epoch in range(1, 1 + epochs):
# train
local_time_beg = time.time()
epoch_train_loss, cur_steps = train_process(
model,
cur_steps,
optimizer,
scheduler,
train_loader,
out_channels,
lossfn_config,
)
epoch_train_loss / train_steps_per_epoch
local_time_end = time.time()
epoch_seconds = (local_time_end - local_time_beg) * 1000
step_seconds = epoch_seconds / (train_steps_per_epoch * 1000)
step_train_loss = epoch_train_loss / train_steps_per_epoch
# step_train_loss differ from the real train_loss.
print(
f"epoch: {epoch} train loss: {step_train_loss} "
f"epoch time: {epoch_seconds:5.3f}s step time: {step_seconds:5.3f} ms"
)
if epoch % config["summary"]["eval_interval_epochs"] == 0:
eval_time_start = time.time()
epoch_val_loss = eval_process(
model, val_loader, out_channels, lossfn_config
)
step_val_loss = epoch_val_loss / val_steps_per_epoch
epoch_val_seconds = (time.time() - eval_time_start) * 1000
step_val_seconds = epoch_val_seconds / (val_steps_per_epoch * 1000)
print(
f"epoch: {epoch} val loss: {step_val_loss} "
f"evaluation time: {epoch_val_seconds:5.3f}s step time: {step_val_seconds:5.3f}ms"
)
if config["summary"]["save_ckpt"]:
tempdir = config["summary"]["ckpt_dir"]
os.makedirs(tempdir, exist_ok=True)
torch.save(
{"epoch": epoch, "model_state": model.state_dict()},
os.path.join(
tempdir,
f"{epoch}-{step_train_loss:.5f}-{step_val_loss:.5f}.pt",
),
)
if __name__ == "__main__":
from src import log_config
log_config("./logs", "tokmak")
print("pid:", os.getpid())
args = parse_args()
print(f"Running in {args.mode.upper()} mode, using device id: {args.device_id}.")
# use_ascend = context.get_context(attr_key='device_target') == "Ascend"
train()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/yi-zhang95/tok-ai.git
git@gitee.com:yi-zhang95/tok-ai.git
yi-zhang95
tok-ai
TokAI
master

搜索帮助