From 0e3f2f2fbc181f3192fdea00d7a10e7974575833 Mon Sep 17 00:00:00 2001 From: jializheng Date: Tue, 29 Aug 2023 15:29:02 +0800 Subject: [PATCH 1/2] fix hccl wrap --- torch_npu/distributed/hccl_dtype_wraper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_npu/distributed/hccl_dtype_wraper.py b/torch_npu/distributed/hccl_dtype_wraper.py index 2070746a824..5fb34fca8b2 100644 --- a/torch_npu/distributed/hccl_dtype_wraper.py +++ b/torch_npu/distributed/hccl_dtype_wraper.py @@ -59,14 +59,14 @@ def wrapper_dist_dtype_one_input(fn): args[0].copy_(new_args[0].to(raw_type)) return output elif 'tensor' in kwargs and str(kwargs['tensor'].dtype) in WRAP_DTYPE_DICT.keys(): - new_kwargs = copy.deepcopy(kwargs) + old_tensor = kwargs['tensor'] raw_type = kwargs['tensor'].dtype tar_type = WRAP_DTYPE_DICT[str(raw_type)] - new_kwargs['tensor'] = new_kwargs['tensor'].to(tar_type) - output = fn(*args, **new_kwargs) + kwargs['tensor'] = kwargs['tensor'].to(tar_type) + output = fn(*args, **kwargs) if output is not None: output.wait() - kwargs['tensor'].copy_(new_kwargs['tensor'].to(raw_type)) + kwargs['tensor'] = old_tensor.copy_(kwargs['tensor'].to(raw_type)) return output return fn(*args, **kwargs) -- Gitee From 223ee74553a4e212e56761d2957b22301029904e Mon Sep 17 00:00:00 2001 From: jializheng Date: Tue, 29 Aug 2023 15:31:08 +0800 Subject: [PATCH 2/2] remove int64, is support now --- torch_npu/distributed/hccl_dtype_wraper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_npu/distributed/hccl_dtype_wraper.py b/torch_npu/distributed/hccl_dtype_wraper.py index 5fb34fca8b2..22b526eb55a 100644 --- a/torch_npu/distributed/hccl_dtype_wraper.py +++ b/torch_npu/distributed/hccl_dtype_wraper.py @@ -25,7 +25,6 @@ WRAP_DTYPE_DICT = { 'torch.uint16': torch.int32, 'torch.uint32': torch.int32, 'torch.uint64': torch.int32, - 'torch.int64': torch.int32, 'torch.float64': torch.float32, } -- Gitee