diff --git a/mindspore/ccsrc/pybind_api/tensor_py_reg.cc b/mindspore/ccsrc/pybind_api/tensor_py_reg.cc index 04a9de6bc94e9a2e6c0991fb8fcbbe845469d65c..89114ad9c1c6a74a010c21808a86fbbd2caa34f5 100644 --- a/mindspore/ccsrc/pybind_api/tensor_py_reg.cc +++ b/mindspore/ccsrc/pybind_api/tensor_py_reg.cc @@ -281,6 +281,7 @@ extern int TensorPython_set_dtypeObj(PyObject *self, PyObject *value, void *) { HANDLE_MS_EXCEPTION PyType *obj = reinterpret_cast *>(self); TypePtr dtype_object = py::cast(value); + runtime::Pipeline::Get().WaitForward(); obj->value.SetDtype(dtype_object); return 0; HANDLE_MS_EXCEPTION_RET_FAIL_END @@ -841,6 +842,7 @@ extern PyObject *TensorPython_set_dtype(PyObject *self, PyObject *args) { } TypePtr type_ptr = py::cast(py::handle(py_type)); PyType *tensor = (PyType *)self; + runtime::Pipeline::Get().WaitForward(); TypePtr result = tensor->value.SetDtype(type_ptr); return py::cast(result).release().ptr(); @@ -1050,6 +1052,7 @@ extern PyObject *TensorPython_set_device_address(PyObject *self, PyObject *args) TypePtr type_ptr = py::cast(py::handle(type_ptr_obj)); PyType *tensor = (PyType *)self; auto tensorTmp = tensor->value.GetTensor(); + runtime::Pipeline::Get().WaitForward(); TensorPybind::SetDeviceAddress(tensorTmp, addr, shape, type_ptr); Py_RETURN_NONE; diff --git a/tests/st/collective_ops/test_hccl/test_distributed.py b/tests/st/collective_ops/test_hccl/test_distributed.py index 9256b7a1eda6e7536260812082d9772011d83127..09dc62177a3bd7d33f77a1f743e0c355e41718e3 100644 --- a/tests/st/collective_ops/test_hccl/test_distributed.py +++ b/tests/st/collective_ops/test_hccl/test_distributed.py @@ -1683,3 +1683,31 @@ def test_hccl_scalar(): output_handle = all_gather_into_tensor(output_tensor, input_tensor) assert output_handle is None assert np.allclose(output_tensor.asnumpy(), except_output_tensor.asnumpy()) + + +@log_function_entry_exit +def test_hccl_overlap(): + """ + Feature: test distributed op + Description: test comm op in python native + Expectation: success + """ + + input_np = np.ones((1024, 1024)).astype(np.float32) + + x = ms.Tensor.from_numpy(input_np) + expect_sum_output = ms.Tensor(input_np * (sum(list(range(1, size + 1))))) + + for _ in range(100): + w = x * (rank + 1) + w, sum_output_handle = ms.communication.comm_func.all_reduce(w, async_op=True) + + # Communication/Compute overlap. + # The shape and dtype of empty is same as w. + # Incorrect calculation results will occur if memory cross-stream usage is improper. + empty = ms.mint.empty_like(w) + zeros = ms.mint.zeros_like(w) + empty.copy_(zeros) + + sum_output_handle.wait() + assert np.allclose(w.asnumpy(), expect_sum_output.asnumpy()) diff --git a/tests/st/pynative/forward/test_multi_stream.py b/tests/st/pynative/forward/test_multi_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..9b90c8cb0b00da053128fd3c59b539c3ebac761a --- /dev/null +++ b/tests/st/pynative/forward/test_multi_stream.py @@ -0,0 +1,53 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import mindspore as ms +from tests.mark_utils import arg_mark + + +@arg_mark(plat_marks=['platform_ascend'], + level_mark='level0', + card_mark='onecard', + essential_mark='essential') +def test_pynative_multi_stream(): + """ + Feature: PyNative multi-stream + Description: Test PyNative multi-stream with memory reuse. + Expectation: run success + """ + + x_np = np.ones((1024, 1024)).astype(np.float32) + x = ms.Tensor.from_numpy(x_np) + s1 = ms.runtime.Stream() + for _ in range(100): + with ms.runtime.StreamCtx(s1): + y = x + 1 + z = ms.mint.matmul(y, y) + event = s1.record_event() + + # Free tensor memory + del y + + # Execute on default stream. + # Memory reuse is prevented as different streams utilize separate memory pools. + empty = ms.mint.empty_like(x) + zeros = ms.mint.zeros_like(x) + empty.copy_(zeros) + + cur_stream = ms.runtime.current_stream() + cur_stream.wait_event(event) + + np.allclose(z.asnumpy(), np.matmul(x_np + 1, x_np + 1)) diff --git a/tests/st/pynative/forward/test_pynative_heter.py b/tests/st/pynative/forward/test_pynative_heter.py new file mode 100644 index 0000000000000000000000000000000000000000..ad5efcb74aecb30da65247bafe0839ef759757d7 --- /dev/null +++ b/tests/st/pynative/forward/test_pynative_heter.py @@ -0,0 +1,156 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore as ms +from mindspore import ops, nn +from tests.mark_utils import arg_mark + + +class Net1(nn.Cell): + def __init__(self): + super().__init__() + self.addn = ops.AddN() + self.sin1 = ops.Sin() + self.sin2 = ops.Sin() + def construct(self, x): + out1 = self.sin1(x) + out2 = self.sin2(x) + return self.addn((out1, out2)) + +class Net2(nn.Cell): + def __init__(self): + super().__init__() + self.addn1 = ops.AddN() + self.addn2 = ops.AddN() + self.sin1 = ops.Sin() + self.sin2 = ops.Sin() + def construct(self, x): + out1 = self.addn1((x, x)) + out2 = self.addn2((x, out1)) + out3 = self.sin1(out2) + return self.sin2(out3) + + +class Net3(nn.Cell): + def __init__(self): + super().__init__() + self.addn1 = ops.AddN() + self.addn2 = ops.AddN() + self.sin1 = ops.Sin() + self.sin2 = ops.Sin() + def construct(self, x): + out1 = self.addn1((x, x)) + out2 = self.sin1(out1) + out3 = self.addn2((x, out2)) + return self.sin2(out3) +def test_pynative_heterogeneous1(): + """ + Feature: PyNative Heterogeneous + Description: Test PyNative heterogeneous with aclnn/aclop/cpu + Expectation: run success + """ + input_np = np.ones((1024,)).astype(np.float32) + output_expect = np.sin(input_np) + np.sin(input_np) + + net = Net1() + net.sin1.set_device("CPU") + output = net(ms.Tensor.from_numpy(input_np)) + assert np.allclose(output.asnumpy(), output_expect) + + net = Net1() + net.sin2.set_device("CPU") + output = net(ms.Tensor.from_numpy(input_np)) + assert np.allclose(output.asnumpy(), output_expect) + + net = Net1() + net.addn.set_device("CPU") + output = net(ms.Tensor.from_numpy(input_np)) + assert np.allclose(output.asnumpy(), output_expect) + + +def test_pynative_heterogeneous2(): + """ + Feature: PyNative Heterogeneous + Description: Test PyNative heterogeneous with aclnn/aclop/cpu + Expectation: run success + """ + input_np = np.ones((1024,)).astype(np.float32) + output_expect = np.sin(np.sin(input_np * 3)) + + net = Net2() + net.sin1.set_device("CPU") + output = net(ms.Tensor.from_numpy(input_np)) + assert np.allclose(output.asnumpy(), output_expect) + + net = Net2() + net.sin2.set_device("CPU") + output = net(ms.Tensor.from_numpy(input_np)) + assert np.allclose(output.asnumpy(), output_expect) + + net = Net2() + net.addn1.set_device("CPU") + output = net(ms.Tensor.from_numpy(input_np)) + assert np.allclose(output.asnumpy(), output_expect) + + net = Net2() + net.addn2.set_device("CPU") + output = net(ms.Tensor.from_numpy(input_np)) + assert np.allclose(output.asnumpy(), output_expect) + + +def test_pynative_heterogeneous3(): + """ + Feature: PyNative Heterogeneous + Description: Test PyNative heterogeneous with aclnn/aclop/cpu + Expectation: run success + """ + input_np = np.ones((1024,)).astype(np.float32) + output_expect = np.sin(np.sin(input_np * 2) + input_np) + + net = Net3() + net.sin1.set_device("CPU") + output = net(ms.Tensor.from_numpy(input_np)) + assert np.allclose(output.asnumpy(), output_expect) + + net = Net3() + net.sin2.set_device("CPU") + output = net(ms.Tensor.from_numpy(input_np)) + assert np.allclose(output.asnumpy(), output_expect) + + net = Net3() + net.addn1.set_device("CPU") + output = net(ms.Tensor.from_numpy(input_np)) + assert np.allclose(output.asnumpy(), output_expect) + + net = Net3() + net.addn2.set_device("CPU") + output = net(ms.Tensor.from_numpy(input_np)) + assert np.allclose(output.asnumpy(), output_expect) + + + +@arg_mark(plat_marks=['platform_ascend'], + level_mark='level0', + card_mark='onecard', + essential_mark='essential') +def test_pynative_heterogeneous(): + """ + Feature: PyNative Heterogeneous + Description: Test PyNative heterogeneous + Expectation: run success + """ + test_pynative_heterogeneous1() + test_pynative_heterogeneous2() + test_pynative_heterogeneous3() diff --git a/tests/st/pynative/forward/test_pynative_with_jit.py b/tests/st/pynative/forward/test_pynative_with_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..d47c531cc0d3b778db46c4a2c01798d5a4a0c852 --- /dev/null +++ b/tests/st/pynative/forward/test_pynative_with_jit.py @@ -0,0 +1,166 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore +from mindspore import Tensor, ops, jit, mint +from tests.mark_utils import arg_mark + +@jit +def func_jit(x): + return ops.sin(x) + +def func(x): + return ops.sin(x) + +def test_pynative_with_jit(): + """ + Feature: PyNative with jit. + Description: Test running PyNative with jit. + Expectation: run success + """ + def test_func(input_np, run_func): + for _ in range(10): + x = Tensor(input_np) + y = Tensor(input_np) + + for _ in range(1000): + out = mint.matmul(x, y) + out = run_func(out) + assert np.allclose(out.asnumpy(), np.sin(np.matmul(input_np, input_np))) + + input_np = np.ones((1024, 1024)).astype(np.float32) + + test_func(input_np, func) + test_func(input_np, func_jit) + + +def test_multi_stream_with_jit(): + """ + Feature: PyNative with jit. + Description: Test running PyNative multi-stream wait before jit. + Expectation: run success + """ + input_np = np.ones((1024, 1024)).astype(np.float32) + s1 = mindspore.hal.Stream() + for _ in range(10): + x = Tensor(input_np) + y = Tensor(input_np) + + with mindspore.hal.StreamCtx(s1): + for _ in range(1000): + out = mint.matmul(x, y) + event = s1.record_event() + + cur_stream = mindspore.hal.current_stream() + cur_stream.wait_event(event) + + out = func_jit(out) + assert np.allclose(out.asnumpy(), np.sin(np.matmul(input_np, input_np))) + +def test_multi_stream_with_event(): + """ + Feature: PyNative multi-stream. + Description: Test running PyNative with stream/event. + Expectation: run success + """ + input_np = np.ones((1024, 1024)).astype(np.float32) + s1 = mindspore.hal.Stream() + for _ in range(10): + x = Tensor(input_np) + y = Tensor(input_np) + + with mindspore.hal.StreamCtx(s1): + for _ in range(1000): + out = mint.matmul(x, y) + event = s1.record_event() + + cur_stream = mindspore.hal.current_stream() + cur_stream.wait_event(event) + + out = func(out) + assert np.allclose(out.asnumpy(), np.sin(np.matmul(input_np, input_np))) + + +def test_jit_within_multi_stream(): + """ + Feature: PyNative jit multi-stream. + Description: Test running PyNative with stream/event. + Expectation: run success + """ + input_np = np.ones((1024, 1024)).astype(np.float32) + s1 = mindspore.hal.Stream() + for _ in range(10): + x = Tensor(input_np) + y = Tensor(input_np) + + @jit + def matmul_jit(a, b): + return mint.matmul(a, b) + + with mindspore.hal.StreamCtx(s1): + for _ in range(1000): + out = matmul_jit(x, y) + event = s1.record_event() + + cur_stream = mindspore.hal.current_stream() + cur_stream.wait_event(event) + + out = ops.sin(out) + assert np.allclose(out.asnumpy(), np.sin(np.matmul(input_np, input_np))) + + +def test_multi_stream_with_jit_output(): + """ + Feature: PyNative jit multi-stream. + Description: Test running PyNative with stream/event. + Expectation: run success + """ + input_np = np.ones((1024, 1024)).astype(np.float32) + s1 = mindspore.hal.Stream() + for _ in range(10): + x = Tensor(input_np) + y = Tensor(input_np) + + @jit + def matmul_jit(a, b): + return mint.matmul(a, b) + + for _ in range(1000): + out = matmul_jit(x, y) + + cur_stream = mindspore.hal.current_stream() + event = cur_stream.record_event() + with mindspore.hal.StreamCtx(s1): + s1.wait_event(event) + out = ops.sin(out) + + assert np.allclose(out.asnumpy(), np.sin(np.matmul(input_np, input_np))) + + +@arg_mark(plat_marks=['platform_ascend'], + level_mark='level0', + card_mark='onecard', + essential_mark='essential') +def test_pynative_and_graph_mixed_run(): + """ + Feature: test pynative and graph mixed run + Description: single op run in pynative, the output to net input which run in graph + Expectation: run success + """ + test_pynative_with_jit() + test_multi_stream_with_jit() + test_multi_stream_with_event() + test_jit_within_multi_stream() + test_multi_stream_with_jit_output() diff --git a/tests/st/runtime/test_runtime_heter.py b/tests/st/runtime/test_runtime_heter.py index 37f1accb58727e5cbe2337724e9bbbd33657445a..8e500546b6464d170f2842862ad635518a4561ef 100644 --- a/tests/st/runtime/test_runtime_heter.py +++ b/tests/st/runtime/test_runtime_heter.py @@ -225,3 +225,21 @@ def test_heter_lenet(): context.set_context(mode=context.GRAPH_MODE, jit_config={"jit_level": "O0"}) out_ascend = train_lenet() print(out_ascend) + + +@arg_mark(plat_marks=['platform_ascend'], level_mark='level1', card_mark='onecard', essential_mark='essential') +@test_utils.run_test_with_On +def test_heter_lenet_pynative(): + """ + Feature: PyNative special format in the heterogeneous scene. + Description: Test special format in the heterogeneous scene. + Expectation: success. + """ + context.set_context(mode=context.PYNATIVE_MODE) + ms.set_seed(42) + np.random.seed(42) + out_ascend = train_lenet() + expect_data = np.array([[2.6390028e-06, -1.0571928e-05, 5.6523363e-06, 5.9253930e-06, + 9.7876073e-06, 3.1337552e-06, 5.2174191e-06, 9.4886109e-06, + 7.1082345e-06, -3.6553743e-06]]) + assert np.allclose(out_ascend.asnumpy(), expect_data) diff --git a/tests/st/train/test_fork.py b/tests/st/train/test_fork.py index 7996c25d085d152359d33a029ecb6ceb880e360e..710e70c0892e49d21db76d9a6f668c01d1232395 100644 --- a/tests/st/train/test_fork.py +++ b/tests/st/train/test_fork.py @@ -210,3 +210,43 @@ def test_fork_with_pynative_pipeline(): p.start() assert q.get().asnumpy() == ops.log(ms.Tensor(2.0)).asnumpy() p.join() + + +@arg_mark(plat_marks=['platform_ascend'], + level_mark='level0', + card_mark='onecard', + essential_mark='essential') +def test_fork_with_pynative_pipeline_unfinished(): + """ + Feature: Fork test + Description: Test multiprocessing with PyNative pipeline. + Expectation: No exception + """ + ms.set_context(mode=ms.PYNATIVE_MODE) + + input_np = np.ones((1024,)).astype(np.float32) + addn = ops.AddN() + + def my_child_process(q): + out = None + for _ in range(1000): + x = ms.Tensor(input_np) + out = ops.sin(x) + out = addn((out, x)) + + q.put(out) + + mp.set_start_method("fork", force=True) + + for _ in range(10): + x = ms.Tensor.from_numpy(input_np) + out = ops.sin(x) + addn((out, x)) + + q = mp.Queue() + p = mp.Process(target=my_child_process, args=(q,)) + p.start() + out = q.get() + assert out.device == "CPU" + assert np.allclose(out.asnumpy(), np.sin(input_np) + input_np) + p.join()