diff --git a/PyTorch/built-in/mm/LAVIS/lavis/tasks/base_task.py b/PyTorch/built-in/mm/LAVIS/lavis/tasks/base_task.py index b9b70f71ebfe063ab3259173d0c882e20a028a2f..1998a725b8c2d7f75cb91f992d7993f90f37e4a3 100644 --- a/PyTorch/built-in/mm/LAVIS/lavis/tasks/base_task.py +++ b/PyTorch/built-in/mm/LAVIS/lavis/tasks/base_task.py @@ -9,6 +9,8 @@ import logging import os import torch +if torch.__version__ >= "1.8": + import torch_npu import torch.distributed as dist from lavis.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized from lavis.common.logger import MetricLogger, SmoothedValue @@ -16,6 +18,23 @@ from lavis.common.registry import registry from lavis.datasets.data_utils import prepare_sample +try: + from torch_npu.utils.profiler import Profile +except: + print("Profile not in torch_npu.utils.profiler now.. Auto Profile disabled.", flush=True) + + + class Profile: + def __init__(self, *args, **kwargs): + pass + + def start(self): + pass + + def end(self): + pass + + class BaseTask: def __init__(self, **kwargs): super().__init__() @@ -209,8 +228,12 @@ class BaseTask: inner_epoch = start_iters // iters_per_epoch header = header + "; inner epoch [{}]".format(inner_epoch) + profile = Profile(start_step=int(os.getenv('PROFILE_START_STEP', 10)), + profile_type=os.getenv('PROFILE_TYPE')) + for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header): # if using iter-based runner, we stop after iters_per_epoch iterations. + profile.start() if i >= iters_per_epoch: break if self.test_performance and i > self.test_performance_steps: @@ -248,6 +271,7 @@ class BaseTask: else: optimizer.step() optimizer.zero_grad() + profile.end() metric_logger.update(**loss_dict) metric_logger.update(lr=optimizer.param_groups[0]["lr"])