Ai
1 Star 1 Fork 0

陈睿敏/PIDM

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
dist_accuracy_pidm.py 7.89 KB
一键复制 编辑 原始数据 按行查看 历史
陈睿敏 提交于 2024-11-05 16:16 +08:00 . update readme
import os
import sys
import time
import argparse
import torch
import torch.distributed as dist
import torch_npu
import npu_inductor_mlir.npu.npu_inductor_plugin
from torch_npu.npu.amp import GradScaler
scaler = GradScaler(init_scale=1024, dynamic=False)
from tensorfn import load_config
from torch import nn
from tqdm import tqdm
from config.diffconfig import DiffusionConfig, get_model_conf
from diffusion import create_gaussian_diffusion
torch._dynamo.config.optimize_ddp_lazy_compile = True
def init_distributed():
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
dist_url = "env://" # default
# only works with torch.distributed.launch // torch.run
rank = int(os.environ["RANK"])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ['LOCAL_RANK'])
dist.init_process_group(
backend="hccl",
init_method=dist_url,
world_size=world_size,
rank=rank)
# this will make all .cuda() calls work properly
torch_npu.npu.set_device(local_rank)
# synchronizes all the threads to reach this point before moving on
dist.barrier()
setup_for_distributed(rank == 0)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_main_process():
try:
if dist.get_rank()==0:
return True
else:
return False
except:
return True
def make_model(
diffconf, image_size, disable_attention,
use_checkpoint, device, compiled, channel_last):
betas = diffconf.diffusion.beta_schedule.make()
model_conf = get_model_conf()
model_conf.dropout = 0.0
model_conf.use_checkpoint = use_checkpoint
if image_size == 1024:
model_conf.attention_resolutions = (128, 64, 32) # type: ignore
elif image_size == 512:
model_conf.attention_resolutions = (64, 32, 16) # type: ignore
# to disable attention
if disable_attention:
model_conf.attention_resolutions = () # type: ignore
model_conf.image_size = image_size
model = model_conf.make_model().to(device=device)
if channel_last:
model = model.to(
device=device, memory_format=torch.channels_last) # type: ignore
if compiled:
import npu_inductor_mlir.npu.npu_inductor_plugin
model = torch.compile(model=model, dynamic=False)
dfu = create_gaussian_diffusion(betas, predict_xstart=False)
optimizer = diffconf.training.optimizer.make(model.parameters())
scheduler = diffconf.training.scheduler.make(optimizer) # type: ignore
return model, dfu, optimizer, scheduler
def training_forward(
amp, model, dfu, img, target_img,
target_pose, guidance_prob, time_t):
with torch.autocast( # type: ignore
device_type='npu', dtype=torch.bfloat16, enabled=amp):
loss_dict = dfu.training_losses(
model, x_start=target_img, t=time_t, cond_input=[
img, target_pose], prob=1 - guidance_prob)
loss = loss_dict['loss'].mean()
# loss_mse = loss_dict['mse'].mean()
# loss_vb = loss_dict['vb'].mean()
return loss
def training_update(model, optimizer, scheduler, loss):
optimizer.zero_grad()
scaler.scale(loss).backward()
nn.utils.clip_grad_norm_(model.parameters(), 1) # type: ignore
scheduler.step()
scaler.step(optimizer)
scaler.update()
def equal_training(
bs, h, w, seed, disable_attention, use_checkpoint,
compiled: bool, amp: bool, channel_last: bool, data_path: str):
print(f"bs {bs}, h {h}, w {w}, seed {seed}, "
f"use_checkpoint {use_checkpoint}, compiled {compiled}, amp {amp}, "
f"channel_last {channel_last}")
torch_npu.npu.empty_cache()
torch.manual_seed(seed)
from config.dataconfig import Config as DataConfig
import data as deepfashion_data
DataConf = DataConfig('./config/data.yaml')
diffconf: DiffusionConfig = load_config(
DiffusionConfig, config='./config/diffusion.conf', show=False)
diffconf.distributed = True
local_rank = int(os.environ['LOCAL_RANK'])
DataConf.data.path = data_path
val_dataset, train_dataset = deepfashion_data.get_train_val_dataloader(DataConf.data, labels_required = True, distributed = diffconf.distributed)
device = 'npu'
model, dfu, optimizer, scheduler = make_model(
diffconf, image_size=h,
disable_attention=disable_attention,
use_checkpoint=use_checkpoint,
device=device, compiled=compiled,
channel_last=channel_last)
if diffconf.distributed:
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank],
find_unused_parameters=True
)
loss_lst = []
latency_lst = []
data_size = 10000
for epoch in range(100):
start = time.time()
step_time = []
data_time = []
for batch_idx, batch in enumerate(tqdm(train_dataset)):
guidance_prob = 0.1
begin = time.time()
img = torch.cat([batch['source_image'], batch['target_image']], 0)
target_img = torch.cat([batch['target_image'], batch['source_image']], 0)
target_pose = torch.cat([batch['target_skeleton'], batch['source_skeleton']], 0)
img = img.to(device)
target_img = target_img.to(device)
target_pose = target_pose.to(device)
time_t = torch.randint(0, 1000, (bs,), device=device)
if channel_last:
target_pose = target_pose.to(
memory_format=torch.channels_last) # type: ignore
img = img .to(
memory_format=torch.channels_last) # type: ignore
target_img = target_img.to(
memory_format=torch.channels_last) # type: ignore
data_time.append(time.time() - begin)
loss = training_forward(
amp=amp, model=model, dfu=dfu, img=img, target_img=target_img,
target_pose=target_pose, guidance_prob=guidance_prob,
time_t=time_t)
training_update(
model=model, optimizer=optimizer, scheduler=scheduler,
loss=loss)
loss_lst.append(loss)
print(f"loss: {loss.item()}", flush=True)
if batch_idx > 3:
step_time.append(time.time() - begin)
print(f"epoch time {epoch}, {time.time() - start}, "
f"mean data prep time {sum(data_time) / len(data_time)}, "
f"mean step time {sum(step_time) / len(step_time)}")
return loss_lst, latency_lst
def run_train(h, w, bs, disable_attention, compiled, amp, data_path):
seed = int(time.time())
channel_last = False
use_checkpoint = False # False, True
equal_training(
bs, h, w, seed=seed,
disable_attention=disable_attention,
use_checkpoint=use_checkpoint,
channel_last=channel_last,
compiled=compiled,
amp=amp,
data_path=data_path,
)
if __name__ == '__main__':
init_distributed()
parser = argparse.ArgumentParser()
parser.add_argument('--bs', type=int)
parser.add_argument('--height', type=int)
parser.add_argument('--compiled', action='store_true')
parser.add_argument('--data_path', type=str)
args = parser.parse_args()
torch.npu.config.allow_internal_format=False
bs = args.bs
h = args.height
compiled = args.compiled
data_path = args.data_path
assert h in [64, 256, 512, 1024]
w = h*3//4
amp = not False
disable_attention = False # False, True
run_train(h, w, bs, disable_attention, compiled, amp, data_path)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/rmch/pidm.git
git@gitee.com:rmch/pidm.git
rmch
pidm
PIDM
master

搜索帮助