diff --git a/debug/accuracy_tools/atat/pytorch/common/utils.py b/debug/accuracy_tools/atat/pytorch/common/utils.py index aa4a3214071121ead083d9195752ac7db8e3ef07..6ec5d6c9e57d06b51cc56feb2a6004d0b0048245 100644 --- a/debug/accuracy_tools/atat/pytorch/common/utils.py +++ b/debug/accuracy_tools/atat/pytorch/common/utils.py @@ -195,7 +195,8 @@ class Const: MAX_SEED_VALUE = 2**32 - 1 INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", - "_reduce_scatter_base", "_all_gather_base", "all_to_all_single"] + "_reduce_scatter_base", "_all_gather_base", "all_to_all_single", "all_gather_into_tensor", + "reduce_scatter_tensor"] TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark"] LEVEL_LIST = ["L0", "L1", "L2", "mix"] diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml b/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml index d64c577ff38b7b7e3478d59eb1754845973c7103..d6b4c6896f3af4730c9c7f620a474dc0468f33d5 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml +++ b/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml @@ -1873,4 +1873,6 @@ distributed: - reduce_scatter - _reduce_scatter_base - _all_gather_base - - all_to_all_single \ No newline at end of file + - all_to_all_single + - all_gather_into_tensor + - reduce_scatter_tensor \ No newline at end of file