diff --git a/MindFlow/applications/data_driven/airfoil/2D_unsteady/train.py b/MindFlow/applications/data_driven/airfoil/2D_unsteady/train.py index e02dd94731642bba67cbcbfb13a0124113d22880..e9d50dec1713573d011116492a31bb20095e667e 100644 --- a/MindFlow/applications/data_driven/airfoil/2D_unsteady/train.py +++ b/MindFlow/applications/data_driven/airfoil/2D_unsteady/train.py @@ -168,6 +168,7 @@ if __name__ == '__main__': print_log(f"pid: {os.getpid()}") print_log(datetime.datetime.now()) context.set_context(mode=context.GRAPH_MODE if args.mode.upper().startswith("GRAPH") else context.PYNATIVE_MODE, + jit_config={"jit_level": "O2"}, device_target=args.device_target, device_id=args.device_id) print_log(f"device_id: {context.get_context(attr_key='device_id')}")