代码拉取完成,页面将自动刷新
同步操作将从 逍遥叹wan/TokAI 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
"""train process"""
import argparse
import os
import random
import time
import numpy as np
import torch
from torch import nn
from torch.utils import data as Data
from src import (
calculate_l2_error,
create_datasets,
load_yaml_config,
pad_collate,
tokmak_model,
)
seed = 123456
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def parse_args():
"""Parse input args"""
parser = argparse.ArgumentParser(description="tokmak train")
parser.add_argument(
"--config_file_path", type=str, default="./configs/tokamak.yaml"
)
parser.add_argument(
"--device_target",
type=str,
default="Ascend",
choices=["GPU", "Ascend"],
help="The target device to run, support 'Ascend', 'GPU'",
)
parser.add_argument(
"--device_id", type=int, default=0, help="ID of the target device"
)
parser.add_argument(
"--mode",
type=str,
default="GRAPH",
choices=["GRAPH", "PYNATIVE"],
help="Running in GRAPH_MODE OR PYNATIVE_MODE",
)
parser.add_argument(
"--save_graphs",
type=bool,
default=False,
choices=[True, False],
help="Whether to save intermediate compilation graphs",
)
parser.add_argument("--save_graphs_path", type=str, default="./graphs")
input_args = parser.parse_args()
return input_args
def train_process(
model,
cur_steps,
optimizer,
scheduler,
train_loader,
out_channels,
lossfn_config,
):
"""
Args:
...
Returns:
...
"""
model.train()
epoch_train_loss = 0
for X, valid_len, valid_channels, h5_files in train_loader:
optimizer.zero_grad(set_to_none=True)
X = X.cuda()
valid_len = valid_len.cuda()
valid_channels = valid_channels.cuda()
device = X.device
dtype = X.dtype
batch_size = X.shape[0]
cur_steps += batch_size
enc_inputs = X[:, :, :-out_channels]
label = X[:, :, -out_channels:]
valid_channels = valid_channels[:, -out_channels:]
padded_values = torch.zeros(
label.shape[0], 1, label.shape[-1], device=device, dtype=dtype
)
dec_padded_inputs = torch.cat((padded_values, label), 1)
dec_inputs = dec_padded_inputs[:, :-1, :]
l2_loss = calculate_l2_error(
model,
enc_inputs,
dec_inputs,
cur_steps,
label,
valid_len,
valid_channels,
lossfn_config,
)
l2_loss.backward()
optimizer.step()
scheduler.step()
epoch_train_loss += l2_loss.detach().item()
return epoch_train_loss, cur_steps
@torch.no_grad()
def eval_process(model, val_loader, out_channels, lossfn_config):
# cur_steps is
""" """
model.eval()
epoch_eval_loss = 0
for X, valid_len, valid_channels, h5_files in val_loader:
X = X.cuda()
valid_len = valid_len.cuda()
valid_channels = valid_channels.cuda()
device = X.device
dtype = X.dtype
batch_size = X.shape[0]
enc_inputs = X[:, :, :-out_channels]
label = X[:, :, -out_channels:]
valid_channels = valid_channels[:, -out_channels:]
padded_values = torch.zeros(
label.shape[0], 1, label.shape[-1], device=device, dtype=dtype
)
dec_padded_inputs = torch.cat((padded_values, label), 1)
dec_inputs = dec_padded_inputs[:, :-1, :]
# cur_steps is meaningless when model evaluation.
cur_steps = 0
l2_loss = calculate_l2_error(
model,
enc_inputs,
dec_inputs,
cur_steps,
label,
valid_len,
valid_channels,
lossfn_config,
)
epoch_eval_loss += l2_loss.detach().item()
return epoch_eval_loss
def train():
"""Train and evaluate the pinns network"""
# load configurations
config = load_yaml_config(args.config_file_path)
train_paras = config["data"]["train"]
val_paras = config["data"]["validation"]
loss_paras = config["loss"]["train"]
# For next improve, please not rewrite this.
lossfn_config = {}
lossfn_config["limiter_steps"] = loss_paras["limiter_steps"]
# create dataset
train_set, val_set = create_datasets(config, is_debug=train_paras["is_debug"])
train_loader = Data.DataLoader(
train_set,
batch_size=train_paras["batch_size"],
num_workers=train_paras["num_workers"],
collate_fn=pad_collate,
)
train_steps_per_epoch = len(train_loader) + (train_paras["num_workers"] - 1)
val_loader = Data.DataLoader(
val_set,
batch_size=val_paras["batch_size"],
num_workers=val_paras["num_workers"],
collate_fn=pad_collate,
)
val_steps_per_epoch = len(val_loader) + (train_paras["num_workers"] - 1)
# define models and optimizers
model_params = config["model"]
optim_params = config["optimizer"]
model = tokmak_model(
in_channels=model_params["in_channels"],
hidden_size=model_params["hidden_size"],
num_layers=model_params["num_layers"],
dropout_rate=model_params["dropout_rate"],
out_channels=model_params["out_channels"],
noise_ratio=model_params["noise_ratio"],
)
model.cuda()
out_channels = model_params["out_channels"]
# define optimizer & scheduler
optimizer_fn = torch.optim.SGD
optimizer = optimizer_fn(
model.parameters(),
lr=float(optim_params["lr"]),
weight_decay=optim_params["weight_decay"],
)
scheduler_fn = torch.optim.lr_scheduler.OneCycleLR
scheduler_fn = torch.optim.lr_scheduler.OneCycleLR
scheduler = scheduler_fn(
optimizer,
max_lr=float(optim_params["lr"]),
steps_per_epoch=train_steps_per_epoch,
epochs=train_paras["epochs"],
)
epochs = config["data"]["train"]["epochs"]
cur_steps = 0
for epoch in range(1, 1 + epochs):
# train
local_time_beg = time.time()
epoch_train_loss, cur_steps = train_process(
model,
cur_steps,
optimizer,
scheduler,
train_loader,
out_channels,
lossfn_config,
)
epoch_train_loss / train_steps_per_epoch
local_time_end = time.time()
epoch_seconds = (local_time_end - local_time_beg) * 1000
step_seconds = epoch_seconds / (train_steps_per_epoch * 1000)
step_train_loss = epoch_train_loss / train_steps_per_epoch
# step_train_loss differ from the real train_loss.
print(
f"epoch: {epoch} train loss: {step_train_loss} "
f"epoch time: {epoch_seconds:5.3f}s step time: {step_seconds:5.3f} ms"
)
if epoch % config["summary"]["eval_interval_epochs"] == 0:
eval_time_start = time.time()
epoch_val_loss = eval_process(
model, val_loader, out_channels, lossfn_config
)
step_val_loss = epoch_val_loss / val_steps_per_epoch
epoch_val_seconds = (time.time() - eval_time_start) * 1000
step_val_seconds = epoch_val_seconds / (val_steps_per_epoch * 1000)
print(
f"epoch: {epoch} val loss: {step_val_loss} "
f"evaluation time: {epoch_val_seconds:5.3f}s step time: {step_val_seconds:5.3f}ms"
)
if config["summary"]["save_ckpt"]:
tempdir = config["summary"]["ckpt_dir"]
os.makedirs(tempdir, exist_ok=True)
torch.save(
{"epoch": epoch, "model_state": model.state_dict()},
os.path.join(
tempdir,
f"{epoch}-{step_train_loss:.5f}-{step_val_loss:.5f}.pt",
),
)
if __name__ == "__main__":
from src import log_config
log_config("./logs", "tokmak")
print("pid:", os.getpid())
args = parse_args()
print(f"Running in {args.mode.upper()} mode, using device id: {args.device_id}.")
# use_ascend = context.get_context(attr_key='device_target') == "Ascend"
train()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。