1 Star 0 Fork 2

Yi_zhang95/TokAI

forked from 逍遥叹wan/TokAI 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train_fsdp.py 9.40 KB
一键复制 编辑 原始数据 按行查看 历史
逍遥叹wan 提交于 2024-01-15 22:37 +08:00 . model training with fsdp and non-distributed
"""train process"""
import argparse
import os
import random
import time
import numpy as np
import torch
from torch import nn
from torch.utils import data as Data
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
)
import torch.multiprocessing as mp
import functools
import pathlib
from src import (
calculate_l2_error,
create_datasets,
load_yaml_config,
pad_collate,
tokmak_model,
)
seed = 123456
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def parse_args():
"""Parse input args"""
parser = argparse.ArgumentParser(description="tokmak train")
parser.add_argument(
"--config_file_path", type=str, default="./configs/tokamak.yaml"
)
parser.add_argument(
"--device_target",
type=str,
default="Ascend",
choices=["GPU", "Ascend"],
help="The target device to run, support 'Ascend', 'GPU'",
)
parser.add_argument(
"--device_id", type=int, default=0, help="ID of the target device"
)
parser.add_argument(
"--mode",
type=str,
default="GRAPH",
choices=["GRAPH", "PYNATIVE"],
help="Running in GRAPH_MODE OR PYNATIVE_MODE",
)
parser.add_argument(
"--save_graphs",
type=bool,
default=False,
choices=[True, False],
help="Whether to save intermediate compilation graphs",
)
parser.add_argument("--save_graphs_path", type=str, default="./graphs")
input_args = parser.parse_args()
return input_args
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def train_process(
model,
rank,
cur_steps,
optimizer,
scheduler,
train_loader,
out_channels,
lossfn_config,
):
"""
Args:
...
Returns:
...
"""
model.train()
torch.cuda.set_device(rank)
ddp_loss = torch.zeros(2).cuda()
for X, valid_len, valid_channels, h5_files in train_loader:
optimizer.zero_grad(set_to_none=True)
X = X.cuda()
valid_len = valid_len.cuda()
valid_channels = valid_channels.cuda()
device = X.device
dtype = X.dtype
batch_size = X.shape[0]
cur_steps += batch_size
enc_inputs = X[:, :, :-out_channels]
label = X[:, :, -out_channels:]
valid_channels = valid_channels[:, -out_channels:]
padded_values = torch.zeros(
label.shape[0], 1, label.shape[-1], device=device, dtype=dtype
)
dec_padded_inputs = torch.cat((padded_values, label), 1)
dec_inputs = dec_padded_inputs[:, :-1, :]
l2_loss = calculate_l2_error(
model,
enc_inputs,
dec_inputs,
cur_steps,
label,
valid_len,
valid_channels,
lossfn_config,
)
l2_loss.backward()
optimizer.step()
scheduler.step()
ddp_loss[0] += l2_loss.detach().item()
ddp_loss[1] += batch_size
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
return ddp_loss[0] / ddp_loss[1], cur_steps
@torch.no_grad()
def eval_process(model, rank, val_loader, out_channels, lossfn_config):
# cur_steps is
""" """
model.eval()
# epoch_eval_loss = 0
torch.cuda.set_device(rank)
ddp_loss = torch.zeros(2).to(rank)
for X, valid_len, valid_channels, h5_files in val_loader:
X = X.cuda()
valid_len = valid_len.cuda()
valid_channels = valid_channels.cuda()
device = X.device
dtype = X.dtype
batch_size = X.shape[0]
enc_inputs = X[:, :, :-out_channels]
label = X[:, :, -out_channels:]
valid_channels = valid_channels[:, -out_channels:]
padded_values = torch.zeros(
label.shape[0], 1, label.shape[-1], device=device, dtype=dtype
)
dec_padded_inputs = torch.cat((padded_values, label), 1)
dec_inputs = dec_padded_inputs[:, :-1, :]
# cur_steps is meaningless when model evaluation.
cur_steps = 0
l2_loss = calculate_l2_error(
model,
enc_inputs,
dec_inputs,
cur_steps,
label,
valid_len,
valid_channels,
lossfn_config,
)
ddp_loss[0] += l2_loss.detach().item()
ddp_loss[1] += batch_size
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
return ddp_loss[0] / ddp_loss[1]
def fsdp_train(rank, world_size, config):
"""fsdp train network"""
setup(rank, world_size)
torch.cuda.set_device(rank)
train_paras = config["data"]["train"]
val_paras = config["data"]["validation"]
loss_paras = config["loss"]["train"]
# For next improve, please not rewrite this.
lossfn_config = {}
lossfn_config["limiter_steps"] = loss_paras["limiter_steps"]
# create dataset & dataloader
train_set, val_set = create_datasets(config, is_debug=train_paras['is_debug'])
train_loader = Data.DataLoader(
train_set,
batch_size=train_paras["batch_size"],
num_workers=train_paras["num_workers"],
collate_fn=pad_collate,
)
train_steps_per_epoch = len(train_loader) + (train_paras["num_workers"] - 1)
val_loader = Data.DataLoader(
val_set,
batch_size=val_paras["batch_size"],
num_workers=val_paras["num_workers"],
collate_fn=pad_collate,
)
val_steps_per_epoch = len(val_loader) + (train_paras["num_workers"] - 1)
my_auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=100
)
# define models and optimizers
model_params = config["model"]
optim_params = config["optimizer"]
model = tokmak_model(
in_channels=model_params["in_channels"],
hidden_size=model_params["hidden_size"],
num_layers=model_params["num_layers"],
dropout_rate=model_params["dropout_rate"],
out_channels=model_params["out_channels"],
noise_ratio=model_params["noise_ratio"],
)
model.cuda()
model = FSDP(model)
out_channels = model_params["out_channels"]
# define optimizer & scheduler
optimizer_fn = torch.optim.SGD
optimizer = optimizer_fn(
model.parameters(),
lr=float(optim_params["lr"]),
weight_decay=optim_params["weight_decay"],
)
scheduler_fn = torch.optim.lr_scheduler.OneCycleLR
scheduler_fn = torch.optim.lr_scheduler.OneCycleLR
scheduler = scheduler_fn(
optimizer,
max_lr=float(optim_params["lr"]),
steps_per_epoch=train_steps_per_epoch,
epochs=train_paras["epochs"],
)
epochs = config["data"]["train"]["epochs"]
cur_steps = 0
for epoch in range(1, 1 + epochs):
# train
if rank == 0:
local_time_beg = time.time()
step_train_loss, cur_steps = train_process(
model,
rank,
cur_steps,
optimizer,
scheduler,
train_loader,
out_channels,
lossfn_config,
)
if rank == 0:
# epoch_train_loss / train_steps_per_epoch
local_time_end = time.time()
epoch_seconds = (local_time_end - local_time_beg) * 1000
step_seconds = epoch_seconds / train_steps_per_epoch
# step_train_loss differ from the real train_loss.
print(
f"epoch: {epoch} train loss: {step_train_loss} "
f"epoch time: {epoch_seconds:5.3f}s step time: {step_seconds:5.3f}s"
)
if epoch % config["summary"]["eval_interval_epochs"] == 0:
if rank == 0:
eval_time_start = time.time()
step_val_loss = eval_process(
model, rank, val_loader, out_channels, lossfn_config
)
if rank == 0:
epoch_val_seconds = (time.time() - eval_time_start) * 1000
step_val_seconds = epoch_val_seconds / (val_steps_per_epoch * 1000)
print(
f"epoch: {epoch} val loss: {step_val_loss} "
f"evaluation time: {time.time() - eval_time_start:5.3f}s step time: {step_val_seconds:5.3f}ms"
)
if config["summary"]["save_ckpt"]:
tempdir = config["summary"]["ckpt_dir"]
os.makedirs(tempdir, exist_ok=True)
torch.save(
{"epoch": epoch, "model_state": model.state_dict()},
os.path.join(
tempdir,
f"fsdp-{epoch}-{step_train_loss:.5f}-{step_val_loss:.5f}.pt",
),
)
def train():
# load configurations
config = load_yaml_config(args.config_file_path)
world_size = torch.cuda.device_count()
mp.spawn(
fsdp_train,
args=(
world_size,
config,
),
nprocs=world_size,
join=True,
)
if __name__ == "__main__":
from src import log_config
log_config("./logs", "tokmak")
print("pid:", os.getpid())
args = parse_args()
print(f"Running in {args.mode.upper()} mode, using device id: {args.device_id}.")
# use_ascend = context.get_context(attr_key='device_target') == "Ascend"
train()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/yi-zhang95/tok-ai.git
git@gitee.com:yi-zhang95/tok-ai.git
yi-zhang95
tok-ai
TokAI
master

搜索帮助