From 2492b4ff9127c49fda4dbbabfd48e747352961c6 Mon Sep 17 00:00:00 2001 From: jializheng Date: Tue, 29 Aug 2023 15:37:53 +0800 Subject: [PATCH] fix hccl wapper bug and remove int64, is support now --- torch_npu/distributed/hccl_dtype_wraper.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torch_npu/distributed/hccl_dtype_wraper.py b/torch_npu/distributed/hccl_dtype_wraper.py index 2070746a824..22b526eb55a 100644 --- a/torch_npu/distributed/hccl_dtype_wraper.py +++ b/torch_npu/distributed/hccl_dtype_wraper.py @@ -25,7 +25,6 @@ WRAP_DTYPE_DICT = { 'torch.uint16': torch.int32, 'torch.uint32': torch.int32, 'torch.uint64': torch.int32, - 'torch.int64': torch.int32, 'torch.float64': torch.float32, } @@ -59,14 +58,14 @@ def wrapper_dist_dtype_one_input(fn): args[0].copy_(new_args[0].to(raw_type)) return output elif 'tensor' in kwargs and str(kwargs['tensor'].dtype) in WRAP_DTYPE_DICT.keys(): - new_kwargs = copy.deepcopy(kwargs) + old_tensor = kwargs['tensor'] raw_type = kwargs['tensor'].dtype tar_type = WRAP_DTYPE_DICT[str(raw_type)] - new_kwargs['tensor'] = new_kwargs['tensor'].to(tar_type) - output = fn(*args, **new_kwargs) + kwargs['tensor'] = kwargs['tensor'].to(tar_type) + output = fn(*args, **kwargs) if output is not None: output.wait() - kwargs['tensor'].copy_(new_kwargs['tensor'].to(raw_type)) + kwargs['tensor'] = old_tensor.copy_(kwargs['tensor'].to(raw_type)) return output return fn(*args, **kwargs) -- Gitee