代码拉取完成,页面将自动刷新
同步操作将从 逍遥叹wan/TokAI 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
"""train process"""
import argparse
import os
import random
import time
import numpy as np
import torch
from torch import nn
from torch.utils import data as Data
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
)
import torch.multiprocessing as mp
import functools
import pathlib
from src import (
calculate_l2_error,
create_datasets,
load_yaml_config,
pad_collate,
tokmak_model,
)
seed = 123456
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def parse_args():
"""Parse input args"""
parser = argparse.ArgumentParser(description="tokmak train")
parser.add_argument(
"--config_file_path", type=str, default="./configs/tokamak.yaml"
)
parser.add_argument(
"--device_target",
type=str,
default="Ascend",
choices=["GPU", "Ascend"],
help="The target device to run, support 'Ascend', 'GPU'",
)
parser.add_argument(
"--device_id", type=int, default=0, help="ID of the target device"
)
parser.add_argument(
"--mode",
type=str,
default="GRAPH",
choices=["GRAPH", "PYNATIVE"],
help="Running in GRAPH_MODE OR PYNATIVE_MODE",
)
parser.add_argument(
"--save_graphs",
type=bool,
default=False,
choices=[True, False],
help="Whether to save intermediate compilation graphs",
)
parser.add_argument("--save_graphs_path", type=str, default="./graphs")
input_args = parser.parse_args()
return input_args
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def train_process(
model,
rank,
cur_steps,
optimizer,
scheduler,
train_loader,
out_channels,
lossfn_config,
):
"""
Args:
...
Returns:
...
"""
model.train()
torch.cuda.set_device(rank)
ddp_loss = torch.zeros(2).cuda()
for X, valid_len, valid_channels, h5_files in train_loader:
optimizer.zero_grad(set_to_none=True)
X = X.cuda()
valid_len = valid_len.cuda()
valid_channels = valid_channels.cuda()
device = X.device
dtype = X.dtype
batch_size = X.shape[0]
cur_steps += batch_size
enc_inputs = X[:, :, :-out_channels]
label = X[:, :, -out_channels:]
valid_channels = valid_channels[:, -out_channels:]
padded_values = torch.zeros(
label.shape[0], 1, label.shape[-1], device=device, dtype=dtype
)
dec_padded_inputs = torch.cat((padded_values, label), 1)
dec_inputs = dec_padded_inputs[:, :-1, :]
l2_loss = calculate_l2_error(
model,
enc_inputs,
dec_inputs,
cur_steps,
label,
valid_len,
valid_channels,
lossfn_config,
)
l2_loss.backward()
optimizer.step()
scheduler.step()
ddp_loss[0] += l2_loss.detach().item()
ddp_loss[1] += batch_size
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
return ddp_loss[0] / ddp_loss[1], cur_steps
@torch.no_grad()
def eval_process(model, rank, val_loader, out_channels, lossfn_config):
# cur_steps is
""" """
model.eval()
# epoch_eval_loss = 0
torch.cuda.set_device(rank)
ddp_loss = torch.zeros(2).to(rank)
for X, valid_len, valid_channels, h5_files in val_loader:
X = X.cuda()
valid_len = valid_len.cuda()
valid_channels = valid_channels.cuda()
device = X.device
dtype = X.dtype
batch_size = X.shape[0]
enc_inputs = X[:, :, :-out_channels]
label = X[:, :, -out_channels:]
valid_channels = valid_channels[:, -out_channels:]
padded_values = torch.zeros(
label.shape[0], 1, label.shape[-1], device=device, dtype=dtype
)
dec_padded_inputs = torch.cat((padded_values, label), 1)
dec_inputs = dec_padded_inputs[:, :-1, :]
# cur_steps is meaningless when model evaluation.
cur_steps = 0
l2_loss = calculate_l2_error(
model,
enc_inputs,
dec_inputs,
cur_steps,
label,
valid_len,
valid_channels,
lossfn_config,
)
ddp_loss[0] += l2_loss.detach().item()
ddp_loss[1] += batch_size
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
return ddp_loss[0] / ddp_loss[1]
def fsdp_train(rank, world_size, config):
"""fsdp train network"""
setup(rank, world_size)
torch.cuda.set_device(rank)
train_paras = config["data"]["train"]
val_paras = config["data"]["validation"]
loss_paras = config["loss"]["train"]
# For next improve, please not rewrite this.
lossfn_config = {}
lossfn_config["limiter_steps"] = loss_paras["limiter_steps"]
# create dataset & dataloader
train_set, val_set = create_datasets(config, is_debug=train_paras['is_debug'])
train_loader = Data.DataLoader(
train_set,
batch_size=train_paras["batch_size"],
num_workers=train_paras["num_workers"],
collate_fn=pad_collate,
)
train_steps_per_epoch = len(train_loader) + (train_paras["num_workers"] - 1)
val_loader = Data.DataLoader(
val_set,
batch_size=val_paras["batch_size"],
num_workers=val_paras["num_workers"],
collate_fn=pad_collate,
)
val_steps_per_epoch = len(val_loader) + (train_paras["num_workers"] - 1)
my_auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=100
)
# define models and optimizers
model_params = config["model"]
optim_params = config["optimizer"]
model = tokmak_model(
in_channels=model_params["in_channels"],
hidden_size=model_params["hidden_size"],
num_layers=model_params["num_layers"],
dropout_rate=model_params["dropout_rate"],
out_channels=model_params["out_channels"],
noise_ratio=model_params["noise_ratio"],
)
model.cuda()
model = FSDP(model)
out_channels = model_params["out_channels"]
# define optimizer & scheduler
optimizer_fn = torch.optim.SGD
optimizer = optimizer_fn(
model.parameters(),
lr=float(optim_params["lr"]),
weight_decay=optim_params["weight_decay"],
)
scheduler_fn = torch.optim.lr_scheduler.OneCycleLR
scheduler_fn = torch.optim.lr_scheduler.OneCycleLR
scheduler = scheduler_fn(
optimizer,
max_lr=float(optim_params["lr"]),
steps_per_epoch=train_steps_per_epoch,
epochs=train_paras["epochs"],
)
epochs = config["data"]["train"]["epochs"]
cur_steps = 0
for epoch in range(1, 1 + epochs):
# train
if rank == 0:
local_time_beg = time.time()
step_train_loss, cur_steps = train_process(
model,
rank,
cur_steps,
optimizer,
scheduler,
train_loader,
out_channels,
lossfn_config,
)
if rank == 0:
# epoch_train_loss / train_steps_per_epoch
local_time_end = time.time()
epoch_seconds = (local_time_end - local_time_beg) * 1000
step_seconds = epoch_seconds / train_steps_per_epoch
# step_train_loss differ from the real train_loss.
print(
f"epoch: {epoch} train loss: {step_train_loss} "
f"epoch time: {epoch_seconds:5.3f}s step time: {step_seconds:5.3f}s"
)
if epoch % config["summary"]["eval_interval_epochs"] == 0:
if rank == 0:
eval_time_start = time.time()
step_val_loss = eval_process(
model, rank, val_loader, out_channels, lossfn_config
)
if rank == 0:
epoch_val_seconds = (time.time() - eval_time_start) * 1000
step_val_seconds = epoch_val_seconds / (val_steps_per_epoch * 1000)
print(
f"epoch: {epoch} val loss: {step_val_loss} "
f"evaluation time: {time.time() - eval_time_start:5.3f}s step time: {step_val_seconds:5.3f}ms"
)
if config["summary"]["save_ckpt"]:
tempdir = config["summary"]["ckpt_dir"]
os.makedirs(tempdir, exist_ok=True)
torch.save(
{"epoch": epoch, "model_state": model.state_dict()},
os.path.join(
tempdir,
f"fsdp-{epoch}-{step_train_loss:.5f}-{step_val_loss:.5f}.pt",
),
)
def train():
# load configurations
config = load_yaml_config(args.config_file_path)
world_size = torch.cuda.device_count()
mp.spawn(
fsdp_train,
args=(
world_size,
config,
),
nprocs=world_size,
join=True,
)
if __name__ == "__main__":
from src import log_config
log_config("./logs", "tokmak")
print("pid:", os.getpid())
args = parse_args()
print(f"Running in {args.mode.upper()} mode, using device id: {args.device_id}.")
# use_ascend = context.get_context(attr_key='device_target') == "Ascend"
train()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。