diff --git a/.gitee/ISSUE_TEMPLATE/.keep b/.gitee/ISSUE_TEMPLATE/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.gitee/ISSUE_TEMPLATE/1-documentation.yml b/.gitee/ISSUE_TEMPLATE/1-documentation.yml new file mode 100644 index 0000000000000000000000000000000000000000..128e7e8c88571b7c594c238dc28d5547cce3cda0 --- /dev/null +++ b/.gitee/ISSUE_TEMPLATE/1-documentation.yml @@ -0,0 +1,62 @@ +name: 📚 Documentation +description: Request updates or additions to MindScience documentation +title: "[Doc]: " +labels: ["documentation"] + +body: +- type: markdown + attributes: + value: | + Thanks for taking the time to help MindScience and improve our documentation! + - If this is your first time, please read [our contributor guidelines](https://gitee.com/mindspore/mindscience/blob/master/CONTRIBUTION.md). + - You also confirm that you have searched the [open documentation issues](https://gitee.com/mindspore/mindscience/issues) and have found no duplicates for this request +- type: dropdown + id: domain + attributes: + label: Which domain the issue belongs to? + options: + - MindSpore Science Core + - applications-SPONGE + - applications-Flow + - applications-Energy + - applications-Chemistry + - applications-Earth + - Others + validations: + required: true + +- type: dropdown + id: new_or_correction + attributes: + label: Is this for new documentation, or an update to existing docs? + options: + - New + - Update + validations: + required: true + +- type: textarea + attributes: + label: 📚 The doc issue + description: > + Describe the incorrect/future/missing documentation. + value: | + 1. 【Document Link】/【文档链接】 + + 2. 【Issues Section】/【问题文档片段】 + + 3. 【Existing Issues】/【存在的问题】 + + 4. 【Expected Result】【预期结果】 + + validations: + required: true +- type: textarea + attributes: + label: Suggest a potential alternative/fix + description: > + Tell us how we could improve the documentation in this regard. +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! \ No newline at end of file diff --git a/.gitee/ISSUE_TEMPLATE/2-installation.yml b/.gitee/ISSUE_TEMPLATE/2-installation.yml new file mode 100644 index 0000000000000000000000000000000000000000..5938cfe5dab419c99dfd0e2d08a5b9f191dccfe5 --- /dev/null +++ b/.gitee/ISSUE_TEMPLATE/2-installation.yml @@ -0,0 +1,68 @@ +name: 🛠️ Installation +description: Report an issue here when you hit errors during installation. +title: "[Installation]: " +labels: ["installation"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://gitee.com/mindspore/mindscience/issues). +- type: dropdown + id: domain + attributes: + label: Which domain the issue belongs to? + options: + - MindSpore Science Core + - applications-SPONGE + - applications-Flow + - applications-Energy + - applications-Chemistry + - applications-Earth + - Others + validations: + required: true + +- type: textarea + attributes: + label: Your current environment + description: | + Environment / 环境信息 (Mandatory / 必填) + value: | + - **Hardware Environment / 硬件环境(Mandatory / 必填)**: + Hardware (e.g.`Atlas 800T A2`) + + 样例: + + | 后端类型| 硬件具体类别 | + | --- | --- | + | Server | Atlas 800T A2 | + | CPU| Mac CPU/Win CPU/Linux CPU| + + + - **Software Environment / 软件环境 (Mandatory / 必填)**: + 迭代版本新增问题样例:(根据实际修改和增删) + + | Software | Version(根据实际修改,必填)| + | --- | --- | + | MindSpore | MindSpore 2.4.0 | + | CANN | 8.0.0.beta1 | + | Python | Python XXXXXX | + | OS platform | Ubuntu XXXXXX | + | GCC/Compiler version | XXXXXX | + + validations: + required: true +- type: textarea + attributes: + label: How you are installing MindScience + description: | + Paste the full command you are trying to execute. + placeholder: | + ```sh + pip install mindsponge_*.whl + ``` +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! \ No newline at end of file diff --git a/.gitee/ISSUE_TEMPLATE/3-bug-report.yml b/.gitee/ISSUE_TEMPLATE/3-bug-report.yml new file mode 100644 index 0000000000000000000000000000000000000000..ae8c1bae76121286453adbc054f4affaeebc0039 --- /dev/null +++ b/.gitee/ISSUE_TEMPLATE/3-bug-report.yml @@ -0,0 +1,308 @@ +name: 🐛 Bug report +description: Raise an issue here if you find a bug. +title: "[Bug]: " +labels: ["bug"] + +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to help MindScience and fill out this bug report! + - If this is your first time, please read [our contributor guidelines](https://gitee.com/mindspore/mindscience/blob/master/CONTRIBUTION.md). + - You also confirm that you have searched the [open documentation issues](https://gitee.com/mindspore/mindscience/issues) and have found no duplicates for this request + - type: dropdown + id: domain + attributes: + label: Which domain the issue belongs to? + options: + - MindSpore Science Core + - applications-SPONGE + - applications-Flow + - applications-Energy + - applications-Chemistry + - applications-Earth + - Others + validations: + required: true + + - type: input + id: version + attributes: + label: Version + description: What version of MindScience are you running? + placeholder: "example: r0.7" + validations: + required: true + + - type: textarea + attributes: + label: installation-method + description: | + Paste the full command you are trying to execute. + placeholder: | + ```sh + pip install mindsponge_*.whl + ``` + + - type: textarea + attributes: + label: Your current environment + description: | + Environment / 环境信息 (Mandatory / 必填) + value: | + - **Hardware Environment / 硬件环境(Mandatory / 必填)**: + Hardware (e.g.`Atlas 800T A2`) + + 样例: + + | 后端类型| 硬件具体类别 | + | --- | --- | + | Server | Atlas 800T A2 | + | CPU| Mac CPU/Win CPU/Linux CPU| + + + - **Software Environment / 软件环境 (Mandatory / 必填)**: + 迭代版本新增问题样例:(根据实际修改和增删) + + | Software | Version(根据实际修改,必填)| + | --- | --- | + | MindSpore | MindSpore 2.4.0 | + | CANN | 8.0.0.beta1 | + | Python | Python XXXXXX | + | OS platform | Ubuntu XXXXXX | + | GCC/Compiler version | XXXXXX | + + + bugfix版本问题引入样例:(根据实际修改和增删) + + | Software | Version(根据实际修改,必填)| + | --- | --- | + | MindSpore | MindSpore 2.4.0 (成功)master_202407131XXXXXX _a4230c71d(失败)| + | CANN | 8.0.0.beta1 | + | CANN 归档地址 | | + | Python | Python XXXXXX | + | OS platform | Ubuntu XXXXXX | + | GCC/Compiler version | XXXXXX | + + validations: + required: true + + + - type: textarea + id: description + attributes: + label: Describe the issue + description: | + Please provide a complete and succinct description of the problem, including what you expected to happen. + value: | + #### 1.Describe the current behavior / 问题描述 (Mandatory / 必填) + + 样例: (根据实际修改和增删) + + > sponge.colvar.Distance()报错,同时 sponge.metrics.Metric其子类的 .update() 报错 + + #### 2. / 关联用例 (Mandatory / 必填)Related testcase + + ```python + from mindspore import Tensor + from sponge.colvar import Distance + from sponge.metrics import MetricCV + cv = Distance([0,1]) + coordinate = Tensor([[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]]) + metric = MetricCV(cv) + metric.update(coordinate) + print(metric.eval()) + ``` + + #### 3.Steps to reproduce the issue / 重现步骤 (Mandatory / 必填) + + > 测试步骤:运行关联用例即可 + > 用例执行命令:来自CI日志或者用户执行命令 + + + #### 4.Describe the expected behavior / 预期结果 (Mandatory / 必填) + + > **【预期结果】**:MindSpore 1.10.1 版本下,可正常运行。预期输出为 [1.] + + #### 5.Related log / screenshot / 日志 / 截图 (Mandatory / 必填) + + ```shell + --------------------------------------------------------------------------- + ValueError Traceback (most recent call last) + Cell In[4], line 7 + 5 coordinate = Tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) + 6 metric = MetricCV(cv) + ----> 7 metric.update(coordinate) + 8 print(metric.eval()) + + File ~/mindscience/MindSPONGE/./src/sponge/metrics/metrics.py:190, in MetricCV.update(self, coordinate, pbc_box, energy, force, potentials, total_bias, biases) + 163 """ + 164 + 165 Args: + (...) + 186 V: Number of bias potential energies. + 187 """ + 188 #pylint: disable=unused-argument + --> 190 colvar = self.colvar(coordinate, pbc_box) + 192 self._value = self._convert_data(colvar) + + File ~/.local/lib/python3.8/site-packages/mindspore/nn/cell.py:705, in Cell.__call__(self, *args, **kwargs) + 703 except Exception as err: + 704 _pynative_executor.clear_res() + --> 705 raise err + 707 if isinstance(output, Parameter): + 708 output = output.data + + File ~/.local/lib/python3.8/site-packages/mindspore/nn/cell.py:701, in Cell.__call__(self, *args, **kwargs) + 699 try: + 700 _pynative_executor.new_graph(self, *args, **kwargs) + --> 701 output = self._run_construct(args, kwargs) + 702 _pynative_executor.end_graph(self, output, *args, **kwargs) + 703 except Exception as err: + + File ~/.local/lib/python3.8/site-packages/mindspore/nn/cell.py:482, in Cell._run_construct(self, cast_inputs, kwargs) + 480 output = self._shard_fn(*cast_inputs, **kwargs) + 481 else: + --> 482 output = self.construct(*cast_inputs, **kwargs) + 483 if self._enable_forward_hook: + 484 output = self._run_forward_hook(cast_inputs, output) + + File ~/mindscience/MindSPONGE/./src/sponge/colvar/basic/distance.py:146, in Distance.construct(self, coordinate, pbc_box) + 131 r"""calculate distance. + 132 + 133 Args: + (...) + 142 + 143 """ + 145 # (B, ..., D) + --> 146 vector = self.vector(coordinate, pbc_box) + 148 # (B, ...) or (B, ..., 1) + 149 if self.norm_last_dim is None: + + File ~/.local/lib/python3.8/site-packages/mindspore/nn/cell.py:705, in Cell.__call__(self, *args, **kwargs) + 703 except Exception as err: + 704 _pynative_executor.clear_res() + --> 705 raise err + 707 if isinstance(output, Parameter): + 708 output = output.data + + File ~/.local/lib/python3.8/site-packages/mindspore/nn/cell.py:701, in Cell.__call__(self, *args, **kwargs) + 699 try: + 700 _pynative_executor.new_graph(self, *args, **kwargs) + --> 701 output = self._run_construct(args, kwargs) + 702 _pynative_executor.end_graph(self, output, *args, **kwargs) + 703 except Exception as err: + + File ~/.local/lib/python3.8/site-packages/mindspore/nn/cell.py:482, in Cell._run_construct(self, cast_inputs, kwargs) + 480 output = self._shard_fn(*cast_inputs, **kwargs) + 481 else: + --> 482 output = self.construct(*cast_inputs, **kwargs) + 483 if self._enable_forward_hook: + 484 output = self._run_forward_hook(cast_inputs, output) + + File ~/mindscience/MindSPONGE/./src/sponge/colvar/atoms/vector.py:183, in Vector.construct(self, coordinate, pbc_box) + 180 atoms1 = self.atoms1(coordinate, pbc_box) + 181 else: + 182 # (B, ..., 2, D) + --> 183 atoms = self.atoms(coordinate, pbc_box) + 184 # (B, ..., 1, D) <- (B, ..., 2, D) + 185 atoms0, atoms1 = self.split2(atoms) + + File ~/.local/lib/python3.8/site-packages/mindspore/nn/cell.py:705, in Cell.__call__(self, *args, **kwargs) + 703 except Exception as err: + 704 _pynative_executor.clear_res() + --> 705 raise err + 707 if isinstance(output, Parameter): + 708 output = output.data + + File ~/.local/lib/python3.8/site-packages/mindspore/nn/cell.py:701, in Cell.__call__(self, *args, **kwargs) + 699 try: + 700 _pynative_executor.new_graph(self, *args, **kwargs) + --> 701 output = self._run_construct(args, kwargs) + 702 _pynative_executor.end_graph(self, output, *args, **kwargs) + 703 except Exception as err: + + File ~/.local/lib/python3.8/site-packages/mindspore/nn/cell.py:482, in Cell._run_construct(self, cast_inputs, kwargs) + 480 output = self._shard_fn(*cast_inputs, **kwargs) + 481 else: + --> 482 output = self.construct(*cast_inputs, **kwargs) + 483 if self._enable_forward_hook: + 484 output = self._run_forward_hook(cast_inputs, output) + + File ~/mindscience/MindSPONGE/./src/sponge/colvar/atoms/atoms.py:232, in Atoms.construct(self, coordinate, pbc_box) + 219 r"""get position coordinate(s) of specific atom(s) + 220 + 221 Args: + (...) + 229 + 230 """ + 231 # (B, a_1, a_2, ..., a_{n}, D) <- (B, A, D) + --> 232 atoms = func.gather_vector(coordinate, self.index) + 233 if self.keep_in_box: + 234 atoms = self.coordinate_in_pbc(atoms, pbc_box) + + File ~/.local/lib/python3.8/site-packages/mindspore/common/api.py:718, in jit..wrap_mindspore..staging_specialize(*args, **kwargs) + 716 if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME: + 717 process_obj = hash_args + --> 718 out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj, jit_config)(*args, **kwargs) + 719 return out + + File ~/.local/lib/python3.8/site-packages/mindspore/common/api.py:121, in _wrap_func..wrapper(*arg, **kwargs) + 119 @wraps(fn) + 120 def wrapper(*arg, **kwargs): + --> 121 results = fn(*arg, **kwargs) + 122 return _convert_python_data(results) + + File ~/.local/lib/python3.8/site-packages/mindspore/common/api.py:350, in _MindsporeFunctionExecutor.__call__(self, *args, **kwargs) + 348 except Exception as err: + 349 _pynative_executor.clear_res() + --> 350 raise err + 352 if context.get_context("precompile_only"): + 353 return None + + File ~/.local/lib/python3.8/site-packages/mindspore/common/api.py:344, in _MindsporeFunctionExecutor.__call__(self, *args, **kwargs) + 342 if context.get_context("mode") == context.PYNATIVE_MODE: + 343 _pynative_executor.set_jit_compile_status(True, phase) + --> 344 phase = self.compile(self.fn.__name__, *args_list, **kwargs) + 345 _pynative_executor.set_jit_compile_status(False, phase) + 346 else: + + File ~/.local/lib/python3.8/site-packages/mindspore/common/api.py:435, in _MindsporeFunctionExecutor.compile(self, method_name, *args, **kwargs) + 433 else: + 434 setattr(self.fn, "__jit_function__", True) + --> 435 is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, True) + 436 if isinstance(self.fn, types.MethodType): + 437 delattr(self.fn.__func__, "__jit_function__") + + ValueError: For primitive[BroadcastTo], the attribute[x shape] must be less than or equal to 1, but got 2. + + ---------------------------------------------------- + - C++ Call Stack: (For framework developers) + ---------------------------------------------------- + mindspore/core/utils/check_convert_utils.cc:675 Check + ``` + + ### + + #### 6.Special notes for this issue/备注 (Optional / 选填) + + **【定位人】**吴某某(根据实际修改) + + validations: + required: true + + + - type: textarea + id: mvr + attributes: + label: Minimum reproducible example + description: Please supply a [minimum reproducible code example](https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports) here. + render: shell + + - type: textarea + id: logs + attributes: + label: Relevant log output + description: Please paste relevant error and log output here + render: shell + diff --git a/.gitee/ISSUE_TEMPLATE/4-ci-failure.yml b/.gitee/ISSUE_TEMPLATE/4-ci-failure.yml new file mode 100644 index 0000000000000000000000000000000000000000..b63c2480454460fe145fc4975a9dda124445f41e --- /dev/null +++ b/.gitee/ISSUE_TEMPLATE/4-ci-failure.yml @@ -0,0 +1,83 @@ +name: 🧪 CI failure report +description: Report a failing test. +title: "[CI Failure]: " +labels: ["ci-failure"] + +body: +- type: dropdown + id: domain + attributes: + label: Which domain the issue belongs to? + options: + - MindSpore Science Core + - applications-SPONGE + - applications-Flow + - applications-Energy + - applications-Chemistry + - applications-Earth + - Others + validations: + required: true + +- type: markdown + attributes: + value: > + #### Include the name of the failing Buildkite step and test file in the title. +- type: input + attributes: + label: Name of failing test + description: | + Paste in the fully-qualified name of the failing test from the logs. + placeholder: | + `path/to/test_file.py::test_name[params]` + validations: + required: true +- type: checkboxes + attributes: + label: Basic information + description: Select all items that apply to the failing test. + options: + - label: Flaky test + - label: Can reproduce locally + - label: Caused by external libraries (e.g. bug in `transformers`) +- type: textarea + attributes: + label: 🧪 Describe the failing test + description: | + Please provide a clear and concise description of the failing test. + placeholder: | + A clear and concise description of the failing test. + + ``` + The error message you got, with the full traceback and the error logs with [dump_input.py:##] if present. + ``` + validations: + required: true +- type: textarea + attributes: + label: 📝 History of failing test + description: | + Since when did the test start to fail? + + If you have time, identify the PR that caused the test to fail on main. You can do so via the following methods: + + - Use Buildkite Test Suites to find the PR where the test failure first occurred, and reproduce the failure locally. + + - Run [`git bisect`](https://git-scm.com/docs/git-bisect) locally. + + - Manually unblock Buildkite steps for suspected PRs on main and check the results. (authorized users only) + placeholder: | + Approximate timeline and/or problematic PRs + + A link to the Buildkite analytics of the failing test (if available) + validations: + required: true +- type: textarea + attributes: + label: CC List. + description: > + The list of people you want to CC. Usually, this includes those who worked on the PR that failed the test. +- type: markdown + attributes: + value: > + Thanks for reporting 🙏! \ No newline at end of file diff --git a/.gitee/ISSUE_TEMPLATE/5-feature-request.yml b/.gitee/ISSUE_TEMPLATE/5-feature-request.yml new file mode 100644 index 0000000000000000000000000000000000000000..a51b2526c0a5e9587e8c64e1fdb3bca24d8869aa --- /dev/null +++ b/.gitee/ISSUE_TEMPLATE/5-feature-request.yml @@ -0,0 +1,58 @@ +name: 🚀 Feature request +description: Submit a proposal/request for a new MindScience feature +title: "[Feature]: " +labels: ["feature"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://gitee.com/mindspore/mindscience/issues). +- type: dropdown + id: module + attributes: + label: Which module the issue belongs to? + options: + - MindScience data + - MindScience common + - MindScience e3nn + - MindScience models + - MindScience sciops + - MindScience solver + - MindScience sharker + - MindScience utils + - Others + validations: + required: true +- type: dropdown + id: new_or_improvement + attributes: + label: Is this a new feature, an improvement, or a change to existing functionality? + options: + - New Feature + - Improvement + - Change + validations: + required: true + +- type: textarea + attributes: + label: 🚀 The feature, motivation and pitch + description: > + A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too. For feature design, you can refer to [feature design template](https://gitee.com/mindspore/mindscience/blob/br_refactor/docs/template/feature_design.md). + validations: + required: true +- type: textarea + attributes: + label: Alternatives + description: > + A description of any alternative solutions or features you've considered, if any. +- type: textarea + attributes: + label: Additional context + description: > + Add any other context or screenshots about the feature request. +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! \ No newline at end of file diff --git a/.gitee/ISSUE_TEMPLATE/6-application-case.yml b/.gitee/ISSUE_TEMPLATE/6-application-case.yml new file mode 100644 index 0000000000000000000000000000000000000000..dfbf8e154b1fa9ff77b6e154895f911627c8ce97 --- /dev/null +++ b/.gitee/ISSUE_TEMPLATE/6-application-case.yml @@ -0,0 +1,45 @@ +name: 🤗 Support request for a new application case of science +description: Submit a proposal/request for a new application case of science +title: "[Application Case]: " +labels: ["application-case"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://gitee.com/mindspore/mindscience/issues). +- type: dropdown + id: domain + attributes: + label: Which domain the issue belongs to? + options: + - MindSPONGE + - MindFlow + - MindEnergy + - MindChemistry + - MindEarth + - Others + validations: + required: true + +- type: textarea + attributes: + label: The model to consider. + description: > + A url, pointing to the model, e.g. https://huggingface.co/openai-community/gpt2 . + validations: + required: true +- type: textarea + attributes: + label: The closest model MindScience already supports. + description: > + Here is the list of models already supported by MindScience: https://gitee.com/mindspore/mindscience#%E6%A6%82%E8%BF%B0 . Which model is the most similar to the model you want to add support for? +- type: textarea + attributes: + label: What's your difficulty of supporting the model you want? + description: > + For example, any new operators or new architecture? +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! \ No newline at end of file diff --git a/.gitee/ISSUE_TEMPLATE/7-RFC.yml b/.gitee/ISSUE_TEMPLATE/7-RFC.yml new file mode 100644 index 0000000000000000000000000000000000000000..63a9035e773d522cb6109f71606521b4da24c005 --- /dev/null +++ b/.gitee/ISSUE_TEMPLATE/7-RFC.yml @@ -0,0 +1,87 @@ +name: 💬 Request for comments (RFC). +description: Ask for feedback on major architectural changes or design choices. +title: "[RFC]: " +labels: ["RFC"] + +body: +- type: markdown + attributes: + value: > + #### Please take a look at previous [RFCs](https://gitee.com/mindspore/mindscience/issues) for reference. +- type: dropdown + id: domain + attributes: + label: Which domain the issue belongs to? + options: + - MindSpore Science Core + - applications-SPONGE + - applications-Flow + - applications-Energy + - applications-Chemistry + - applications-Earth + - Others + validations: + required: true + +- type: textarea + attributes: + label: Backgroud. + description: > + Backgroud(背景信息) + placeholder: | + - Describe/Explain the status of the problem you wish to solve. + - Attach relevant issues if there is any. + validations: + required: true +- type: textarea + attributes: + label: Origin + description: > + Origin(信息来源) + placeholder: | + - Explain which department/team made this request so that its priority can be given. + validations: + required: true +- type: textarea + attributes: + label: Benefit / Necessity + description: > + Benefit / Necessity (价值/作用) + placeholder: | + - Describe/Explain the key value by fulfilling the request. + validations: + required: true +- type: textarea + attributes: + label: Design + description: > + Design(设计方案) + placeholder: | + - Describe/Explain the general idea of the design. Pseudo-code is allowed + validations: + required: true +- type: textarea + attributes: + label: Feedback Period. + description: > + The feedback period of the RFC. Usually at least one week. + validations: + required: false +- type: textarea + attributes: + label: CC List. + description: > + The list of people you want to CC. + validations: + required: false +- type: textarea + attributes: + label: Any Other Things. + description: > + Any other things you would like to mention. + validations: + required: false +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! \ No newline at end of file diff --git a/.gitee/ISSUE_TEMPLATE/8-internship.yml b/.gitee/ISSUE_TEMPLATE/8-internship.yml new file mode 100644 index 0000000000000000000000000000000000000000..5bd065bcf087662ec146c951cf1271dfd3ec48f4 --- /dev/null +++ b/.gitee/ISSUE_TEMPLATE/8-internship.yml @@ -0,0 +1,81 @@ +name: 💻 Internship +description: This issue is intended for the MindScience open source internship project for college students +title: "[Internship]: " +labels: ["internship"] + + +body: +- type: markdown + attributes: + value: | + - This issue is intended for the MindSpore open source internship project for college students. Developers who do not participate in this project are not allowed to receive it. + - 本issue为面向高校学生的“MindSpore开源实习”项目的任务,非参加该项目的人员勿领。 +- type: dropdown + id: domain + attributes: + label: Which domain the issue belongs to? + options: + - MindSpore Science Core + - applications-SPONGE + - applications-Flow + - applications-Energy + - applications-Chemistry + - applications-Earth + - Others + validations: + required: true + +- type: textarea + attributes: + label: Your information. + description: > + Your information for intership. + value: | + 【Task score】 + 【Background description】 + 【Requirements】 + 【Development environment】 + - Hardware: + - Software: + + 【Programming language】 + 【Acceptance criteria】 + 【PR Submission address】 + 【Expected completion time】 + 【Development guide】 + 【Tutor & email】 + + Note: This issue is intended for the MindSpore open source internship project for college students. Developers who do not participate in this project are not allowed to receive it. + + --- + + 【任务分值】 + 【背景描述】 + 【需求描述】 + 【环境要求】 + - 硬件: + - 软件: + + 【编程语言】 + 【产出标准】 + 【PR提交地址】 + 【期望完成时间】 + 【开发指导】 + 【导师及邮箱】 + + 本issue为面向高校学生的“MindSpore开源实习”项目的任务,非参加该项目的人员勿领。 + + validations: + required: false + +- type: textarea + attributes: + label: Any Other Things. + description: > + Any other things you would like to mention. + validations: + required: false +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! \ No newline at end of file diff --git a/.gitee/ISSUE_TEMPLATE/config.yaml b/.gitee/ISSUE_TEMPLATE/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b8ffbb4b2dff618963057deaa280bc530aaef7d4 --- /dev/null +++ b/.gitee/ISSUE_TEMPLATE/config.yaml @@ -0,0 +1,5 @@ +blank_issues_enabled: false +contact_links: + - name: Gitee 帮助中心 + url: https://help.gitee.com/ + about: 提供 Git 使用指南、教程、Gitee.com 平台基本功能使用、介绍和常见问题解答 \ No newline at end of file diff --git a/.gitee/PULL_REQUEST_TEMPLATE.en.md b/.gitee/PULL_REQUEST_TEMPLATE.en.md new file mode 100644 index 0000000000000000000000000000000000000000..9108f279047b6c2a0c9e198f539f447cd9aa7654 --- /dev/null +++ b/.gitee/PULL_REQUEST_TEMPLATE.en.md @@ -0,0 +1,30 @@ +### PR Source +- [ ] Issue (Please link related issue) +- [ ] Feature request +- [ ] Bug report +- [ ] Community contributor + +### Change Description +- **Reason for Modification:** + +- **Content Modified:** + +### Function Validation +- [ ] **Self-verification** +- [ ] **Screenshots of local test cases** + +### Checklist +- [ ] **Code reviewed** +- [ ] **UT test coverage** (If not, explain reason: ____________________) +- [ ] **Involves public API changes in MindSpore Science** +- [ ] **Documentation updated** + +### Code Review Requirements +- Changes over 1000 lines require organized review meeting with conclusions +- PR without function validation cannot be merged +- PR with incomplete checklist cannot be merged +- PR without clear source identification or change description cannot be merged + +### Change Notification +- [ ] **Documentation modified** +- [ ] **API change description** (If API changed, detail description): \ No newline at end of file diff --git a/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md b/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md new file mode 100644 index 0000000000000000000000000000000000000000..c166c30dc0a742dc9666ec760788ece4d5c61963 --- /dev/null +++ b/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md @@ -0,0 +1,30 @@ +### PR来源 +- [ ] issue单(请关联issue) +- [ ] 需求特性 +- [ ] 问题单 +- [ ] 社区开发者贡献 + +### 修改描述 +- **修改原因:** + +- **修改内容:** + +### 功能验证 +- [ ] **功能自验** +- [ ] **本地自验用例截图** + +### 检查清单 +- [ ] **是否经过代码检视** +- [ ] **是否具备UT测试用例看护**(如不符合,请说明原因:____________________) +- [ ] **是否涉及MindSpore Science公共接口变更** +- [ ] **是否涉及文档更新** + +### 代码检视要求 +- 合入代码超过1000行,需组织会议检视并附上结论 +- 未完成功能验证不允许合入 +- 未完成检查清单不允许合入 +- PR来源未标识或修改描述不清晰不允许合入 + +### 变更说明 +- [ ] **文档修改** +- [ ] **接口变更说明**(如涉及接口变更需详细描述): \ No newline at end of file diff --git a/MindChemistry/applications/crystalflow/README.md b/MindChemistry/applications/crystalflow/README.md index bead72bb0f5c8f24292fff5177ffb80c370a2691..afdea75c1fca29bd71d682992d6e88e8296d270f 100644 --- a/MindChemistry/applications/crystalflow/README.md +++ b/MindChemistry/applications/crystalflow/README.md @@ -15,7 +15,7 @@ ## 快速入门 > 1. 将Mindchemistry/mindchemistry文件包下载到当前目录 -> 2. 在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/)下载相应的数据集 +> 2. 在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/dataset/)下载相应的数据集 > 3. 安装依赖包:`pip install -r requirement.txt` > 4. 训练命令: `python train.py` > 5. 预测命令: `python evaluate.py` @@ -54,7 +54,7 @@ applications ## 下载数据集 -在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/)中下载相应的数据集文件夹和dataset_prop.txt数据集属性文件放置于当前路径的dataset文件夹下(如果没有需要自己手动创建),文件路径参考: +在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/dataset/)中下载相应的数据集文件夹和dataset_prop.txt数据集属性文件放置于当前路径的dataset文件夹下(如果没有需要自己手动创建),文件路径参考: ```txt crystalflow @@ -87,8 +87,6 @@ python train.py ### 推理 -将权重的path写入config文件的checkpoint.last_path中。预训练模型可以从[预训练模型链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/pre-train)中获取。 - 更改config文件中的test字段来更改推理参数,特别是test.num_eval,它**决定了对于每个组分生成多少个样本**,对于后续的评估阶段很重要。 ```bash diff --git a/MindChemistry/applications/crystalflow/test_crystalflow.py b/MindChemistry/applications/crystalflow/test_crystalflow.py index ffce4881fb9075c96dc7b1eb0afdd56a5a0a8779..0f9531b135384ba2e7486eb0caaf5fa412fd7ac5 100644 --- a/MindChemistry/applications/crystalflow/test_crystalflow.py +++ b/MindChemistry/applications/crystalflow/test_crystalflow.py @@ -1,6 +1,7 @@ """model test""" import math import os +import urllib.request import mindspore as ms import mindspore.numpy as mnp @@ -33,6 +34,10 @@ class SinusoidalTimeEmbeddings(nn.Cell): (ops.Sin()(embeddings), ops.Cos()(embeddings))) return embeddings +def download_file(url, filename): + urllib.request.urlretrieve(url, filename) + print(f"File downloaded successfully: {filename}") + def test_cspnet(): """test cspnet.py""" ms.set_seed(1234) @@ -153,7 +158,8 @@ def test_loss(): cspnet = CSPNet(num_layers=6, hidden_dim=512, num_freqs=256) cspflow = CSPFlow(cspnet) - mindspore_ckpt = load_checkpoint("./torch2ms_ckpt/ms_flow.ckpt") + download_file('https://download-mindspore.osinfra.cn/mindscience/mindchemistry/crystalflow/ms_flow.ckpt', 'ms_flow.ckpt') + mindspore_ckpt = load_checkpoint("ms_flow.ckpt") load_param_into_net(cspflow, mindspore_ckpt) loss_func_mse = L2LossMask(reduction='mean') diff --git a/MindChemistry/applications/orb/Parallel_Implementation.md b/MindChemistry/applications/orb/Parallel_Implementation.md new file mode 100644 index 0000000000000000000000000000000000000000..285183f0f6a68dc9179d6763412d6c080dc4e48b --- /dev/null +++ b/MindChemistry/applications/orb/Parallel_Implementation.md @@ -0,0 +1,121 @@ +# ORB模型并行训练说明文档 + +本文档说明了ORB模型从单卡训练到多卡并行训练的实现方案、启动方式以及性能提升结果。 + +## 一、并行实现 + +对比`finetune.py`和`finetune_parallel.py`,主要有以下几处改动: + +1、引入并行训练所需的mindspore通信模块: + +```python +from mindspore.communication import init +from mindspore.communication import get_rank, get_group_size +``` + +2、训练步骤中增加梯度聚合: + +```python +# 单卡版本 +grad_fn = ms.value_and_grad(model.loss, None, optimizer.parameters, has_aux=True) + +# 并行版本 +grad_fn = ms.value_and_grad(model.loss, None, optimizer.parameters, has_aux=True) +grad_reducer = nn.DistributedGradReducer(optimizer.parameters) # 新增梯度规约器 +``` + +3、数据加载时实现数据分片: + +```python +# 单卡版本 +dataloader = [base.batch_graphs([dataset[j] for j in range(i, min(i + batch_size, len(dataset)))]) + for i in range(0, len(dataset), batch_size)] + +# 并行版本 +rank_id = get_rank() +rank_size = get_group_size() +dataloader = [[dataset[j] for j in range(i, min(i + batch_size, len(dataset)))] + for i in range(0, len(dataset), batch_size)] +dataloader = [base.batch_graphs( + data[rank_id*len(data)//rank_size : (rank_id+1)*len(data)//rank_size] +) for data in dataloader] +``` + +4、初始化并行训练环境: + +```python +ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True) +init() +``` + +## 二、启动方式 + +设置训练参数 + +> 1. 修改`configs/config_parallel.yaml`中的参数: +> a. 设置`data_path`字段指定训练和测试数据集 +> b. 设置`checkpoint_path`指定预训练模型权重路径 +> c. 根据需要调整其他训练参数 +> 2. 修改`run_parallel.sh`中的并行数: +> a. 通过`--worker_num=4 --local_worker_num=4`设置使用卡的数量 + +启动训练 + +```bash +pip install -r requirement.txt +bash run_parallel.sh +``` + +## 三、性能提升 + +单卡训练结果如下所示: + +```log +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Model has 25213610 trainable parameters. +Epoch: 0/100, + train_metrics: {'data_time': 0.00010895108183224995, 'train_time': 386.58018293464556, 'energy_reference_mae': 5.598883946736653, 'energy_mae': 3.3611322244008384, 'energy_mae_raw': 103.14391835530598, 'stress_mae': 41.36046473185221, 'stress_mae_raw': 12.710869789123535, 'node_mae': 0.02808943825463454, 'node_mae_raw': 0.0228044210622708, 'node_cosine_sim': 0.7026202281316122, 'fwt_0.03': 0.23958333333333334, 'loss': 44.74968592325846} + val_metrics: {'energy_reference_mae': 5.316623687744141, 'energy_mae': 3.594848871231079, 'energy_mae_raw': 101.00129699707031, 'stress_mae': 30.630516052246094, 'stress_mae_raw': 9.707925796508789, 'node_mae': 0.017718862742185593, 'node_mae_raw': 0.014386476017534733, 'node_cosine_sim': 0.5506304502487183, 'fwt_0.03': 0.375, 'loss': 34.24308395385742} + +... + +Epoch: 99/100, + train_metrics: {'data_time': 7.802306208759546e-05, 'train_time': 59.67856075416785, 'energy_reference_mae': 5.5912095705668134, 'energy_mae': 0.007512244085470836, 'energy_mae_raw': 0.21813046435515085, 'stress_mae': 0.7020445863405863, 'stress_mae_raw': 2.222463607788086, 'node_mae': 0.04725319395462672, 'node_mae_raw': 0.042800972859064736, 'node_cosine_sim': 0.3720853428045909, 'fwt_0.03': 0.09895833333333333, 'loss': 0.7568100094795227} + val_metrics: {'energy_reference_mae': 5.308632850646973, 'energy_mae': 0.27756747603416443, 'energy_mae_raw': 3.251189708709717, 'stress_mae': 2.8720269203186035, 'stress_mae_raw': 9.094478607177734, 'node_mae': 0.05565642938017845, 'node_mae_raw': 0.05041291564702988, 'node_cosine_sim': 0.212838813662529, 'fwt_0.03': 0.19499999284744263, 'loss': 3.2052507400512695} +Checkpoint saved to orb_ckpts/ +Training time: 7333.08717 seconds + +``` + +四卡并行训练结果如下所示: + +```log +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Model has 25213607 trainable parameters. +Model has 25213607 trainable parameters. +Model has 25213607 trainable parameters. +Model has 25213607 trainable parameters. + +... + +Training time: 2375.89474 seconds +Training time: 2377.02413 seconds +Training time: 2377.22778 seconds +Training time: 2376.63176 seconds + +``` + +在相同的训练配置下,并行训练相比单卡训练取得了显著的性能提升: + +- 单卡训练耗时:7293.28995 seconds +- 4卡并行训练耗时:2377.22778 seconds +- 性能提升:67.40% +- 加速比:3.07倍 diff --git a/MindChemistry/applications/orb/README.md b/MindChemistry/applications/orb/README.md new file mode 100644 index 0000000000000000000000000000000000000000..afb66a8648eb9650b4104d0d7d7cf8cb986266a3 --- /dev/null +++ b/MindChemistry/applications/orb/README.md @@ -0,0 +1,171 @@ + +# 模型名称 + +> Orb + +## 介绍 + +> 材料科学中,设计新型功能材料一直是新兴技术的关键部分。然而,传统的从头算计算方法在设计新型无机材料时速度慢且难以扩展到实际规模的系统。近年来,深度学习方法在多个领域展示了其强大的能力,能够通过并行架构高效运行。ORB模型的核心创新在于将这种深度学习方法应用于材料建模,通过可扩展的图神经网络架构学习原子间相互作用的复杂性。ORB模型是一个基于图神经网络(GNN)的机器学习力场(MLFF),设计为通用的原子间势能模型,适用于多种模拟任务(几何优化、蒙特卡洛模拟和分子动力学模拟)。该模型的输入是一个图结构,包含原子的位置、类型以及系统配置(如晶胞尺寸和边界条件);输出包括系统的总能量、每个原子的力向量以及单元格应力。与现有的开源神经网络势能模型(如MACE)相比,ORB模型在大系统规模下的速度提高了3-6倍。在Matbench Discovery基准测试中,ORB模型的误差比其他方法降低了31%,并且在发布时成为该基准测试的最新最佳模型。ORB模型在零样本评估中表现出色,即使在没有针对特定任务进行微调的情况下,也能在高温度非周期分子的分子动力学模拟中保持稳定。 + +![Orb模型预测自由能](docs/orb.png) + +> 上图中:(a) 通过Widom插入法在Mg-MOF-74中获得的MACE + D3(左)和Orb-D3(右)自由能表面。开放金属位点附近的蓝色区域代表最低自由能,表明这些是CO2的优势吸附位点。(b) CO2在Mg-MOF-74中的吸附位置,展示了通过Widom插入法获得的两个最有利的吸附位点,其吸附能分别为-54.5 kJ/mol和-54.4 kJ/mol。虽然Orb和MACE预测的能量极小值位置相似,但ORB的自由能最小值与实验测得的吸附热(-44 kJ/mol)数值更为接近。 + +## 环境要求 + +> 1. 安装`mindspore(2.5.0)` +> 2. 安装`mindchemistry` +> 3. 安装依赖包:`pip install -r requirement.txt` + +## 快速入门 + +> 1. 在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/dataset/)下载相应的数据集并放在`dataset`目录下 +> 2. 在[模型链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/orb_ckpts/)下载orb预训练模型ckpt并放在`orb_ckpts`目录下 +> 3. 安装依赖包:`pip install -r requirement.txt` +> 4. 单卡训练命令: `bash run.sh` +> 5. 多卡训练命令: `bash run_parallel.sh` +> 6. 评估命令: `python evaluate.py` +> 7. 模型预测结果会存在`results`目录下 + +### 代码目录结构 + +```text +代码主要模块在src文件夹下,其中dataset文件夹下是数据集,orb_ckpts文件夹下是预训练模型和训练好的模型权重文件,configs文件夹下是各代码的参数配置文件。 + +orb_models # 模型名 +├── dataset + ├── train_mptrj_ase.db # 微调阶段训练数据集 + └── val_mptrj_ase.db # 微调阶段测试数据集 +├── orb_ckpts + └── orb-mptraj-only-v2.ckpt # 预训练模型checkpoint +├── configs + ├── config.yaml # 单卡训练参数配置文件 + ├── config_parallel.yaml # 多卡并行训练参数配置文件 + └── config_eval.yaml # 推理参数配置文件 +├── src + ├── __init__.py + ├── ase_dataset.py # 处理和加载数据集 + ├── atomic_system.py # 定义原子系统的数据结构 + ├── base.py # 基础类定义 + ├── featurization_utilities.py # 提供将原子系统转换为特征向量的工具 + ├── pretrained.py # 预训练模型相关函数 + ├── property_definitions.py # 定义原子系统中各种物理性质的计算方式和命名规则 + ├── trainer.py # 模型loss类定义 + ├── segment_ops.py # 提供对数据进行分段处理的工具 + └── utils.py # 工具模块 +├── finetune.py # 模型微调代码 +├── evaluate.py # 模型推理代码 +├── run.sh # 单卡训练启动脚本 +├── run_parallel.sh # 多卡并行训练启动脚本 +└── requirement.txt # 环境 +``` + +## 下载数据集 + +在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/dataset/)下载训练和测试数据集放置于当前路径的dataset文件夹下(如果没有需要自己手动创建);在[模型链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/orb_ckpts/)下载orb预训练模型`orb-mptraj-only-v2.ckpt`放置于当前路径的orb_ckpts文件夹下(如果没有需要自己手动创建);文件路径参考[代码目录结构](#代码目录结构) + +## 训练过程 + +### 单卡训练 + +更改`configs/config.yaml`文件中训练参数: + +> 1. 设置微调阶段的训练和测试数据集,见`data_path`字段 +> 2. 设置训练加载的预训练模型权重文件,更改`checkpoint_path`路径字段 +> 3. 其它训练设置见Training Configuration部分 + +```bash +pip install -r requirement.txt +bash run.sh +``` + +代码运行结果如下所示: + +```log +============================================================================================================== +Please run the script as: +bash run.sh +============================================================================================================== +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Model has 25213610 trainable parameters. +Epoch: 0/100, + train_metrics: {'data_time': 0.00010895108183224995, 'train_time': 386.58018293464556, 'energy_reference_mae': 5.598883946736653, 'energy_mae': 3.3611322244008384, 'energy_mae_raw': 103.14391835530598, 'stress_mae': 41.36046473185221, 'stress_mae_raw': 12.710869789123535, 'node_mae': 0.02808943825463454, 'node_mae_raw': 0.0228044210622708, 'node_cosine_sim': 0.7026202281316122, 'fwt_0.03': 0.23958333333333334, 'loss': 44.74968592325846} + val_metrics: {'energy_reference_mae': 5.316623687744141, 'energy_mae': 3.594848871231079, 'energy_mae_raw': 101.00129699707031, 'stress_mae': 30.630516052246094, 'stress_mae_raw': 9.707925796508789, 'node_mae': 0.017718862742185593, 'node_mae_raw': 0.014386476017534733, 'node_cosine_sim': 0.5506304502487183, 'fwt_0.03': 0.375, 'loss': 34.24308395385742} + +... + +Epoch: 99/100, + train_metrics: {'data_time': 7.802306208759546e-05, 'train_time': 59.67856075416785, 'energy_reference_mae': 5.5912095705668134, 'energy_mae': 0.007512244085470836, 'energy_mae_raw': 0.21813046435515085, 'stress_mae': 0.7020445863405863, 'stress_mae_raw': 2.222463607788086, 'node_mae': 0.04725319395462672, 'node_mae_raw': 0.042800972859064736, 'node_cosine_sim': 0.3720853428045909, 'fwt_0.03': 0.09895833333333333, 'loss': 0.7568100094795227} + val_metrics: {'energy_reference_mae': 5.308632850646973, 'energy_mae': 0.27756747603416443, 'energy_mae_raw': 3.251189708709717, 'stress_mae': 2.8720269203186035, 'stress_mae_raw': 9.094478607177734, 'node_mae': 0.05565642938017845, 'node_mae_raw': 0.05041291564702988, 'node_cosine_sim': 0.212838813662529, 'fwt_0.03': 0.19499999284744263, 'loss': 3.2052507400512695} +Checkpoint saved to orb_ckpts/ +Training time: 7333.08717 seconds +``` + +### 多卡并行训练 + +更改`configs/config_parallel.yaml`和`run_parallel.sh`文件中训练参数: + +> 1. 设置微调阶段的训练和测试数据集,见`data_path`字段 +> 2. 设置训练加载的预训练模型权重文件,更改`checkpoint_path`路径字段 +> 3. 其它训练设置见Training Configuration部分 +> 4. 修改`run_parallel.sh`文件中`--worker_num=4 --local_worker_num=4`来设置调用的卡的数量 + +```bash +pip install -r requirement.txt +bash run_parallel.sh +``` + +代码运行结果如下所示: + +```log +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Model has 25213607 trainable parameters. +Model has 25213607 trainable parameters. +Model has 25213607 trainable parameters. +Model has 25213607 trainable parameters. + +... + +Training time: 2375.89474 seconds +Training time: 2377.02413 seconds +Training time: 2377.22778 seconds +Training time: 2376.63176 seconds +``` + +### 推理 + +更改`configs/config_eval.yaml`文件中推理参数: + +> 1. 设置测试数据集,见`val_data_path`字段 +> 2. 设置推理加载的预训练模型权重文件,更改`checkpoint_path`路径字段 +> 3. 其它训练设置见Evaluating Configuration部分 + +```bash +python evaluate.py +``` + +代码运行结果如下所示: + +```log +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Model has 25213607 trainable parameters. +.Validation loss: 0.89507836 + energy_reference_mae: 5.3159098625183105 + energy_mae: 0.541229784488678 + energy_mae_raw: 4.244375228881836 + stress_mae: 0.22862032055854797 + stress_mae_raw: 10.575761795043945 + node_mae: 0.12522821128368378 + node_mae_raw: 0.04024107754230499 + node_cosine_sim: 0.38037967681884766 + fwt_0.03: 0.22499999403953552 + loss: 0.8950783610343933 +``` diff --git a/MindChemistry/applications/orb/configs/config.yaml b/MindChemistry/applications/orb/configs/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cbd4189a3ffa819b92524d4a815f7c107cbf0432 --- /dev/null +++ b/MindChemistry/applications/orb/configs/config.yaml @@ -0,0 +1,39 @@ +# Training Configuration +train_data_path: dataset/train_mptrj_ase.db +val_data_path: dataset/val_mptrj_ase.db +num_workers: 8 +batch_size: 64 +gradient_clip_val: 0.5 +max_epochs: 100 +checkpoint_path: orb_ckpts/ +lr: 3.0e-4 +random_seed: 1234 + +# Model Configuration +model: + # Energy Head Configuration + energy_head: + latent_dim: 256 + num_mlp_layers: 1 + mlp_hidden_dim: 256 + target: "energy" + node_aggregation: "mean" + reference_energy_name: "vasp-shifted" + train_reference: true + predict_atom_avg: true + + # Node Head Configuration + node_head: + latent_dim: 256 + num_mlp_layers: 1 + mlp_hidden_dim: 256 + target: "forces" + remove_mean: true + + # Stress Head Configuration + stress_head: + latent_dim: 256 + num_mlp_layers: 1 + mlp_hidden_dim: 256 + target: "stress" + compute_stress: true diff --git a/MindChemistry/applications/orb/configs/config_eval.yaml b/MindChemistry/applications/orb/configs/config_eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e98c5f0b0d98f1c2dd87bcb095b0c19036ef2b4 --- /dev/null +++ b/MindChemistry/applications/orb/configs/config_eval.yaml @@ -0,0 +1,40 @@ +# Evaluating Configuration +mode: "PYNATIVE" +device_target: "Ascend" +device_id: 0 +# Dataset config +val_data_path: dataset/val_mptrj_ase.db +num_workers: 8 +batch_size: 64 +checkpoint_path: orb_ckpts/orb-ft-checkpoint_epoch99.ckpt +random_seed: 1234 +output_dir: results/ + +# Model Configuration +model: + # Energy Head Configuration + energy_head: + latent_dim: 256 + num_mlp_layers: 1 + mlp_hidden_dim: 256 + target: "energy" + node_aggregation: "mean" + reference_energy_name: "vasp-shifted" + train_reference: true + predict_atom_avg: true + + # Node Head Configuration + node_head: + latent_dim: 256 + num_mlp_layers: 1 + mlp_hidden_dim: 256 + target: "forces" + remove_mean: true + + # Stress Head Configuration + stress_head: + latent_dim: 256 + num_mlp_layers: 1 + mlp_hidden_dim: 256 + target: "stress" + compute_stress: true diff --git a/MindChemistry/applications/orb/configs/config_parallel.yaml b/MindChemistry/applications/orb/configs/config_parallel.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c6a5e0857b52d3cf543e0e0388160e6764184c1c --- /dev/null +++ b/MindChemistry/applications/orb/configs/config_parallel.yaml @@ -0,0 +1,39 @@ +# Training Configuration +train_data_path: dataset/train_mptrj_ase.db +val_data_path: dataset/val_mptrj_ase.db +num_workers: 8 +batch_size: 256 +gradient_clip_val: 0.5 +max_epochs: 100 +checkpoint_path: orb_ckpts/ +lr: 3.0e-4 +random_seed: 666 + +# Model Configuration +model: + # Energy Head Configuration + energy_head: + latent_dim: 256 + num_mlp_layers: 1 + mlp_hidden_dim: 256 + target: "energy" + node_aggregation: "mean" + reference_energy_name: "vasp-shifted" + train_reference: true + predict_atom_avg: true + + # Node Head Configuration + node_head: + latent_dim: 256 + num_mlp_layers: 1 + mlp_hidden_dim: 256 + target: "forces" + remove_mean: true + + # Stress Head Configuration + stress_head: + latent_dim: 256 + num_mlp_layers: 1 + mlp_hidden_dim: 256 + target: "stress" + compute_stress: true diff --git a/MindChemistry/applications/orb/docs/orb.png b/MindChemistry/applications/orb/docs/orb.png new file mode 100644 index 0000000000000000000000000000000000000000..6f9026b83e31dad2f48d626388e15262a831d02c Binary files /dev/null and b/MindChemistry/applications/orb/docs/orb.png differ diff --git a/MindChemistry/applications/orb/evaluate.py b/MindChemistry/applications/orb/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..4e7157a82100284ff1c35106b98bae9643b82f23 --- /dev/null +++ b/MindChemistry/applications/orb/evaluate.py @@ -0,0 +1,102 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Evaluate.""" + +import argparse +import logging +import os +import pickle + +import mindspore as ms +from mindspore import context + +from finetune import build_loader +from src import pretrained, utils, trainer + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +def evaluate(args): + """Evaluate the model.""" + # set seed + utils.seed_everything(args.random_seed) + + # load dataset + val_loader = build_loader( + dataset_path=args.val_data_path, + num_workers=args.num_workers, + batch_size=1000, + target_config={"graph": ["energy", "stress"], "node": ["forces"]}, + augmentation=False, # do not apply random augment + shuffle=False, + ) + + # load trained model + if args.checkpoint_path is None: + raise ValueError("Checkpoint path is not provided.") + model = pretrained.orb_mptraj_only_v2(args.checkpoint_path) + model_params = sum(p.size for p in model.trainable_params() if p.requires_grad) + logging.info("Model has %d trainable parameters.", model_params) + + # begin evaluation + model.set_train(False) + val_iter = iter(val_loader) + val_batch = next(val_iter) + + output = model( + val_batch.edge_features, + val_batch.node_features, + val_batch.senders, + val_batch.receivers, + val_batch.n_node, + ) + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir, exist_ok=True) + save_path = os.path.join(args.output_dir, "predictions.pkl") + with open(save_path, "wb") as f: + pickle.dump(output, f) + + loss_fn = trainer.OrbLoss(model) + loss, logs = loss_fn.loss(val_batch) + print(f"Validation loss: {loss}") + for key, value in logs.items(): + print(f" {key}: {value}") + + +def main(): + """Main.""" + parser = argparse.ArgumentParser( + description="Evaluate orb model", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--config", type=str, default="configs/config_eval.yaml", help="Path to config file" + ) + args = parser.parse_args() + args = utils.load_cfg(args.config) + ms.set_context( + mode=context.PYNATIVE_MODE, + device_target=args.device_target, + device_id=args.device_id, + pynative_synchronize=True, + ) + evaluate(args) + + +if __name__ == "__main__": + main() diff --git a/MindChemistry/applications/orb/finetune.py b/MindChemistry/applications/orb/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..cb025e9dfd248bb6b9c958e449fedb054fe12eef --- /dev/null +++ b/MindChemistry/applications/orb/finetune.py @@ -0,0 +1,311 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Finetuning loop.""" + +import argparse +import logging +import warnings +import os +import timeit +from typing import Dict, Optional + +import mindspore as ms +from mindspore import nn, ops, context +import mindspore.dataset as ds +from mindspore.communication import init +from mindspore.communication import get_rank, get_group_size + +from src import base, pretrained, utils +from src.ase_dataset import AseSqliteDataset, BufferData +from src.trainer import OrbLoss + + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +def finetune( + model: nn.Cell, + loss_fn: Optional[nn.Cell], + optimizer: nn.Optimizer, + train_dataloader: ds.GeneratorDataset, + val_dataloader: ds.GeneratorDataset, + lr_scheduler: Optional[ms.experimental.optim.lr_scheduler] = None, + clip_grad: Optional[float] = None, + log_freq: float = 10, + parallel_mode: str = "NONE", +): + """Train for a fixed number of steps. + + Args: + model: The model to optimize. + loss_fn: The loss function to use. + optimizer: The optimizer to use for the model. + train_dataloader: A Dataloader, which may be infinite if num_steps is passed. + val_dataloader: A Dataloader for validation. + lr_scheduler: Optional, a Learning rate scheduler for modifying the learning rate. + clip_grad: Optional, the gradient clipping threshold. + log_freq: The logging frequency for step metrics. + parallel_mode: The parallel mode to use, e.g., "DATA_PARALLEL" or "NONE". + + Returns + A dictionary of metrics. + """ + if clip_grad is not None: + hook_handles = utils.gradient_clipping(model, clip_grad) + + train_metrics = utils.ScalarMetricTracker() + val_metrics = utils.ScalarMetricTracker() + + epoch_metrics = { + "data_time": 0.0, + "train_time": 0.0, + } + + # Get gradient function + grad_fn = ms.value_and_grad(loss_fn.loss, None, optimizer.parameters, has_aux=True) + if parallel_mode == "DATA_PARALLEL": + grad_reducer = nn.DistributedGradReducer(optimizer.parameters) + + # Define function of one-step training + def train_step(data, label=None): + (loss, val_logs), grads = grad_fn(data, label) + if parallel_mode == "DATA_PARALLEL": + grads = grad_reducer(grads) + optimizer(grads) + return loss, val_logs + + step_begin = timeit.default_timer() + for i, batch in enumerate(train_dataloader): + epoch_metrics["data_time"] += timeit.default_timer() - step_begin + # Reset metrics so that it reports raw values for each step but still do averages on + # the gradient accumulation. + if i % log_freq == 0: + train_metrics.reset() + + model.set_train() + loss, train_logs = train_step(batch) + + epoch_metrics["train_time"] += timeit.default_timer() - step_begin + train_metrics.update(epoch_metrics) + train_metrics.update(train_logs) + + if ops.isnan(loss): + raise ValueError("nan loss encountered") + + if lr_scheduler is not None: + lr_scheduler.step() + step_begin = timeit.default_timer() + + if clip_grad is not None: + for h in hook_handles: + h.remove() + + # begin evaluation + model.set_train(False) + val_iter = iter(val_dataloader) + val_batch = next(val_iter) + loss, val_logs = loss_fn.loss(val_batch) + val_metrics.update(val_logs) + + return train_metrics.get_metrics(), val_metrics.get_metrics() + + +def build_loader( + dataset_path: str, + num_workers: int, + batch_size: int, + augmentation: Optional[bool] = True, + target_config: Optional[Dict] = None, + shuffle: Optional[bool] = True, + parallel_mode: str = "NONE", + **kwargs, +) -> ds.GeneratorDataset: + """Builds the dataloader from a config file. + + Args: + dataset_path: Dataset path. + num_workers: The number of workers for each dataset. + batch_size: The batch_size config for each dataset. + augmentation: If rotation augmentation is used. + target_config: The target config. + shuffle: If the dataset should be shuffled. + parallel_mode: The parallel mode to use, e.g., "DATA_PARALLEL" or "NONE". + + Returns: + The Dataloader. + """ + log_loading = f"Loading datasets: {dataset_path} with {num_workers} workers. " + dataset = AseSqliteDataset( + dataset_path, target_config=target_config, augmentation=augmentation, **kwargs + ) + + log_loading += f"Total dataset size: {len(dataset)} samples" + logging.info(log_loading) + + dataset = BufferData(dataset, shuffle=shuffle) + if parallel_mode == "DATA_PARALLEL": + rank_id = get_rank() + rank_size = get_group_size() + dataloader = [ + [dataset[j] for j in range(i, min(i + batch_size, len(dataset)))] \ + for i in range(0, len(dataset), batch_size) + ] + dataloader = [ + base.batch_graphs( + data[rank_id*len(data)//rank_size : (rank_id+1)*len(data)//rank_size] + ) for data in dataloader + ] + else: + dataloader = [ + base.batch_graphs( + [dataset[j] for j in range(i, min(i + batch_size, len(dataset)))] + ) for i in range(0, len(dataset), batch_size) + ] + + return dataloader + + +def run(args, parallel_mode="NONE"): + """Training Loop. + + Args: + config (DictConfig): Config for training loop. + parallel_mode (str): The parallel mode to use, e.g., "DATA_PARALLEL" or "NONE". + """ + utils.seed_everything(args.random_seed) + + # Load dataset + train_loader = build_loader( + dataset_path=args.train_data_path, + num_workers=args.num_workers, + batch_size=args.batch_size, + target_config={"graph": ["energy", "stress"], "node": ["forces"]}, + augmentation=True, + ) + val_loader = build_loader( + dataset_path=args.val_data_path, + num_workers=args.num_workers, + batch_size=1000, + target_config={"graph": ["energy", "stress"], "node": ["forces"]}, + augmentation=False, + shuffle=False, + ) + num_steps = len(train_loader) + + # Instantiate model + pretrained_weights_path = os.path.join(args.checkpoint_path, "orb-mptraj-only-v2.ckpt") + model = pretrained.orb_mptraj_only_v2(pretrained_weights_path) + loss_fn = OrbLoss(model) + model_params = sum(p.size for p in model.trainable_params() if p.requires_grad) + logging.info("Model has %d trainable parameters.", model_params) + + total_steps = args.max_epochs * num_steps + optimizer, lr_scheduler = utils.get_optim(args.lr, total_steps, model) + + # Fine-tuning loop + start_epoch = 0 + train_time = timeit.default_timer() + for epoch in range(start_epoch, args.max_epochs): + train_metrics, val_metrics = finetune( + model=model, + loss_fn=loss_fn, + optimizer=optimizer, + train_dataloader=train_loader, + val_dataloader=val_loader, + lr_scheduler=lr_scheduler, + clip_grad=args.gradient_clip_val, + parallel_mode=parallel_mode, + ) + print(f'Epoch: {epoch}/{args.max_epochs}, \n train_metrics: {train_metrics}\n val_metrics: {val_metrics}') + + # Save checkpoint from last epoch + if epoch == args.max_epochs - 1: + # create ckpts folder if it does not exist + if not os.path.exists(args.checkpoint_path): + os.makedirs(args.checkpoint_path) + if parallel_mode == "DATA_PARALLEL": + rank_id = get_rank() + rank_size = get_group_size() + ms.save_checkpoint( + model, + os.path.join( + args.checkpoint_path, + f"orb-ft-parallel[{rank_id}-{rank_size}]-checkpoint_epoch{epoch}.ckpt" + ), + ) + else: + ms.save_checkpoint( + model, + os.path.join(args.checkpoint_path, f"orb-ft-checkpoint_epoch{epoch}.ckpt"), + ) + logging.info("Checkpoint saved to %s", args.checkpoint_path) + logging.info("Training time: %.5f seconds", timeit.default_timer() - train_time) + + +def main(): + """Main.""" + parser = argparse.ArgumentParser( + description="Finetune orb model", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--config", type=str, default="configs/config.yaml", help="Path to config file" + ) + parser.add_argument( + "--device_target", + type=str, + default="Ascend", + help="The target device to run, support 'Ascend'" + ) + parser.add_argument( + "--device_id", default=0, type=int, help="device index to use." + ) + parser.add_argument( + "--parallel_mode", + type=str, + default="NONE", + choices=["DATA_PARALLEL", "NONE"], + help="Parallel mode, support 'DATA_PARALLEL', 'NONE'" + ) + args = parser.parse_args() + + if args.parallel_mode.upper() == "DATA_PARALLEL": + ms.set_context( + mode=context.PYNATIVE_MODE, + device_target=args.device_target, + pynative_synchronize=True, + ) + # Set parallel context + ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True) + init() + ms.set_seed(1) + else: + ms.set_context( + mode=context.PYNATIVE_MODE, + device_target=args.device_target, + device_id=args.device_id, + pynative_synchronize=True, + ) + configs = utils.load_cfg(args.config) + warnings.filterwarnings("ignore") + + run(configs, args.parallel_mode) + + +if __name__ == "__main__": + main() diff --git a/MindChemistry/applications/orb/requirement.txt b/MindChemistry/applications/orb/requirement.txt new file mode 100644 index 0000000000000000000000000000000000000000..e3f1a95c78cbcf5fb4b875a23e8d8c3853b64ac0 --- /dev/null +++ b/MindChemistry/applications/orb/requirement.txt @@ -0,0 +1,8 @@ +python>=3.10 +cached_path>=1.6.7 +ase>=3.24.0 +numpy>=1.26.4 +scipy>=1.15.1 +dm-tree==0.1.8 +tqdm>=4.66.5 +mindspore==2.5.0 \ No newline at end of file diff --git a/MindChemistry/applications/orb/run.sh b/MindChemistry/applications/orb/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..e21df131ac10c208179285b99d65658116eb6b8a --- /dev/null +++ b/MindChemistry/applications/orb/run.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +export GLOG_v=3 +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run.sh" +echo "==============================================================================================================" + +python finetune.py --device_target Ascend --device_id 7 diff --git a/MindChemistry/applications/orb/run_parallel.sh b/MindChemistry/applications/orb/run_parallel.sh new file mode 100644 index 0000000000000000000000000000000000000000..e49c6ab71cef9b615b5fca800f53c79a3e355f1c --- /dev/null +++ b/MindChemistry/applications/orb/run_parallel.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +rm -rf msrun_log +mkdir msrun_log + +export GLOG_v=3 +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run_parallel.sh" +echo "==============================================================================================================" + +msrun --worker_num=4 --local_worker_num=4 --master_port=8118 --log_dir=msrun_log --join=True --cluster_time_out=300 finetune.py --config configs/config_parallel.yaml --parallel_mode DATA_PARALLEL \ No newline at end of file diff --git a/MindChemistry/applications/orb/src/__init__.py b/MindChemistry/applications/orb/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..328a08a650120b0d9aedfd04c203ecb52649a69d --- /dev/null +++ b/MindChemistry/applications/orb/src/__init__.py @@ -0,0 +1,16 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" diff --git a/MindChemistry/applications/orb/src/ase_dataset.py b/MindChemistry/applications/orb/src/ase_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f61b95aa33dab27abbdefc6650efdf429d7eb988 --- /dev/null +++ b/MindChemistry/applications/orb/src/ase_dataset.py @@ -0,0 +1,239 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ASE dataset""" + +import os +from typing import Dict, Optional, Tuple, Union + +import ase +import ase.db +import ase.db.row +import ase.stress +import numpy as np +import mindspore as ms +from mindspore import Tensor + +from src import atomic_system, property_definitions +from src.base import AtomGraphs +from src.utils import rand_matrix + + +class AseSqliteDataset: + """AseSqliteDataset. + + A MindSpore Dataset for reading ASE Sqlite serialized Atoms objects. + + Args: + dataset_path: Local path to read. + system_config: A config for controlling how an atomic system is represented. + target_config: A config for regression/classification targets. + augmentation: If random rotation augmentation is used. + + Returns: + An AseSqliteDataset. + """ + + def __init__( + self, + dataset_path: Union[str, os.PathLike], + system_config: Optional[atomic_system.SystemConfig] = None, + target_config: Optional[Dict] = None, + augmentation: Optional[bool] = True, + ): + super().__init__() + self.augmentation = augmentation + self.path = dataset_path + self.db = ase.db.connect(str(self.path), serial=True, type="db") + + self.feature_config = system_config + if target_config is None: + target_config = { + "graph": ["energy", "stress"], + "node": ["forces"], + "edge": [], + } + self.target_config = target_config + + def __getitem__(self, idx) -> AtomGraphs: + """Fetch an item from the db. + + Args: + idx: An index to fetch from the db file and convert to an AtomGraphs. + + Returns: + A AtomGraphs object containing everything the model needs as input, + positions and atom types and other auxiliary information, such as + fine tuning targets, or global graph features. + """ + # Sqlite db is 1 indexed. + row = self.db.get(idx + 1) + atoms = row.toatoms() + node_properties = property_definitions.get_property_from_row( + self.target_config["node"], row + ) + graph_property_dict = {} + for target_property in self.target_config["graph"]: + system_properties = property_definitions.get_property_from_row( + target_property, row + ) + # transform stress to voigt6 representation + if target_property == "stress" and len(system_properties.reshape(-1)) == 9: + system_properties = Tensor( + ase.stress.full_3x3_to_voigt_6_stress(system_properties.reshape(3, 3)), + dtype=ms.float32, + ).reshape(1, -1) + graph_property_dict[target_property] = system_properties + extra_targets = { + "node": {"forces": node_properties}, + "edge": {}, + "graph": graph_property_dict, + } + if self.augmentation: + atoms, extra_targets = random_rotations_with_properties(atoms, extra_targets) + + atom_graph = atomic_system.ase_atoms_to_atom_graphs( + atoms, + system_id=idx, + brute_force_knn=False, + ) + atom_graph = self._add_extra_targets(atom_graph, extra_targets) + + return atom_graph + + def get_atom(self, idx: int) -> ase.Atoms: + """Return the Atoms object for the dataset index.""" + row = self.db.get(idx + 1) + return row.toatoms() + + def get_atom_and_metadata(self, idx: int) -> Tuple[ase.Atoms, Dict]: + """Return the Atoms object plus a dict of metadata for the dataset index.""" + row = self.db.get(idx + 1) + return row.toatoms(), row.data + + def __len__(self) -> int: + """Return the dataset length.""" + return len(self.db) + + def __repr__(self) -> str: + """String representation of class.""" + return f"AseSqliteDataset(path={self.path})" + + def _add_extra_targets( + self, + atom_graph: AtomGraphs, + extra_targets: Dict[str, Dict], + ): + """Add extra features and targets to the AtomGraphs object. + + Args: + atom_graph: AtomGraphs object to add extra features and targets to. + extra_targets: Dictionary of extra targets to add. + """ + node_targets = ( + atom_graph.node_targets if atom_graph.node_targets is not None else {} + ) + node_targets = {**node_targets, **extra_targets["node"]} + + edge_targets = ( + atom_graph.edge_targets if atom_graph.edge_targets is not None else {} + ) + edge_targets = {**edge_targets, **extra_targets["edge"]} + + system_targets = ( + atom_graph.system_targets if atom_graph.system_targets is not None else {} + ) + system_targets = {**system_targets, **extra_targets["graph"]} + + return atom_graph._replace( + node_targets=node_targets if node_targets != {} else None, + edge_targets=edge_targets if edge_targets != {} else None, + system_targets=system_targets if system_targets != {} else None, + ) + + +def random_rotations_with_properties( + atoms: ase.Atoms, properties: dict +) -> Tuple[ase.Atoms, dict]: + """Randomly rotate atoms in ase.Atoms object. + + This exists to handle the case where we also need to rotate properties. + Currently we only ever do this for random rotations, but it could be extended. + + Args: + atoms (ase.Atoms): Atoms object to rotate. + properties (dict): Dictionary of properties to rotate. + """ + rand_rotation = rand_matrix(1)[0] + atoms.positions = atoms.positions @ rand_rotation + if atoms.cell is not None: + atoms.set_cell(atoms.cell.array @ rand_rotation) + + new_node_properties = {} + for key, v in properties["node"].items(): + if tuple(v.shape) == tuple(atoms.positions.shape): + new_node_properties[key] = v @ rand_rotation + else: + new_node_properties[key] = v + properties["node"] = new_node_properties + + if "stress" in properties["graph"]: + # Transformation rule of stress tensor + stress = properties["graph"]["stress"] + full_stress = ase.stress.voigt_6_to_full_3x3_stress(stress) + + if full_stress.shape != (3, 3): + full_stress = full_stress.reshape(3, 3) + + transformed = np.dot(np.dot(rand_rotation, full_stress), rand_rotation.T) + # Back to voigt notation, and shape (1, 6) for consistency with batching + properties["graph"]["stress"] = Tensor( + [ + transformed[0, 0], + transformed[1, 1], + transformed[2, 2], + transformed[1, 2], + transformed[0, 2], + transformed[0, 1], + ], + dtype=ms.float32, + ).unsqueeze(0) + + return atoms, properties + +class BufferData: + """Wrapper for a dataset. Loads all data into memory.""" + + def __init__(self, dataset, shuffle: bool = True): + """BufferData. + Args: + dataset: The dataset to wrap. + shuffle: If True, shuffle the data. + """ + self.data_objects = [dataset[i] for i in range(len(dataset))] + if shuffle: + self.shuffle() + + def __len__(self): + return len(self.data_objects) + + def __getitem__(self, index): + return self.data_objects[index] + + def shuffle(self): + """Shuffle the data.""" + indices = np.arange(len(self.data_objects)) + np.random.shuffle(indices) + self.data_objects = [self.data_objects[i] for i in indices] diff --git a/MindChemistry/applications/orb/src/atomic_system.py b/MindChemistry/applications/orb/src/atomic_system.py new file mode 100644 index 0000000000000000000000000000000000000000..b9895e9bd78be394a1ee9fa09e0c5a0b054de787 --- /dev/null +++ b/MindChemistry/applications/orb/src/atomic_system.py @@ -0,0 +1,222 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""atomic system""" + +from dataclasses import dataclass +from typing import List, Optional + +import ase +from ase import constraints +from ase.calculators.singlepoint import SinglePointCalculator + +import mindspore as ms +from mindspore import Tensor, mint + +from src import featurization_utilities +from src.base import AtomGraphs + + +@dataclass +class SystemConfig: + """Config controlling how to featurize a system of atoms. + + Args: + radius: radius for edge construction + max_num_neighbors: maximum number of neighbours each node can send messages to. + use_timestep_0: (unused - purely for compatibility with internal models) + """ + + radius: float + max_num_neighbors: int + use_timestep_0: bool = True + + +def atom_graphs_to_ase_atoms( + graphs: AtomGraphs, + energy: Optional[Tensor] = None, + forces: Optional[Tensor] = None, + stress: Optional[Tensor] = None, +) -> List[ase.Atoms]: + """Converts a list of graphs to a list of ase.Atoms.""" + if "atomic_numbers_embedding" in graphs.node_features: + atomic_numbers = mint.argmax( + graphs.node_features["atomic_numbers_embedding"], dim=-1 + ) + else: + atomic_numbers = graphs.node_features["atomic_numbers"] + atomic_numbers_split = mint.split(atomic_numbers, graphs.n_node.tolist()) + positions_split = mint.split(graphs.positions, graphs.n_node.tolist()) + assert graphs.tags is not None and graphs.system_features is not None + tags = mint.split(graphs.tags, graphs.n_node.tolist()) + + calculations = {} + if energy is not None: + energy_list = mint.unbind(energy.detach()) + assert len(energy_list) == len(atomic_numbers_split) + calculations["energy"] = energy_list + if forces is not None: + forces_list = mint.split(forces.detach(), graphs.n_node.tolist()) + assert len(forces_list) == len(atomic_numbers_split) + calculations["forces"] = forces_list + if stress is not None: + stress_list = mint.unbind(stress.detach()) + assert len(stress_list) == len(atomic_numbers_split) + calculations["stress"] = stress_list + + atoms_list = [] + for index, (n, p, c, t) in enumerate( + zip(atomic_numbers_split, positions_split, graphs.cell, tags) + ): + atoms = ase.Atoms( + numbers=n.detach(), + positions=p.detach(), + cell=c.detach(), + tags=t.detach(), + pbc=mint.any(c != 0), + ) + if calculations != {}: + spc = SinglePointCalculator( + atoms=atoms, + **{ + key: ( + val[index].item() + if val[index].nelement() == 1 + else val[index].numpy() + ) + for key, val in calculations.items() + }, + ) + atoms.calc = spc + atoms_list.append(atoms) + + return atoms_list + + +def ase_atoms_to_atom_graphs( + atoms: ase.Atoms, + system_config: SystemConfig = SystemConfig( + radius=10.0, max_num_neighbors=20, use_timestep_0=True + ), + system_id: Optional[int] = None, + brute_force_knn: Optional[bool] = None, +) -> AtomGraphs: + """Generate AtomGraphs from an ase.Atoms object. + + Args: + atoms: ase.Atoms object + system_config: SystemConfig object + system_id: Optional system_id + brute_force_knn: whether to use a 'brute force' knn approach with torch.cdist for kdtree construction. + Defaults to None, in which case brute_force is used if we a GPU is available (2-6x faster), + but not on CPU (1.5x faster - 4x slower). For very large systems, brute_force may OOM on GPU, + so it is recommended to set to False in that case. + device: device to put the tensors on. + + Returns: + AtomGraphs object + """ + atomic_numbers = ms.from_numpy(atoms.numbers).long() + atom_type_embedding = mint.nn.functional.one_hot( + atomic_numbers, num_classes=118 + ).type(ms.float32) + + node_feats = { + "atomic_numbers": atomic_numbers.to(ms.int64), + "atomic_numbers_embedding": atom_type_embedding.to(ms.float32), + "positions": ms.from_numpy(atoms.positions).to(ms.float32), + } + system_feats = {"cell": Tensor(atoms.cell.array[None, ...]).to(ms.float32)} + edge_feats, senders, receivers = _get_edge_feats( + node_feats["positions"], + system_feats["cell"][0], + system_config.radius, + system_config.max_num_neighbors, + brute_force=brute_force_knn, + ) + + num_atoms = len(node_feats["positions"]) + atom_graph = AtomGraphs( + senders=senders, + receivers=receivers, + n_node=Tensor([num_atoms]), + n_edge=Tensor([len(senders)]), + node_features=node_feats, + edge_features=edge_feats, + system_features=system_feats, + system_id=Tensor([system_id]) if system_id is not None else system_id, + fix_atoms=ase_fix_atoms_to_tensor(atoms), + tags=_get_ase_tags(atoms), + radius=system_config.radius, + max_num_neighbors=system_config.max_num_neighbors, + ) + return atom_graph + + +def _get_edge_feats( + positions: Tensor, + cell: Tensor, + radius: float, + max_num_neighbours: int, + brute_force: Optional[bool] = None, +): + """Get edge features. + + Args: + positions: (n_nodes, 3) positions tensor + cell: 3x3 tensor unit cell for a system + radius: radius for edge construction + max_num_neighbours: maximum number of neighbours each node can send messages to. + n_kdtree_workers: number of workers to use for kdtree construction. + brute_force: whether to use brute force for kdtree construction. + """ + # Construct a graph from a 3x3 supercell (as opposed to an infinite supercell). + ( + edge_index, + edge_vectors, + ) = featurization_utilities.compute_pbc_radius_graph( + positions=positions, + periodic_boundaries=cell, + radius=radius, + max_number_neighbors=max_num_neighbours, + brute_force=brute_force, + ) + edge_feats = { + "vectors": edge_vectors.to(ms.float32), + "r": edge_vectors.norm(dim=-1), + } + senders, receivers = edge_index[0], edge_index[1] + return edge_feats, senders, receivers + + +def _get_ase_tags(atoms: ase.Atoms) -> Tensor: + """Get tags from ase.Atoms object.""" + tags = atoms.get_tags() + if tags is not None: + tags = Tensor(tags) + else: + tags = mint.zeros(len(atoms)) + return tags + + +def ase_fix_atoms_to_tensor(atoms: ase.Atoms) -> Optional[Tensor]: + """Get fixed atoms from ase.Atoms object.""" + fixed_atoms = None + if atoms.constraints is not None and atoms.constraints: + constraint = atoms.constraints[0] + if isinstance(constraint, constraints.FixAtoms): + fixed_atoms = mint.zeros((len(atoms)), dtype=ms.bool_) + fixed_atoms[constraint.index] = True + return fixed_atoms diff --git a/MindChemistry/applications/orb/src/base.py b/MindChemistry/applications/orb/src/base.py new file mode 100644 index 0000000000000000000000000000000000000000..046e5950f209f46aa72b2933468211a74d9ae67d --- /dev/null +++ b/MindChemistry/applications/orb/src/base.py @@ -0,0 +1,486 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Base Model class.""" + +from collections import defaultdict +from copy import deepcopy +from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Sequence, Union +import tree + +import mindspore as ms +from mindspore import ops, Tensor, mint + +from src import featurization_utilities + +Metric = Union[Tensor, int, float] +TensorDict = Mapping[str, Optional[Tensor]] + + +class ModelOutput(NamedTuple): + """A model's output.""" + + loss: Tensor + log: Mapping[str, Metric] + + +class AtomGraphs(NamedTuple): + """A class representing the input to a model for a graph. + + Args: + senders (torch.Tensor): The integer source nodes for each edge. + receivers (torch.Tensor): The integer destination nodes for each edge. + n_node (torch.Tensor): A (batch_size, ) shaped tensor containing the number of nodes per graph. + n_edge (torch.Tensor): A (batch_size, ) shaped tensor containing the number of edges per graph. + node_features (Dict[str, torch.Tensor]): A dictionary containing node feature tensors. + It will always contain "atomic_numbers" and "positions" keys, representing the + atomic numbers of each node, and the 3d cartesian positions of them respectively. + edge_features (Dict[str, torch.Tensor]): A dictionary containing edge feature tensors. + system_features (Optional[TensorDict]): An optional dictionary containing system-level features. + node_targets (Optional[Dict[torch.Tensor]]): An optional dict of tensors containing targets + for individual nodes. This tensor is commonly expected to have shape (num_nodes, *). + edge_target (Optional[torch.Tensor]): An optional tensor containing targets for individual edges. + This tensor is commonly expected to have (num_edges, *). + system_targets (Optional[Dict[torch.Tensor]]): An optional dict of tensors containing targets for the + entire system. system_id (Optional[torch.Tensor]): An optional tensor containing the ID of the system. + fix_atoms (Optional[torch.Tensor]): An optional tensor containing information on fixed atoms in the system. + """ + + senders: Tensor + receivers: Tensor + n_node: Tensor + n_edge: Tensor + node_features: Dict[str, Tensor] + edge_features: Dict[str, Tensor] + system_features: Dict[str, Tensor] + node_targets: Optional[Dict[str, Tensor]] = None + edge_targets: Optional[Dict[str, Tensor]] = None + system_targets: Optional[Dict[str, Tensor]] = None + system_id: Optional[Tensor] = None + fix_atoms: Optional[Tensor] = None + tags: Optional[Tensor] = None + radius: Optional[float] = None + max_num_neighbors: Optional[int] = None + + @property + def positions(self): + """Get positions of atoms.""" + return self.node_features["positions"] + + @positions.setter + def positions(self, val: Tensor): + self.node_features["positions"] = val + + @property + def atomic_numbers(self): + """Get integer atomic numbers.""" + return self.node_features["atomic_numbers"] + + @atomic_numbers.setter + def atomic_numbers(self, val: Tensor): + self.node_features["atomic_numbers"] = val + + @property + def cell(self): + """Get unit cells.""" + assert self.system_features + return self.system_features.get("cell") + + @cell.setter + def cell(self, val: Tensor): + assert self.system_features + self.system_features["cell"] = val + + def clone(self) -> "AtomGraphs": + """Clone the AtomGraphs object. + + Note: this differs from deepcopy() because it preserves gradients. + """ + + def _clone(x): + if isinstance(x, Tensor): + return x.clone() + return x + + return tree.map_structure(_clone, self) + + def to(self, device: Any = None) -> "AtomGraphs": + """Move AtomGraphs child tensors to a device.""" + + print(f"Moving AtomGraphs to device: {device}") + def _to(x): + if hasattr(x, "to"): + return x + return x + + return tree.map_structure(_to, self) + + def tachdetach(self) -> "AtomGraphs": + """Detach all child tensors.""" + + def _detach(x): + if hasattr(x, "detach"): + return x.detach() + return x + + return tree.map_structure(_detach, self) + + def equals(self, graphs: "AtomGraphs") -> bool: + """Check two atomgraphs are equal.""" + + def _is_equal(x, y): + if isinstance(x, Tensor): + return mint.equal(x, y) + return x == y + + flat_results = tree.flatten(tree.map_structure(_is_equal, self, graphs)) + return all(flat_results) + + def allclose(self, graphs: "AtomGraphs", rtol=1e-5, atol=1e-8) -> bool: + """Check all tensors/scalars of two atomgraphs are close.""" + + def _is_close(x, y): + if isinstance(x, Tensor): + return mint.allclose(x, y, rtol=rtol, atol=atol) + if isinstance(x, (float, int)): + return mint.allclose( + Tensor(x), Tensor(y), rtol=rtol, atol=atol + ) + return x == y + + flat_results = tree.flatten(tree.map_structure(_is_close, self, graphs)) + return all(flat_results) + + def to_dict(self): + """Return a dictionary mapping each AtomGraph property to a corresponding tensor/scalar. + + Any nested attributes of the AtomGraphs are unpacked so the + returned dict has keys like "positions" and "atomic_numbers". + + Any None attributes are not included in the dictionary. + + Returns: + dict: A dictionary mapping attribute_name -> tensor/scalar + """ + ret = {} + for key, val in self._asdict().items(): + if val is None: + continue + if isinstance(val, dict): + for k, v in val.items(): + ret[k] = v + else: + ret[key] = val + + return ret + + def to_batch_dict(self) -> Dict[str, Any]: + """Return a single dictionary mapping each AtomGraph property to a corresponding list of tensors/scalars. + + Returns: + dict: A dict mapping attribute_name -> list of length batch_size containing tensors/scalars. + """ + batch_dict = defaultdict(list) + for graph in self.split(self): + for key, value in graph.to_dict().items(): + batch_dict[key].append(value) + return batch_dict + + def split(self, clone=True) -> List["AtomGraphs"]: + """Splits batched AtomGraphs into constituent system AtomGraphs. + + Args: + graphs (AtomGraphs): A batched AtomGraphs object. + clone (bool): Whether to clone the graphs before splitting. + Cloning removes risk of side-effects, but uses more memory. + """ + graphs = self.clone() if clone else self + + batch_nodes = graphs.n_node.tolist() + batch_edges = graphs.n_edge.tolist() + + if not batch_nodes: + raise ValueError("Cannot split empty batch") + if len(batch_nodes) == 1: + return [graphs] + + batch_systems = mint.ones(len(batch_nodes), dtype=ms.int32).tolist() + node_features = _split_features(graphs.node_features, batch_nodes) + node_targets = _split_features(graphs.node_targets, batch_nodes) + edge_features = _split_features(graphs.edge_features, batch_edges) + edge_targets = _split_features(graphs.edge_targets, batch_edges) + system_features = _split_features(graphs.system_features, batch_systems) + system_targets = _split_features(graphs.system_targets, batch_systems) + system_ids = _split_tensors(graphs.system_id, batch_systems) + fix_atoms = _split_tensors(graphs.fix_atoms, batch_nodes) + tags = _split_tensors(graphs.tags, batch_nodes) + batch_nodes = [Tensor([n]) for n in batch_nodes] + batch_edges = [Tensor([e]) for e in batch_edges] + + # calculate the new senders and receivers + senders = list(_split_tensors(graphs.senders, batch_edges)) + receivers = list(_split_tensors(graphs.receivers, batch_edges)) + n_graphs = graphs.n_node.shape[0] + offsets = mint.cumsum(graphs.n_node[:-1], 0) + offsets = mint.cat([Tensor([0]), offsets]) + unbatched_senders = [] + unbatched_recievers = [] + for graph_index in range(n_graphs): + s = senders[graph_index] - offsets[graph_index] + r = receivers[graph_index] - offsets[graph_index] + unbatched_senders.append(s) + unbatched_recievers.append(r) + + return [ + AtomGraphs(*args) + for args in zip( + unbatched_senders, + unbatched_recievers, + batch_nodes, + batch_edges, + node_features, + edge_features, + system_features, + node_targets, + edge_targets, + system_targets, + system_ids, + fix_atoms, + tags, + [graphs.radius for _ in range(len(batch_nodes))], + [graphs.max_num_neighbors for _ in range(len(batch_nodes))], + ) + ] + + +def batch_graphs(graphs: List[AtomGraphs]) -> AtomGraphs: + """Batch graphs together by concatenating their nodes, edges, and features. + + Args: + graphs (List[AtomGraphs]): A list of AtomGraphs to be batched together. + + Returns: + AtomGraphs: A new AtomGraphs object with the concatenated nodes, + edges, and features from the input graphs, along with concatenated target, + system ID, and other information. + """ + # Calculates offsets for sender and receiver arrays, caused by concatenating + # the nodes arrays. + offsets = mint.cumsum( + Tensor([0] + [mint.sum(g.n_node) for g in graphs[:-1]]), 0 + ) + radius = graphs[0].radius + assert {graph.radius for graph in graphs} == {radius} + max_num_neighbours = graphs[0].max_num_neighbors + assert {graph.max_num_neighbors for graph in graphs} == {max_num_neighbours} + + return AtomGraphs( + n_node=mint.concat([g.n_node for g in graphs], dim=0).to(ms.int64), + n_edge=mint.concat([g.n_edge for g in graphs], dim=0).to(ms.int64), + senders=mint.concat( + [g.senders + o for g, o in zip(graphs, offsets)], dim=0 + ).to(ms.int64), + receivers=mint.concat( + [g.receivers + o for g, o in zip(graphs, offsets)], dim=0 + ).to(ms.int64), + node_features=_map_concat([g.node_features for g in graphs]), + edge_features=_map_concat([g.edge_features for g in graphs]), + system_features=_map_concat([g.system_features for g in graphs]), + node_targets=_map_concat([g.node_targets for g in graphs]), + edge_targets=_map_concat([g.edge_targets for g in graphs]), + system_targets=_map_concat([g.system_targets for g in graphs]), + system_id=_concat([g.system_id for g in graphs]), + fix_atoms=_concat([g.fix_atoms for g in graphs]), + tags=_concat([g.tags for g in graphs]), + radius=radius, + max_num_neighbors=max_num_neighbours, + ) + + +def refeaturize_atomgraphs( + atoms: AtomGraphs, + positions: Tensor, + atomic_number_embeddings: Optional[Tensor] = None, + cell: Optional[Tensor] = None, + recompute_neighbors=True, + updates: Optional[Tensor] = None, + fixed_atom_pos: Optional[Tensor] = None, + fixed_atom_type_embedding: Optional[Tensor] = None, + differentiable: bool = False, +) -> AtomGraphs: + """Return a graph updated according to the new positions, and (if given) atomic numbers and unit cells. + + Note: if a unit cell is given, it will *both* be used to do the + pbc-remapping and be set on the returned AtomGraphs + + Args: + atoms (AtomGraphs): The original AtomGraphs object. + positions (torch.Tensor): The new positions of the atoms. + atomic_number_embeddings (Optional[torch.Tensor]): The new atomic number embeddings. + cell (Optional[torch.Tensor]): The new unit cell. + recompute_neighbors (bool): Whether to recompute the neighbor list. + updates (Optional[torch.Tensor]): The updates to the positions. + fixed_atom_pos (Optional[torch.Tensor]): The positions of atoms + which are fixed when diffusing on a fixed trajectory. + fixed_atom_type_embedding (Optional[torch.Tensor]): If using atom type diffusion + with a fixed trajectory, the unormalized vectors of the fixed atoms. Shape (n_atoms, 118). + differentiable (bool): Whether to make the graph inputs require_grad. This includes + the positions and atomic number embeddings, if passed. + exact_pbc_image_neighborhood: bool: If the exact pbc image neighborhood calculation (from torch nl) + which considers boundary crossing for more than cell is used. + + Returns: + AtomGraphs: A refeaturized AtomGraphs object. + """ + if cell is None: + cell = atoms.cell + + if atoms.fix_atoms is not None and fixed_atom_pos is not None: + positions[atoms.fix_atoms] = fixed_atom_pos[atoms.fix_atoms] + + if ( + atoms.fix_atoms is not None + and fixed_atom_type_embedding is not None + and atomic_number_embeddings is not None + ): + atomic_number_embeddings[atoms.fix_atoms] = fixed_atom_type_embedding[ + atoms.fix_atoms + ] + + num_atoms = atoms.n_node + positions = featurization_utilities.batch_map_to_pbc_cell( + positions, cell, num_atoms + ) + + if differentiable: + positions.requires_grad = True + if atomic_number_embeddings is not None: + atomic_number_embeddings.requires_grad = True + + if recompute_neighbors: + assert atoms.radius is not None and atoms.max_num_neighbors is not None + ( + edge_index, + edge_vectors, + batch_num_edges, + ) = featurization_utilities.batch_compute_pbc_radius_graph( + positions=positions, + periodic_boundaries=cell, + radius=atoms.radius, + image_idx=num_atoms, + max_number_neighbors=atoms.max_num_neighbors, + ) + new_senders = edge_index[0] + new_receivers = edge_index[1] + else: + assert updates is not None + new_senders = atoms.senders + new_receivers = atoms.receivers + edge_vectors = recompute_edge_vectors(atoms, updates) + batch_num_edges = atoms.n_edge + + edge_features = { + "vectors": edge_vectors.to(ms.float32), + } + + new_node_features = {} + if atoms.node_features is not None: + new_node_features = deepcopy(atoms.node_features) + new_node_features["positions"] = positions + if atomic_number_embeddings is not None: + new_node_features["atomic_numbers_embedding"] = atomic_number_embeddings + + new_system_features = {} + if atoms.system_features is not None: + new_system_features = deepcopy(atoms.system_features) + new_system_features["cell"] = cell + + new_atoms = AtomGraphs( + senders=new_senders, + receivers=new_receivers, + n_node=atoms.n_node, + n_edge=batch_num_edges, + node_features=new_node_features, + edge_features=edge_features, + system_features=new_system_features, + node_targets=atoms.node_targets, + system_targets=atoms.system_targets, + fix_atoms=atoms.fix_atoms, + tags=atoms.tags, + radius=atoms.radius, + max_num_neighbors=atoms.max_num_neighbors, + ) + + return new_atoms + + +def recompute_edge_vectors(atoms, updates): + """Recomputes edge vectors with per node updates.""" + updates = -updates + senders = atoms.senders + receivers = atoms.receivers + edge_translation = updates[senders] - updates[receivers] + return atoms.edge_features["vectors"] + edge_translation + + +def volume_atomgraphs(atoms: AtomGraphs): + """Returns the volume of the unit cell.""" + cell = atoms.cell + cross = ops.Cross(dim=1) + return (cell[:, 0] * cross(cell[:, 1], cell[:, 2])).sum(-1) + + +def _map_concat(nests): + concat = lambda *args: _concat(args) + return tree.map_structure(concat, *nests) + + +def _concat( + tensors: List[Optional[Tensor]], +) -> Optional[Tensor]: + """Splits tensors based on the intended split sizes.""" + if any([x is None for x in tensors]): + return None + return mint.concat(tensors, dim=0) + + +def _split_tensors( + features: Optional[Tensor], + split_sizes: List[int], +) -> Sequence[Optional[Tensor]]: + """Splits tensors based on the intended split sizes.""" + if features is None: + return [None] * len(split_sizes) + + return mint.split(features, split_sizes) + + +def _split_features( + features: Optional[TensorDict], + split_sizes: List[int], +) -> Sequence[Optional[TensorDict]]: + """Splits features based on the intended split sizes.""" + if features is None: + return [None] * len(split_sizes) + + split_dict = { + k: mint.split(v, split_sizes) if v is not None else [None] * len(split_sizes) + for k, v in features.items() + } + individual_tuples = zip(*[v for v in split_dict.values()]) + individual_dicts: List[Optional[TensorDict]] = list( + map(lambda k: dict(zip(split_dict.keys(), k)), individual_tuples) + ) + return individual_dicts diff --git a/MindChemistry/applications/orb/src/featurization_utilities.py b/MindChemistry/applications/orb/src/featurization_utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..dd68d4ad8d4322bb334d9c70109ae7ce029b9050 --- /dev/null +++ b/MindChemistry/applications/orb/src/featurization_utilities.py @@ -0,0 +1,438 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Featurization utilities for molecular models.""" + +from typing import Callable, Optional, Tuple, Union +from pynanoflann import KDTree as NanoKDTree +from scipy.spatial import KDTree as SciKDTree + +import numpy as np +import mindspore as ms +from mindspore import ops, Tensor, mint + +DistanceFeaturizer = Callable[[Tensor], Tensor] + + + +def gaussian_basis_function( + scalars: Tensor, + num_bases: Union[Tensor, int], + radius: Union[Tensor, float], + scale: Union[Tensor, float] = 1.0, +) -> Tensor: + """Gaussian basis function applied to a tensor of scalars. + + Args: + scalars (Tensor): Scalars to compute the gbf on. Shape [num_scalars]. + num_bases (Tensor): The number of bases. An Int. + radius (Tensor): The largest centre of the bases. A Float. + scale (Tensor, optional): The width of the gaussians. Defaults to 1. + + Returns: + Tensor: A tensor of shape [num_scalars, num_bases]. + """ + assert len(scalars.shape) == 1 + gaussian_means = ops.arange( + 0, float(radius), float(radius) / num_bases + ) + return mint.exp( + -(scale**2) * (scalars.unsqueeze(1) - gaussian_means.unsqueeze(0)).abs() ** 2 + ) + + +def featurize_edges( + edge_vectors: Tensor, distance_featurization: DistanceFeaturizer +) -> Tensor: + """Featurizes edge features, provides concatenated unit vector along with featurized distances. + + Args: + edge_vectors (tensor): Edge vectors to featurize. Shape [num_edge, 3] + distance_featurization (DistanceFeaturization): A function that featurizes the distances of the vectors. + + Returns: + tensor: Edge features, shape [num_edge, num_edge_features]. + """ + edge_features = [] + edge_norms = mint.linalg.norm(edge_vectors, dim=1) + featurized_edge_norms = distance_featurization(edge_norms) + unit_vectors = edge_vectors / edge_norms.unsqueeze(1) + unit_vectors = mint.nan_to_num(unit_vectors, nan=0, posinf=0, neginf=0) + edge_features.append(featurized_edge_norms) + edge_features.append(unit_vectors) + return mint.cat(edge_features, dim=-1).to(ms.float32) + + +def compute_edge_vectors( + edge_index: Tensor, positions: Tensor +) -> Tensor: + """Computes edge vectors from positions. + + Args: + edge_index (tensor): The edge index. First position the senders, second + position the receivers. Shape [2, num_edge]. + positions (tensor): Positions of each node. Shape [num_nodes, 3] + + Returns: + tensor: The vectors of each edge. + """ + senders = edge_index[0] + receivers = edge_index[1] + return positions[receivers] - positions[senders] + + +# These are offsets applied to coordinates to create a 3x3x3 +# tiled periodic image of the input structure. +OFFSETS = np.array( + [ + [-1.0, 1.0, -1.0], + [0.0, 1.0, -1.0], + [1.0, 1.0, -1.0], + [-1.0, 0.0, -1.0], + [0.0, 0.0, -1.0], + [1.0, 0.0, -1.0], + [-1.0, -1.0, -1.0], + [0.0, -1.0, -1.0], + [1.0, -1.0, -1.0], + [-1.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + [1.0, 1.0, 0.0], + [-1.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [-1.0, -1.0, 0.0], + [0.0, -1.0, 0.0], + [1.0, -1.0, 0.0], + [-1.0, 1.0, 1.0], + [0.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [-1.0, 0.0, 1.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 1.0], + [-1.0, -1.0, 1.0], + [0.0, -1.0, 1.0], + [1.0, -1.0, 1.0], + ] +) + +NUM_OFFSETS = len(OFFSETS) + + +def _compute_img_positions_torch( + positions: Tensor, periodic_boundaries: Tensor +) -> Tensor: + """Computes the positions of the periodic images of the input structure. + + Consider the following 2D periodic boundary image. + + --- + --- + --- + + | | | | + + --- + --- + --- + + | | x | | + + --- + --- + --- + + | | | | + + --- + --- + --- + + + Each tile in this has an associated translation to translate + 'x'. For example, the top left would by (-1, +1). These are + the 'OFFSETS', but OFFSETS are for a 3x3x3 grid. + + This is complicated by the fact that our periodic + boundaries are not orthogonal to each other, and so we form a new + translation by taking a linear combination of the unit cell axes. + + Args: + positions (Tensor): Positions of the atoms. Shape [num_atoms, 3]. + periodic_boundaries (Tensor): Periodic boundaries of the unit cell. + This can be 2 shapes - [3, 3] or [num_atoms, 3, 3]. If the shape is + [num_atoms, 3, 3], it is assumed that the PBC has been repeat_interleaved + for each atom, i.e this function is agnostic as to whether it is computing + with respect to a batch or not. + Returns: + Tensor: The positions of the periodic images. Shape [num_atoms, 27, 3]. + """ + num_positions = len(positions) + + has_unbatched_pbc = periodic_boundaries.shape == (3, 3) + if has_unbatched_pbc: + periodic_boundaries = periodic_boundaries.unsqueeze(0) + periodic_boundaries = periodic_boundaries.expand((num_positions, 3, 3)) + + assert periodic_boundaries.shape[0] == positions.shape[0] + offsets = Tensor(OFFSETS, dtype=positions.dtype) + offsets = mint.unsqueeze(offsets, 0) + repeated_offsets = offsets.expand((num_positions, NUM_OFFSETS, 3)) + repeated_offsets = mint.unsqueeze(repeated_offsets, 3) + periodic_boundaries = mint.unsqueeze(periodic_boundaries, 1) + translations = repeated_offsets * periodic_boundaries + translations = translations.sum(2) + + # Expand the positions so we can broadcast add the translations per PBC image. + expanded_positions = positions.unsqueeze(1) + translated_positions = expanded_positions + translations + return translated_positions + + +def brute_force_knn( + img_positions: Tensor, positions: Tensor, k: int +) -> Tuple[Tensor, Tensor]: + """Brute force k-nearest neighbors. + + Args: + img_positions (Tensor): The positions of the images. Shape [num_atoms * 27, 3]. + positions (Tensor): The positions of the query atoms. Shape [num_atoms, 3]. + k (int): The number of nearest neighbors to find. + + Returns: + return_types.topk: The indices of the nearest neighbors. Shape [num_atoms, k]. + """ + dist = mint.cdist(positions, img_positions) + return mint.topk(dist, k, largest=False, sorted=True) + + +def compute_pbc_radius_graph( + *, + positions: Tensor, + periodic_boundaries: Tensor, + radius: Union[float, Tensor], + max_number_neighbors: int = 20, + brute_force: Optional[bool] = None, + library: str = "pynanoflann", + n_workers: int = 1, +) -> Tuple[Tensor, Tensor]: + """Computes periodic condition radius graph from positions. + + Args: + positions (Tensor): 3D positions of particles. Shape [num_particles, 3]. + periodic_boundaries (Tensor): A 3x3 matrix where the periodic boundary axes are rows or columns. + radius (Union[float, tensor]): The radius within which to connect atoms. + max_number_neighbors (int, optional): The maximum number of neighbors for each particle. Defaults to 20. + brute_force (bool, optional): Whether to use brute force knn. Defaults to None, in which case brute_force + is used if GPU is available (2-6x faster), but not on CPU (1.5x faster - 4x slower, depending on + system size). + library (str, optional): The KDTree library to use. Currently, either 'scipy' or 'pynanoflann'. + n_workers (int, optional): The number of workers to use for KDTree construction. Defaults to 1. + + Returns: + Tuple[Tensor, Tensor]: A 2-Tuple. First, an edge_index tensor, where the first index are the + sender indices and the second are the receiver indices. Second, the vector displacements between edges. + """ + if brute_force is None: + brute_force = ms.get_context("device_target") == "GPU" + + if mint.any(periodic_boundaries != 0.0): + supercell_positions = _compute_img_positions_torch( + positions=positions, periodic_boundaries=periodic_boundaries + ) + # CRITICALLY IMPORTANT: We need to reshape the supercell_positions to be + # flat, so we can use them for the nearest neighbors. The *way* in which + # they are flattened is important, because we need to be able to map the + # indices returned from the nearest neighbors to the original positions. + # The easiest way to do this is to transpose, so that when we flatten, we + # have: + # [ + # img_0_atom_0, + # img_0_atom_1, + # ..., + # img_0_atom_N, + # img_1_atom_0, + # img_1_atom_1, + # ..., + # img_N_atom_N, + # etc + # ] + # This way, we can take the mod of the indices returned from the nearest + # neighbors to get the original indices. + # Shape (27, num_positions, 3) + supercell_positions = supercell_positions.transpose(0, 1) + supercell_positions = supercell_positions.reshape(-1, 3) + else: + supercell_positions = positions + + num_positions = positions.shape[0] + + if brute_force: + # Brute force + distance_values, nearest_img_neighbors = brute_force_knn( + supercell_positions, + positions, + min(max_number_neighbors + 1, len(supercell_positions)), + ) + + # remove distances greater than radius, and exclude self + within_radius = distance_values[:, 1:] < (radius + 1e-6) + + num_neighbors_per_position = within_radius.sum(-1) + # remove the self node which will be closest + index_array = nearest_img_neighbors[:, 1:] + + senders = mint.repeat_interleave( + mint.arange(num_positions), num_neighbors_per_position + ) + receivers_imgs = index_array[within_radius] + + receivers = receivers_imgs % num_positions + vectors = supercell_positions[receivers_imgs] - positions[senders] + stacked = mint.stack((senders, receivers), dim=0) + return stacked, vectors + + # Build a KDTree from the supercell positions. + # Query that KDTree just for the positions in the central cell. + tree_data = supercell_positions.clone().numpy() + tree_query = positions.clone().numpy() + distance_upper_bound = np.array(radius) + 1e-8 + if library == "scipy": + tree = SciKDTree(tree_data, leafsize=100) + _, nearest_img_neighbors = tree.query( + tree_query, + max_number_neighbors + 1, + distance_upper_bound=distance_upper_bound, + workers=n_workers, + p=2, + ) + # Remove the self-edge that will be closest + index_array = np.array(nearest_img_neighbors)[:, 1:] + # Remove any entry that equals len(supercell_positions), which are negative hits + receivers_imgs = index_array[index_array != len(supercell_positions)] + num_neighbors_per_position = (index_array != len(supercell_positions)).sum( + -1 + ) + elif library == "pynanoflann": + tree = NanoKDTree( + n_neighbors=min(max_number_neighbors + 1, len(supercell_positions)), + radius=radius, + leaf_size=100, + metric="l2", + ) + tree.fit(tree_data) + distance_values, nearest_img_neighbors = tree.kneighbors( + tree_query, n_jobs=n_workers + ) + nearest_img_neighbors = nearest_img_neighbors.astype(np.int32) + + # remove the self node which will be closest + index_array = nearest_img_neighbors[:, 1:] + # remove distances greater than radius + within_radius = distance_values[:, 1:] < (radius + 1e-6) + receivers_imgs = index_array[within_radius] + num_neighbors_per_position = within_radius.sum(-1) + + # We construct our senders and receiver indexes. + senders = np.repeat(np.arange(num_positions), list(num_neighbors_per_position)) + receivers_img_torch = Tensor(receivers_imgs, ms.int32) + # Map back to indexes on the central image. + receivers = receivers_img_torch % num_positions + senders_torch = Tensor(senders, ms.int32) + + # Finally compute the vector displacements between senders and receivers. + vectors = supercell_positions[receivers_img_torch] - positions[senders_torch] + return mint.stack((senders_torch, receivers), dim=0), vectors + + +def batch_map_to_pbc_cell( + positions: Tensor, + periodic_boundary_conditions: Tensor, + num_atoms: Tensor, +) -> Tensor: + """Maps positions to within a periodic boundary cell, for a batched system. + + Args: + positions (Tensor): The positions to be mapped. Shape [num_particles, 3] + periodic_boundary_conditions (Tensor): The periodic boundary conditions. Shape [num_batches, 3, 3] + num_atoms (LongTensor): The number of atoms in each batch. Shape [num_batches] + """ + dtype = positions.dtype + positions = positions.double() + periodic_boundary_conditions = periodic_boundary_conditions.double() + + pbc_nodes = mint.repeat_interleave(periodic_boundary_conditions, num_atoms, dim=0) + + # To use the stable linalg.solve, we need to mask batch elements which don't + # have periodic boundaries. We do this by adding the identity matrix as their PBC, + # because we need the PBCs to be non-singular. + null_pbc = pbc_nodes.abs().sum(dim=[1, 2]) == 0 + identity = mint.eye(3, dtype=ms.bool_) + # Broadcast the identity to the elements of the batch that have a null pbc. + null_pbc_identity_mask = null_pbc.view(-1, 1, 1) & identity.view(1, 3, 3) + pbc_nodes_masked = pbc_nodes + null_pbc_identity_mask.double() + + lattice_coords = ops.matrix_solve(pbc_nodes_masked.transpose(1, 2), positions) + frac_coords = lattice_coords % 1.0 + + cartesian = mint.einsum("bi,bij->bj", frac_coords, pbc_nodes) + return mint.where(null_pbc.unsqueeze(1), positions, cartesian).to(dtype) + + +def batch_compute_pbc_radius_graph( + *, + positions: Tensor, + periodic_boundaries: Tensor, + radius: Union[float, Tensor], + image_idx: Tensor, + max_number_neighbors: int = 20, + brute_force: Optional[bool] = None, + library: str = "scipy", +): + """Computes batched periodic boundary condition radius graph from positions. + + This function is optimised for computation on CPU, and work work on device. GPU implementations + are likely to be significantly slower because of the irregularly sized tensor computations and the + lack of extremely fast GPU knn routines. + + Args: + positions (Tensor): 3D positions of a batch of particles. Shape [num_particles, 3]. + periodic_boundaries (Tensor): A batch where each element 3x3 matrix where the periodic boundary axes + are rows or columns. + radius (Union[float, tensor]): The radius within which to connect atoms. + image_idx (Tensor): A vector where each element indicates the number of particles in each element of + the batch. Of size len(batch). + max_number_neighbors (int, optional): The maximum number of neighbors for each particle. Defaults to 20. + brute_force (bool, optional): Whether to use brute force knn. Defaults to None, in which case brute_force + is used if we are on GPU (2-6x faster), but not on CPU (1.5x faster - 4x slower). + library (str, optional): The KDTree library to use. Currently, either 'scipy' or 'pynanoflann'. + + Returns: + Tuple[Tensor, Tensor]: A 2-Tuple. First, an edge_index tensor, where the first index are the + sender indices and the second are the receiver indices. Second, the vector displacements between edges. + """ + idx = 0 + all_edges = [] + all_vectors = [] + num_edges = [] + + for p, pbc in zip( + ops.tensor_split(positions, mint.cumsum(image_idx, 0)[:-1]), + periodic_boundaries, + ): + edges, vectors = compute_pbc_radius_graph( + positions=p, + periodic_boundaries=pbc, + radius=radius, + max_number_neighbors=max_number_neighbors, + brute_force=brute_force, + library=library, + ) + if idx == 0: + offset = 0 + else: + offset += image_idx[idx - 1] + all_edges.append(edges + offset) + all_vectors.append(vectors) + num_edges.append(len(edges[0])) + idx += 1 + + all_edges = ms.numpy.concatenate(all_edges, 1) + all_vectors = ms.numpy.concatenate(all_vectors, 0) + num_edges = Tensor(num_edges, dtype=ms.int64) + return all_edges, all_vectors, num_edges diff --git a/MindChemistry/applications/orb/src/pretrained.py b/MindChemistry/applications/orb/src/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..cb66db19fe55e2694a492e836c78d0a175bca67d --- /dev/null +++ b/MindChemistry/applications/orb/src/pretrained.py @@ -0,0 +1,116 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""pretrained.""" + +import os +from typing import Optional + +from mindspore import nn, load_checkpoint, load_param_into_net + +from mindchemistry.cell import ( + EnergyHead, + GraphHead, + Orb, + NodeHead, + MoleculeGNS, +) + + +def get_gns( + latent_dim: int = 256, + mlp_hidden_dim: int = 512, + num_message_passing_steps: int = 15, + num_edge_in_features: int = 23, + distance_cutoff: bool = True, + attention_gate: str = "sigmoid", +) -> MoleculeGNS: + """Define the base pretrained model architecture.""" + return MoleculeGNS( + num_node_in_features=256, + num_node_out_features=3, + num_edge_in_features=num_edge_in_features, + latent_dim=latent_dim, + interactions="simple_attention", + interaction_params={ + "distance_cutoff": distance_cutoff, + "polynomial_order": 4, + "cutoff_rmax": 6, + "attention_gate": attention_gate, + }, + num_message_passing_steps=num_message_passing_steps, + num_mlp_layers=2, + mlp_hidden_dim=mlp_hidden_dim, + use_embedding=True, + node_feature_names=["feat"], + edge_feature_names=["feat"], + ) + + +def load_model_for_inference(model: nn.Cell, weights_path: str) -> nn.Cell: + """ + Load a pretrained model in inference mode, using GPU if available. + """ + if not os.path.exists(weights_path): + raise FileNotFoundError(f"Checkpoint file {weights_path} not found.") + param_dict = load_checkpoint(weights_path) + load_param_into_net(model, param_dict) + model.set_train(False) + + return model + +def orb_v2( + weights_path: Optional[str] = None, +): + """Load ORB v2.""" + gns = get_gns() + + model = Orb( + graph_head=EnergyHead( + latent_dim=256, + num_mlp_layers=1, + mlp_hidden_dim=256, + target_property_dim=1, + node_aggregation="mean", + reference_energy_name="vasp-shifted", + train_reference=True, + predict_atom_avg=True, + ), + node_head=NodeHead( + latent_dim=256, + num_mlp_layers=1, + mlp_hidden_dim=256, + target_property_dim=3, + remove_mean=True, + ), + stress_head=GraphHead( + latent_dim=256, + num_mlp_layers=1, + mlp_hidden_dim=256, + target_property_dim=6, + compute_stress=True, + ), + model=gns, + ) + model = load_model_for_inference(model, weights_path) + return model + + +def orb_mptraj_only_v2( + weights_path: Optional[str] = None, +): + """Load ORB MPTraj Only v2.""" + + return orb_v2(weights_path,) diff --git a/MindChemistry/applications/orb/src/property_definitions.py b/MindChemistry/applications/orb/src/property_definitions.py new file mode 100644 index 0000000000000000000000000000000000000000..3951c06c0e0ca8b540bc8ad4c4c69649aa7affd6 --- /dev/null +++ b/MindChemistry/applications/orb/src/property_definitions.py @@ -0,0 +1,239 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Classes that define prediction targets.""" + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import ase.data +import ase.db +import ase.db.row +import ase.db.sqlite +import numpy as np + +import mindspore as ms +from mindspore import ops, Tensor, mint + +HARTREE_TO_EV = 27.211386245988 + + +def recursive_getattr(obj: object, attr: str) -> Any: + """Recursively access an object property using dot notation.""" + for sub_attr in attr.split("."): + obj = getattr(obj, sub_attr) + + return obj + + +def get_property_from_row( + name: Union[str, List[str]], + row: ase.db.row.AtomsRow, + conversion_factor: float = 1.0, +) -> Tensor: + """Retrieve arbitrary values from ase db data dict.""" + if isinstance(name, str): + names = [name] + else: + names = name + values = [] + for name_ in names: + attribute = recursive_getattr(row, name_) + target = np.array(attribute) + values.append(target) + + property_tensor = ms.from_numpy(np.hstack(values)).to(ms.float32) + + while len(property_tensor.shape) < 2: + property_tensor = property_tensor[None, ...] + + if "stress" in name and property_tensor.shape == (3, 3): + # convert stress tensor to voigt notation + property_tensor = Tensor( + [ + property_tensor[0, 0], + property_tensor[1, 1], + property_tensor[2, 2], + property_tensor[1, 2], + property_tensor[0, 2], + property_tensor[0, 1], + ], + dtype=ms.float32, + ).unsqueeze(0) + return property_tensor * conversion_factor + + +@dataclass +class PropertyDefinition: + """Defines how to extract and transform a quantative property from an ase db. + + Such properties have two primary use-cases: + - as features for the model to use / condition on. + - as target variables for regression tasks. + + Args: + name: The name of the property. + dim: The dimensionality of the property variable. + domain: Whether the variable is real, binary or categorical. If using + this variable as a regression target, then var_type determines + the loss function used e.g. MSE, BCE or cross-entropy loss. + row_to_property_fn: A function defining how a target can be + retrieved from an ase database row. + means: The mean to transform this by in the model. + stds: The std to scale this by in the model. + """ + + name: str + dim: int + domain: Literal["real", "binary", "categorical"] + row_to_property_fn: Optional[Callable] = None + means: Optional[Tensor] = None + stds: Optional[Tensor] = None + + +def energy_row_fn(row: ase.db.row.AtomsRow, dataset: str) -> float: + """Energy data in eV. + + - Some datasets use sums of energy values e.g. PBE + D3. + - For external datasets, we should explicitly register how + to extract the energy property by adding it to `extract_info'. + - Unregistered datasets default to using the `energy` attribute + and a conversion factor of 1, which is always correct for our + internally generated datasets. + """ + extract_info: Dict[str, List[Tuple]] = { + "mp-traj": [("energy", 1)], + "mp-traj-d3": [("energy", 1), ("data.d3.energy", 1)], + "alexandria-d3": [("energy", 1), ("data.d3.energy", 1)], + } + if dataset not in extract_info: + if not hasattr(row, "energy"): + raise ValueError( + f"db row {row.id} doesn't have an energy attribute directly " + ", but also doesn't define a method to extract energy info." + ) + return get_property_from_row("energy", row, 1) + + energy_ = 0.0 + for row_attribute, conversion_factor in extract_info[dataset]: + energy_ += get_property_from_row(row_attribute, row, conversion_factor) + return energy_ + + +def forces_row_fn(row: ase.db.row.AtomsRow, dataset: str): + """Force data in eV / Angstrom. + + - Some datasets use sums of energy values e.g. PBE + D3. + - For external datasets, we should explicitly register how + to extract the energy property by adding it to `extract_info'. + - Unregistered datasets default to using the `energy` attribute + and a conversion factor of 1, which is always correct for our + internally generated datasets. + """ + extract_info: Dict[str, List[Tuple]] = { + "mp-traj": [("forces", 1)], + "mp-traj-d3": [("forces", 1), ("data.d3.forces", 1)], + "alexandria-d3": [("forces", 1), ("data.d3.forces", 1)], + } + if dataset not in extract_info: + if not hasattr(row, "forces"): + raise ValueError( + f"db row {row.id} doesn't have a forces attribute directly, " + "but also doesn't define a method to extract forces info." + ) + return get_property_from_row("forces", row, 1) + + forces_ = 0.0 + for row_attribute, conversion_factor in extract_info[dataset]: + forces_ += get_property_from_row(row_attribute, row, conversion_factor) + return forces_ + + +def stress_row_fn(row: ase.db.row.AtomsRow, dataset: str) -> float: + """Extract stress data.""" + extract_info: Dict[str, List[Tuple]] = { + "mp-traj": [("stress", 1)], + "mp-traj-d3": [("stress", 1), ("data.d3.stress", 1)], + "alexandria-d3": [("stress", 1), ("data.d3.stress", 1)], + } + if dataset not in extract_info: + if not hasattr(row, "stress"): + raise ValueError( + f"db row {row.id} doesn't have an stress attribute directly " + ", but also doesn't define a method to extract stress info." + ) + return get_property_from_row("stress", row, 1) + + stress_ = 0.0 + for row_attribute, conversion_factor in extract_info[dataset]: + stress_ += get_property_from_row(row_attribute, row, conversion_factor) + return stress_ + + +def test_fixture_node_row_fn(row: ase.db.row.AtomsRow): + """Just return random noise.""" + + pos = ms.from_numpy(row.toatoms().positions) + return ops.rand_like(pos).to(ms.float32) + + +def test_fixture_graph_row_fn(): + """Just return random noise.""" + return mint.randn((1, 1)).to(ms.float32) + + +energy = PropertyDefinition( + name="energy", + dim=1, + domain="real", + row_to_property_fn=energy_row_fn, +) + +forces = PropertyDefinition( + name="forces", + dim=3, + domain="real", + row_to_property_fn=forces_row_fn, +) + +stress = PropertyDefinition( + name="stress", + dim=6, + domain="real", + row_to_property_fn=stress_row_fn, +) + +test_fixture = PropertyDefinition( + name="test-fixture", + dim=3, + domain="real", + row_to_property_fn=test_fixture_node_row_fn, +) + +test_graph_fixture = PropertyDefinition( + name="test-graph-fixture", + dim=1, + domain="real", + row_to_property_fn=test_fixture_graph_row_fn, +) + + +PROPERTIES = { + "energy": energy, + "forces": forces, + "stress": stress, + "test-fixture": test_fixture, + "test-graph-fixture": test_graph_fixture, +} diff --git a/MindChemistry/applications/orb/src/segment_ops.py b/MindChemistry/applications/orb/src/segment_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b671e4091723fbaa03d02fdfa853535d2ee3b8 --- /dev/null +++ b/MindChemistry/applications/orb/src/segment_ops.py @@ -0,0 +1,202 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Segment operations.""" + +from typing import Optional +import numpy as np +import mindspore as ms +from mindspore import ops, Tensor, mint + +MSINT = [ms.int64, ms.int32, ms.int16, ms.int8, ms.uint8] + + +def aggregate_nodes(tensor: Tensor, n_node: Tensor, reduction: str = "mean", deterministic: bool = False) -> Tensor: + """Aggregates over a tensor based on graph sizes.""" + count = len(n_node) + if deterministic: + ms.set_seed(1) + segments = ops.arange(count).repeat_interleave(n_node).astype(ms.int32) + if reduction == "sum": + return scatter_sum(tensor, segments, dim=0) + if reduction == "mean": + return scatter_mean(tensor, segments, dim=0) + if reduction == "max": + return scatter_max(tensor, segments, dim=0) + raise ValueError("Invalid reduction argument. Use sum, mean or max.") + + +def segment_sum(data: Tensor, segment_ids: Tensor, num_segments: int): + """Computes index based sum over segments of a tensor.""" + return scatter_sum(data, segment_ids, dim=0, dim_size=num_segments) + + +def segment_max(data: Tensor, segment_ids: Tensor, num_segments: int): + """Computes index based max over segments of a tensor.""" + assert segment_ids is not None, "segment_ids must not be None" + assert num_segments > 0, "num_segments must be greater than 0" + max_op = ops.ArgMaxWithValue(axis=0) + _, max_values = max_op(data) + return max_values + + +def segment_mean(data: Tensor, segment_ids: Tensor, num_segments: int): + """Computes index based mean over segments of a tensor.""" + sum_v = segment_sum(data, segment_ids, num_segments) + count = ops.scatter_add(ops.zeros( + (num_segments,), dtype=ms.int32), segment_ids, ops.ones_like(segment_ids)) + return sum_v / count.astype(sum_v.dtype) + + +def segment_softmax(data: Tensor, segment_ids: Tensor, num_segments: int, weights: Optional[Tensor] = None): + """Computes a softmax over segments of the tensor.""" + data_max = segment_max(data, segment_ids, num_segments) + data = data - data_max[segment_ids] + + unnormalised_probs = ops.exp(data) + if weights is not None: + unnormalised_probs = unnormalised_probs * weights + denominator = segment_sum(unnormalised_probs, segment_ids, num_segments) + + return safe_division(unnormalised_probs, denominator, segment_ids) + + +def safe_division(numerator: Tensor, denominator: Tensor, segment_ids: Tensor): + """Divides logits by denominator, setting 0 where the denominator is zero.""" + result = ops.where(denominator[segment_ids] == + 0, 0, numerator / denominator[segment_ids]) + return result + + +def _broadcast(src: Tensor, other: Tensor, dim: int): + """Broadcasts the source tensor to match the shape of the other tensor along the specified dimension.""" + if dim < 0: + dim = other.ndim + dim + if src.ndim == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.ndim, other.ndim): + src = src.unsqueeze(-1) + src = src.expand_as(other) + return src + + +def scatter_sum( + src: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, + dim_size: Optional[int] = None, reduce: str = "sum" +) -> Tensor: + """Applies a sum reduction of the orb_models tensor along the specified dimension.""" + assert reduce == "sum" + index = _broadcast(index, src, dim) + if out is None: + size = list(src.shape) + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = ops.zeros(size, dtype=src.dtype) + return mint.scatter_add(out, dim, index, src) + return mint.scatter_add(out, dim, index, src) + + +def scatter_std( + src: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, + dim_size: Optional[int] = None, unbiased: bool = True +) -> Tensor: + """Computes the standard deviation of the orb_models tensor along the specified dimension.""" + if out is not None: + dim_size = out.shape[dim] + + if dim < 0: + dim = src.ndim + dim + + count_dim = dim + if index.ndim <= dim: + count_dim = index.ndim - 1 + + ones = ops.ones(index.shape, dtype=src.dtype) + count = scatter_sum(ones, index, count_dim, dim_size=dim_size) + + index = _broadcast(index, src, dim) + tmp = scatter_sum(src, index, dim, dim_size=dim_size) + count = _broadcast(count, tmp, dim).clip(1) + mean = tmp / count + + var = src - mean.gather(dim, index) + var = var * var + out = scatter_sum(var, index, dim, out=out, dim_size=dim_size) + + if unbiased: + count = count - 1 + count = count.clip(1) + out = out / (count + 1e-6) + out = ops.sqrt(out) + return out + + +def scatter_mean( + src: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, + dim_size: Optional[int] = None +) -> Tensor: + """Computes the mean of the orb_models tensor along the specified dimension.""" + out = scatter_sum(src, index, dim, out=out, dim_size=dim_size) + dim_size = out.shape[dim] + + index_dim = dim + if index_dim < 0: + index_dim = index_dim + src.ndim + if index.ndim <= index_dim: + index_dim = index.ndim - 1 + + ones = ops.ones(index.shape, dtype=src.dtype) + count = scatter_sum(ones, index, index_dim, dim_size=dim_size) + count = count.clip(1) + count = _broadcast(count, out, dim) + out = out / count + return out + + +def scatter_max( + src: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, + dim_size: Optional[int] = None +) -> Tensor: + """Computes the maximum of the orb_models tensor for each group defined by index along the specified dimension.""" + if out is not None: + raise NotImplementedError( + "The 'out' argument is not supported for scatter_max") + + if src.dtype in MSINT: + init_value = np.iinfo(src.dtype).min + else: + init_value = np.finfo(src.dtype).min + + if dim < 0: + dim = src.ndim + dim + + if dim_size is None: + dim_size = int(index.max()) + 1 + + result = ops.ones( + (dim_size, *src.shape[:dim], *src.shape[dim + 1:]), dtype=src.dtype) + result = init_value * result + broadcasted_index = _broadcast(index, src, dim) + + scatter_result = ops.ZerosLike()(result) + index = ops.expand_dims(broadcasted_index, dim) + scatter_result = scatter_result.scatter_update(index, src) + result = ops.Maximum()(result, scatter_result) + return result diff --git a/MindChemistry/applications/orb/src/trainer.py b/MindChemistry/applications/orb/src/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5b1404f6b59c4d696458d11150e306f183329fe7 --- /dev/null +++ b/MindChemistry/applications/orb/src/trainer.py @@ -0,0 +1,329 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Trainer.""" + +from typing import Dict, Optional, Tuple + +import mindspore as ms +from mindspore import ops, Tensor, mint + +from src import base, segment_ops + +class OrbLoss: + """Loss function for ORB models. + + This class is used to compute the loss for the ORB model. + It can be used to compute the loss for both node and graph predictions. + """ + + def __init__(self, model): + """Initializes the OrbLoss. + + Args: + target: either the name of a PropertyDefinition or a PropertyDefinition itself. + """ + self.model = model + + def loss_node(self, batch, out_batch=None): + """Apply mlp to compute loss and metrics.""" + batch_n_node = batch.n_node + assert batch.node_targets is not None + target = batch.node_targets['forces'].squeeze(-1) + pred = out_batch["node_pred"].squeeze(-1) + # make sure we remove fixed atoms before normalization + pred, target, batch_n_node = _remove_fixed_atoms( + pred, target, batch_n_node, batch.fix_atoms, self.model.training + ) + mae = mint.abs(pred - self.model.node_head.normalizer(target)) + raw_pred = self.model.node_head.normalizer.inverse(pred) + raw_mae = mint.abs(raw_pred - target) + + mae = mae.mean(dim=-1) + mae = segment_ops.aggregate_nodes( + mae, batch_n_node, reduction="mean" + ).mean() + raw_mae = raw_mae.mean(dim=-1) + raw_mae = segment_ops.aggregate_nodes( + raw_mae, batch_n_node, reduction="mean" + ).mean() + + metrics = { + "node_mae": mae.item(), + "node_mae_raw": raw_mae.item(), + "node_cosine_sim": ops.cosine_similarity(raw_pred, target, dim=-1).mean().item(), + "fwt_0.03": forces_within_threshold(raw_pred, target, batch_n_node), + } + return mae, base.ModelOutput(loss=mae, log=metrics) + + def loss_graph(self, batch, out_batch=None): + """Apply mlp to compute loss and metrics. + + Depending on whether the target is real/binary/categorical, we + use an MSE/cross-entropy loss. In the case of cross-entropy, the + preds are logits (not normalised) to take advantage of numerically + stable log-softmax. + """ + assert batch.system_targets is not None + target = batch.system_targets['stress'].squeeze(-1) + if self.model.stress_head.compute_stress: + pred = out_batch["stress_pred"].squeeze(-1) + else: + pred = out_batch["graph_pred"].squeeze(-1) + + normalized_target = self.model.stress_head.normalizer(target) + errors = normalized_target - pred + mae = mint.abs(errors).mean() + + raw_pred = self.model.stress_head.normalizer.inverse(pred) + raw_mae = mint.abs(raw_pred - target).mean() + metrics = {"stress_mae": mae.item(), "stress_mae_raw": raw_mae.item()} + return mae, base.ModelOutput(loss=mae, log=metrics) + + + def loss_energy(self, batch, out_batch=None): + """Apply mlp to compute loss and metrics.""" + assert batch.system_targets is not None + target = batch.system_targets['energy'].squeeze(-1) + pred = out_batch["graph_pred"].squeeze(-1) + + reference = self.model.graph_head.reference(batch.atomic_numbers, batch.n_node).squeeze(-1) + reference_target = target - reference + if self.model.graph_head.atom_avg: + reference_target = reference_target / batch.n_node + + normalized_reference = self.model.graph_head.normalizer(reference_target) + model_loss = normalized_reference - pred + + raw_pred = self.model.graph_head.normalizer.inverse(pred) + if self.model.graph_head.atom_avg: + raw_pred = raw_pred * batch.n_node + raw_mae = mint.abs((raw_pred + reference) - target).mean() + + reference_mae = mint.abs(reference_target).mean() + model_mae = mint.abs(model_loss).mean() + metrics = { + "energy_reference_mae": reference_mae.item(), + "energy_mae": model_mae.item(), + "energy_mae_raw": raw_mae.item(), + } + return model_mae, base.ModelOutput(loss=model_mae, log=metrics) + + def loss(self, batch, label=None): + """Loss function of Orb GraphRegressor.""" + assert label is None, "Orb GraphRegressor does not support labels." + + out = self.model( + batch.edge_features, + batch.node_features, + batch.senders, + batch.receivers, + batch.n_node, + ) + loss = Tensor(0.0, ms.float32) + metrics: Dict = {} + + loss1, graph_out = self.loss_energy(batch, out) + metrics.update(graph_out.log) + loss = loss.type_as(loss1) + loss1 + + loss2, stress_out = self.loss_graph(batch, out) + metrics.update(stress_out.log) + loss = loss.type_as(loss2) + loss2 + + loss3, node_out = self.loss_node(batch, out) + metrics.update(node_out.log) + loss = loss.type_as(loss3) + loss3 + + metrics["loss"] = loss.item() + return loss, metrics + + +def binary_accuracy( + pred: Tensor, target: Tensor, threshold: float = 0.5 +) -> float: + """Calculate binary accuracy between 2 tensors. + + Args: + pred: the prediction tensor. + target: the tensor of target values. + threshold: Binary classification threshold. Default 0.5. + + Returns: + mean accuracy. + """ + return ((pred > threshold) == target).to(ms.float32).mean().item() + + +def categorical_accuracy(pred: Tensor, target: Tensor) -> float: + """Calculate accuracy for K class classification. + + Args: + pred: the tensor of logits for K classes of shape (..., K) + target: tensor of integer target values of shape (...) + + Returns: + mean accuracy. + """ + pred_labels = mint.argmax(pred, dim=-1) + return (pred_labels == target).to(ms.float32).mean().item() + + +def error_within_threshold( + pred: Tensor, target: Tensor, threshold: float = 0.02 +) -> float: + """Calculate MAE between 2 tensors within a threshold. + + Args: + pred: the prediction tensor. + target: the tensor of target values. + threshold: margin threshold. Default 0.02 (derived from OCP metrics). + + Returns: + Mean predictions within threshold. + """ + error = mint.abs(pred - target) + within_threshold = error < threshold + return within_threshold.to(ms.float32).mean().item() + + +def forces_within_threshold( + pred: Tensor, + target: Tensor, + batch_num_nodes: Tensor, + threshold: float = 0.03, +) -> float: + """Calculate MAE between batched graph tensors within a threshold. + + The predictions for a graph are counted as being within the threshold + only if all nodes in the graph have predictions within the threshold. + + Args: + pred: the prediction tensor. + target: the tensor of target values. + batch_num_nodes: A tensor containing the number of nodes per + graph. + threshold: margin threshold. Default 0.03 (derived from OCP metrics). + + Returns: + Mean predictions within threshold. + """ + error = mint.abs(pred - target) + largest_dim_fwt = error.max(-1)[0] < threshold + + count_within_threshold = segment_ops.aggregate_nodes( + largest_dim_fwt.float(), batch_num_nodes, reduction="sum" + ) + # count equals batch_num_nodes if all nodes within threshold + return (count_within_threshold == batch_num_nodes).to(ms.float32).mean().item() + + +def energy_and_forces_within_threshold( + pred_energy: Tensor, + pred_forces: Tensor, + target_energy: Tensor, + target_forces: Tensor, + batch_num_nodes: Tensor, + fixed_atoms: Optional[Tensor] = None, + threshold: Tuple[float, float] = (0.02, 0.03), +) -> float: + """Calculate MAE between batched graph energies and forces within a threshold. + + The predictions for a graph are counted as being within the threshold + only if all nodes in the graph have predictions within the threshold AND + the energies are also within a threshold. A combo of the two above functions. + + Args: + pred_*: the prediction tensors. + target_*: the tensor of target values. + batch_num_nodes: A tensor containing the number of nodes per + graph. + fixed_atoms: A tensor of bools indicating which atoms are fixed. + threshold: margin threshold. Default (0.02, 0.03) (derived from OCP metrics). + Returns: + Mean predictions within threshold. + """ + energy_err = mint.abs(pred_energy - target_energy) + ewt = energy_err < threshold[0] + + forces_err = mint.abs(pred_forces - target_forces) + largest_dim_fwt = forces_err.max(-1).values < threshold[1] + + working_largest_dim_fwt = largest_dim_fwt + + if fixed_atoms is not None: + fixed_per_graph = segment_ops.aggregate_nodes( + fixed_atoms.int(), batch_num_nodes, reduction="sum" + ) + # remove the fixed atoms from the counts + batch_num_nodes = batch_num_nodes - fixed_per_graph + # remove the fixed atoms from the forces + working_largest_dim_fwt = largest_dim_fwt[not fixed_atoms] + + force_count_within_threshold = segment_ops.aggregate_nodes( + working_largest_dim_fwt.int(), batch_num_nodes, reduction="sum" + ) + fwt = force_count_within_threshold == batch_num_nodes + + # count equals batch_num_nodes if all nodes within threshold + return (fwt & ewt).to(ms.float32).mean().item() + + +def _remove_fixed_atoms( + pred_node: Tensor, + node_target: Tensor, + batch_n_node: Tensor, + fix_atoms: Optional[Tensor], + training: bool, +): + """We use inf targets on purpose to designate nodes for removal.""" + assert len(pred_node) == len(node_target) + if fix_atoms is not None and not training: + pred_node = pred_node[~fix_atoms] + node_target = node_target[~fix_atoms] + batch_n_node = segment_ops.aggregate_nodes( + (~fix_atoms).int(), batch_n_node, reduction="sum" + ) + return pred_node, node_target, batch_n_node + + +def bce_loss( + pred: Tensor, target: Tensor, metric_prefix: str = "" +) -> Tuple: + """Binary cross-entropy loss with accuracy metric.""" + loss = mint.nn.BCEWithLogitsLoss()(pred, target.float()) + accuracy = binary_accuracy(pred, target) + return ( + loss, + { + f"{metric_prefix}_accuracy": accuracy, + f"{metric_prefix}_loss": loss.item(), + }, + ) + + +def cross_entropy_loss( + pred: Tensor, target: Tensor, metric_prefix: str = "" +) -> Tuple: + """Cross-entropy loss with accuracy metric.""" + loss = mint.nn.CrossEntropyLoss()(pred, target.long()) + accuracy = categorical_accuracy(pred, target) + return ( + loss, + { + f"{metric_prefix}_accuracy": accuracy, + f"{metric_prefix}_loss": loss.item(), + }, + ) diff --git a/MindChemistry/applications/orb/src/utils.py b/MindChemistry/applications/orb/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..34bd4623f1c452d269085eea8aa03c45550fb06a --- /dev/null +++ b/MindChemistry/applications/orb/src/utils.py @@ -0,0 +1,296 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Experiment utilities.""" + +import math +import random +import re +from collections import defaultdict +from typing import Dict, List, Mapping, Optional, Tuple, TypeVar, Any + +import yaml +import numpy as np +import mindspore as ms +from mindspore import Tensor, mint + +from src import base + +T = TypeVar("T") + + +def load_cfg(filename): + """load_cfg + + Load configurations from yaml file and return a namespace object + """ + from argparse import Namespace + with open(filename, "r", encoding="utf-8") as f: + cfg = yaml.safe_load(f) + return Namespace(**cfg) + + +def ensure_detached(x: base.Metric) -> base.Metric: + """Ensure that the tensor is detached and on the CPU.""" + return x + + +def to_item(x: base.Metric) -> base.Metric: + """Convert a tensor to a python scalar.""" + if isinstance(x, Tensor): + return x.item() + return x + + +def prefix_keys( + dict_to_prefix: Dict[str, T], prefix: str, sep: str = "/" +) -> Dict[str, T]: + """Add a prefix to dictionary keys with a separator.""" + return {f"{prefix}{sep}{k}": v for k, v in dict_to_prefix.items()} + + +def seed_everything(seed: int, rank: int = 0) -> None: + """Set the seed for all pseudo random number generators.""" + random.seed(seed + rank) + np.random.seed(seed + rank) + ms.manual_seed(seed + rank) + + +class ScalarMetricTracker: + """Keep track of average scalar metric values.""" + + def __init__(self): + self.reset() + + def reset(self): + """Reset the AverageMetrics.""" + self.sums = defaultdict(float) + self.counts = defaultdict(int) + + def update(self, metrics: Mapping[str, base.Metric]) -> None: + """Update the metric counts with new values.""" + for k, v in metrics.items(): + if isinstance(v, Tensor) and v.nelement() > 1: + continue # only track scalar metrics + if isinstance(v, Tensor) and v.isnan().any(): + continue + self.sums[k] += ensure_detached(v) + self.counts[k] += 1 + + def get_metrics(self): + """Get the metric values, possibly reducing across gpu processes.""" + return {k: to_item(v) / self.counts[k] for k, v in self.sums.items()} + + +def gradient_clipping( + model: ms.nn.Cell, clip_value: float +) -> List[Any]: + """Add gradient clipping hooks to a model. + + This is the correct way to implement gradient clipping, because + gradients are clipped as gradients are computed, rather than after + all gradients are computed - this means expoding gradients are less likely, + because they are "caught" earlier. + + Args: + model: The model to add hooks to. + clip_value: The upper and lower threshold to clip the gradients to. + + Returns: + A list of handles to remove the hooks from the parameters. + """ + handles = [] + + def _clip(grad): + if grad is None: + return grad + return grad.clamp(min=-clip_value, max=clip_value) + + for parameter in model.trainable_params(): + if parameter.requires_grad: + h = parameter.register_hook(_clip) + handles.append(h) + + return handles + + +def get_optim( + lr: float, total_steps: int, model: ms.nn.Cell +) -> Tuple[ms.experimental.optim.Optimizer, Optional[ms.experimental.optim.lr_scheduler.LRScheduler]]: + """Configure optimizers, LR schedulers and EMA.""" + + # Initialize parameter groups + params = [] + + # Split parameters based on the regex + for param in model.trainable_params(): + name = param.name + if re.search(r"(.*bias|.*layer_norm.*|.*batch_norm.*)", name): + params.append({"params": param, "weight_decay": 0.0}) + else: + params.append({"params": param}) + + # Create the optimizer with the parameter groups + optimizer = ms.experimental.optim.Adam(params, lr=lr) + + # Create the learning rate scheduler + scheduler = ms.experimental.optim.lr_scheduler.CyclicLR( + optimizer, base_lr=1.0e-9, max_lr=lr, step_size_up=int(total_steps*0.04), step_size_down=total_steps + ) + + return optimizer, scheduler + + +def rand_angles(*shape, dtype=None): + r"""random rotation angles + + Parameters + ---------- + *shape : int + + Returns + ------- + alpha : `Tensor` + tensor of shape :math:`(\mathrm{shape})` + + beta : `Tensor` + tensor of shape :math:`(\mathrm{shape})` + + gamma : `Tensor` + tensor of shape :math:`(\mathrm{shape})` + """ + alpha, gamma = 2 * math.pi * mint.rand(2, *shape, dtype=dtype) + beta = mint.rand(shape, dtype=dtype).mul(2).sub(1).acos() + return alpha, beta, gamma + + +def matrix_x(angle: Tensor) -> Tensor: + r"""matrix of rotation around X axis + + Parameters + ---------- + angle : `Tensor` + tensor of any shape :math:`(...)` + + Returns + ------- + `Tensor` + matrices of shape :math:`(..., 3, 3)` + """ + c = angle.cos() + s = angle.sin() + o = mint.ones_like(angle) + z = mint.zeros_like(angle) + return mint.stack( + [ + mint.stack([o, z, z], dim=-1), + mint.stack([z, c, -s], dim=-1), + mint.stack([z, s, c], dim=-1), + ], + dim=-2, + ) + + +def matrix_y(angle: Tensor) -> Tensor: + r"""matrix of rotation around Y axis + + Parameters + ---------- + angle : `Tensor` + tensor of any shape :math:`(...)` + + Returns + ------- + `Tensor` + matrices of shape :math:`(..., 3, 3)` + """ + c = angle.cos() + s = angle.sin() + o = mint.ones_like(angle) + z = mint.zeros_like(angle) + return mint.stack( + [ + mint.stack([c, z, s], dim=-1), + mint.stack([z, o, z], dim=-1), + mint.stack([-s, z, c], dim=-1), + ], + dim=-2, + ) + + +def matrix_z(angle: Tensor) -> Tensor: + r"""matrix of rotation around Z axis + + Parameters + ---------- + angle : `Tensor` + tensor of any shape :math:`(...)` + + Returns + ------- + `Tensor` + matrices of shape :math:`(..., 3, 3)` + """ + c = angle.cos() + s = angle.sin() + o = mint.ones_like(angle) + z = mint.zeros_like(angle) + return mint.stack( + [ + mint.stack([c, -s, z], dim=-1), + mint.stack([s, c, z], dim=-1), + mint.stack([z, z, o], dim=-1), + ], + dim=-2, + ) + + +def angles_to_matrix(alpha, beta, gamma): + r"""conversion from angles to matrix + + Parameters + ---------- + alpha : `Tensor` + tensor of shape :math:`(...)` + + beta : `Tensor` + tensor of shape :math:`(...)` + + gamma : `Tensor` + tensor of shape :math:`(...)` + + Returns + ------- + `Tensor` + matrices of shape :math:`(..., 3, 3)` + """ + alpha, beta, gamma = ms.numpy.broadcast_arrays(alpha, beta, gamma) + return matrix_y(alpha) @ matrix_x(beta) @ matrix_y(gamma) + + +def rand_matrix(*shape, dtype=None): + r"""random rotation matrix + + Parameters + ---------- + *shape : int + + Returns + ------- + `Tensor` + tensor of shape :math:`(\mathrm{shape}, 3, 3)` + """ + rotation_matrix = angles_to_matrix(*rand_angles(*shape, dtype=dtype)) + return rotation_matrix diff --git a/MindChemistry/mindchemistry/cell/__init__.py b/MindChemistry/mindchemistry/cell/__init__.py index f92153c675690f6e66162a99535340533a424760..5660308bc1d2931107aaf328c36f1e927f1d8fc2 100644 --- a/MindChemistry/mindchemistry/cell/__init__.py +++ b/MindChemistry/mindchemistry/cell/__init__.py @@ -21,6 +21,7 @@ from .deephe3nn import * from .matformer import * from .dimenet import * from .gemnet import * +from .orb import * __all__ = [ "Nequip", 'AutoEncoder', 'FCNet', 'MLPNet', 'CSPNet' @@ -30,3 +31,4 @@ __all__.extend(matformer.__all__) __all__.extend(allegro.__all__) __all__.extend(dimenet.__all__) __all__.extend(gemnet.__all__) +__all__.extend(orb.__all__) diff --git a/MindChemistry/mindchemistry/cell/orb/__init__.py b/MindChemistry/mindchemistry/cell/orb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..709978030161be6191c7d5bc96b466e777c6ae3a --- /dev/null +++ b/MindChemistry/mindchemistry/cell/orb/__init__.py @@ -0,0 +1,36 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" + +from .orb import ( + NodeHead, + GraphHead, + EnergyHead, + Orb, +) +from .gns import ( + AttentionInteractionNetwork, + MoleculeGNS, +) + +__all__ = [ + "AttentionInteractionNetwork", + "EnergyHead", + "GraphHead", + "MoleculeGNS", + "NodeHead", + "Orb", +] diff --git a/MindChemistry/mindchemistry/cell/orb/gns.py b/MindChemistry/mindchemistry/cell/orb/gns.py new file mode 100644 index 0000000000000000000000000000000000000000..ab53083f8605dbe0463a5d9a0554d5e5990428f9 --- /dev/null +++ b/MindChemistry/mindchemistry/cell/orb/gns.py @@ -0,0 +1,690 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""GNS Molecule.""" + + +from typing import List, Literal, Optional, Dict, Any, Union +from functools import partial + +import numpy as np +from mindspore import nn, ops, Tensor, mint +from mindspore.common.initializer import Uniform +import mindspore.ops.operations as P + +from mindchemistry.cell.orb.utils import build_mlp + +_KEY = "feat" + + +def mlp_and_layer_norm(in_dim: int, out_dim: int, hidden_dim: int, n_layers: int) -> nn.SequentialCell: + """Create an MLP followed by layer norm. + + Args: + in_dim (int): Input dimension. + out_dim (int): Output dimension. + hidden_dim (int): Hidden dimension. + n_layers (int): Number of hidden layers. + + Returns: + nn.SequentialCell: A sequential cell containing the MLP and layer norm. + """ + layers = build_mlp( + in_dim, + [hidden_dim for _ in range(n_layers)], + out_dim, + ) + layers.append(nn.LayerNorm((out_dim,))) + return layers + + +def get_cutoff(p: int, r: Tensor, r_max: float) -> Tensor: + """Get the cutoff function for attention. + + Args: + p (int): Polynomial order. + r (Tensor): Distance tensor. + r_max (float): Maximum distance for the cutoff. + + Returns: + Tensor: Cutoff tensor. + """ + envelope = 1.0 - ((p + 1.0) * (p + 2.0) / 2.0) * ops.pow(r / r_max, p) + \ + p * (p + 2.0) * ops.pow(r / r_max, p + 1) - \ + (p * (p + 1.0) / 2) * ops.pow(r / r_max, p + 2) + cutoff = ops.expand_dims( + ops.where(r < r_max, envelope, ops.zeros_like(envelope)), -1) + return cutoff + + +def gaussian_basis_function( + scalars: Tensor, + num_bases: Union[Tensor, int], + radius: Union[Tensor, float], + scale: Union[Tensor, float] = 1.0, +) -> Tensor: + """Gaussian basis function applied to a tensor of scalars. + + Args: + scalars (Tensor): Scalars to compute the gbf on. Shape [num_scalars]. + num_bases (Tensor): The number of bases. An Int. + radius (Tensor): The largest centre of the bases. A Float. + scale (Tensor, optional): The width of the gaussians. Defaults to 1. + + Returns: + Tensor: A tensor of shape [num_scalars, num_bases]. + """ + assert len(scalars.shape) == 1 + gaussian_means = ops.arange( + 0, float(radius), float(radius) / num_bases + ) + return mint.exp( + -(scale**2) * (scalars.unsqueeze(1) - gaussian_means.unsqueeze(0)).abs() ** 2 + ) + + +class AtomEmbedding(nn.Cell): + r""" + AtomEmbedding Layer. + + This layer initializes atom embeddings based on the atomic number of elements in the periodic table. + It uses an embedding table initialized with a uniform distribution over the range [-sqrt(3), sqrt(3)]. + + Args: + emb_size (int): Size of the embedding vector for each atom. + num_elements (int): Number of elements in the periodic table (typically 118 for known elements). + + Inputs: + - **x** (Tensor) - Input tensor of shape [..., num_atoms], where + each value represents the atomic number of an atom in the periodic table. + + Outputs: + - **h** (Tensor) - Output tensor of shape [..., num_atoms, emb_size], + where each atom's embedding is represented as a vector of size `emb_size`. + + Supported Platforms: + ``Ascend`` + """ + def __init__(self, emb_size, num_elements): + """init + """ + super().__init__() + self.emb_size = emb_size + self.embeddings = nn.Embedding( + num_elements + 1, emb_size, embedding_table=Uniform(np.sqrt(3))) + + def construct(self, x): + """construct + """ + h = self.embeddings(x) + return h + + +class Encoder(nn.Cell): + r""" + Encoder for Graph Network States (GNS). + + This encoder processes node and edge features using MLPs and layer normalization. + It concatenates the features of nodes and edges, applies MLPs to update their states, + and returns the updated features. + + Args: + num_node_in_features (int): Number of input features for nodes. + num_node_out_features (int): Number of output features for nodes. + num_edge_in_features (int): Number of input features for edges. + num_edge_out_features (int): Number of output features for edges. + num_mlp_layers (int): Number of MLP layers. + mlp_hidden_dim (int): Hidden dimension for the MLP. + node_feature_names (List[str]): List of node feature names. + edge_feature_names (List[str]): List of edge feature names. + + Inputs: + - **nodes** (Dict[str, Tensor]) - Dictionary of node features, where keys are feature names + and values are tensors of shape (num_nodes, num_node_in_features). + - **edges** (Dict[str, Tensor]) - Dictionary of edge features, where keys are feature names + and values are tensors of shape (num_edges, num_edge_in_features). + + Outputs: + - **edges** (Dict[str, Tensor]) - Updated edge features dictionary, where key "feat" contains + the updated edge features of shape (num_edges, num_edge_out_features). + - **nodes** (Dict[str, Tensor]) - Updated node features dictionary, where key "feat" contains + the updated node features of shape (num_nodes, num_node_out_features). + + Supported Platforms: + ``Ascend`` + """ + + def __init__(self, + num_node_in_features: int, + num_node_out_features: int, + num_edge_in_features: int, + num_edge_out_features: int, + num_mlp_layers: int, + mlp_hidden_dim: int, + node_feature_names: List[str], + edge_feature_names: List[str]): + """init + """ + super().__init__() + self.node_feature_names = node_feature_names + self.edge_feature_names = edge_feature_names + self._node_fn = mlp_and_layer_norm( + num_node_in_features, num_node_out_features, mlp_hidden_dim, num_mlp_layers) + self._edge_fn = mlp_and_layer_norm( + num_edge_in_features, num_edge_out_features, mlp_hidden_dim, num_mlp_layers) + + def construct(self, nodes, edges): + """construct + """ + edge_features = ops.cat([edges[k] for k in self.edge_feature_names], axis=-1) + node_features = ops.cat([nodes[k] for k in self.node_feature_names], axis=-1) + + edges.update({_KEY: self._edge_fn(edge_features)}) + nodes.update({_KEY: self._node_fn(node_features)}) + return edges, nodes + + +class InteractionNetwork(nn.Cell): + r""" + Interaction Network. + + Implements a message passing neural network layer that updates node and edge features based on interactions. + This layer combines node and edge features, applies MLPs to update their states, and returns the updated features. + + Args: + num_node_in (int): Number of input features for nodes. + num_node_out (int): Number of output features for nodes. + num_edge_in (int): Number of input features for edges. + num_edge_out (int): Number of output features for edges. + num_mlp_layers (int): Number of MLP layers. + mlp_hidden_dim (int): Hidden dimension for the MLP. + + Inputs: + - **graph_edges** (Dict[str, Tensor]) - Dictionary of edge features, where key "feat" contains + the edge features of shape (num_edges, num_edge_in). + - **graph_nodes** (Dict[str, Tensor]) - Dictionary of node features, where key "feat" contains + the node features of shape (num_nodes, num_node_in). + - **senders** (Tensor) - Indices of the sender nodes for each edge, shape (num_edges,). + - **receivers** (Tensor) - Indices of the receiver nodes for each edge, shape (num_edges,). + + Outputs: + - **edges** (Dict[str, Tensor]) - Updated edge features dictionary, where key "feat" contains + the updated edge features of shape (num_edges, num_edge_out). + - **nodes** (Dict[str, Tensor]) - Updated node features dictionary, where key "feat" contains + the updated node features of shape (num_nodes, num_node_out). + + Supported Platforms: + ``Ascend`` + """ + def __init__(self, + num_node_in: int, + num_node_out: int, + num_edge_in: int, + num_edge_out: int, + num_mlp_layers: int, + mlp_hidden_dim: int): + """init + """ + super().__init__() + self._node_mlp = mlp_and_layer_norm( + num_node_in + num_edge_out, num_node_out, mlp_hidden_dim, num_mlp_layers) + self._edge_mlp = mlp_and_layer_norm( + num_node_in + num_node_in + num_edge_in, num_edge_out, mlp_hidden_dim, num_mlp_layers) + + def construct(self, graph_edges, graph_nodes, senders, receivers): + """construct + """ + nodes = graph_nodes[_KEY] + edges = graph_edges[_KEY] + + sent_attributes = ops.gather(nodes, senders, 0) + received_attributes = ops.gather(nodes, receivers, 0) + + edge_features = ops.cat( + [edges, sent_attributes, received_attributes], axis=1) + updated_edges = self._edge_mlp(edge_features) + + received_attributes = ops.scatter_add( + ops.zeros_like(nodes), receivers, updated_edges) + + node_features = ops.cat([nodes, received_attributes], axis=1) + updated_nodes = self._node_mlp(node_features) + + nodes = graph_nodes[_KEY] + updated_nodes + edges = graph_edges[_KEY] + updated_edges + + node_features = {**graph_nodes, _KEY: nodes} + edge_features = {**graph_edges, _KEY: edges} + return edge_features, node_features + + +# pylint: disable=C0301 +class AttentionInteractionNetwork(nn.Cell): + r""" + Attention interaction network. + Implements attention-based message passing neural network layer for edge updates in molecular graphs. + + Args: + num_node_in (int): Number of input node features. + num_node_out (int): Number of output node features. + num_edge_in (int): Number of input edge features. + num_edge_out (int): Number of output edge features. + num_mlp_layers (int): Number of hidden layers in node and edge update MLPs. + mlp_hidden_dim (int): Hidden dimension size of MLPs. + attention_gate (str, optional): Attention gate type, ``"sigmoid"`` or ``"softmax"``. Default: ``"sigmoid"``. + distance_cutoff (bool, optional): Whether to use distance-based edge cutoff. Default: ``True``. + polynomial_order (int, optional): Order of polynomial cutoff function. Default: ``4``. + cutoff_rmax (float, optional): Maximum distance for cutoff. Default: ``6.0``. + + Inputs: + - **graph_edges** (dict) - Edge feature dictionary, must contain key "feat" with shape :math:`(n_{edges}, num\_edge\_in)`. + - **graph_nodes** (dict) - Node feature dictionary, must contain key "feat" with shape :math:`(n_{nodes}, num\_node\_in)`. + - **senders** (Tensor) - Sender node indices for each edge, shape :math:`(n_{edges},)`. + - **receivers** (Tensor) - Receiver node indices for each edge, shape :math:`(n_{edges},)`. + + Outputs: + - **edges** (dict) - Updated edge feature dictionary with key "feat" of shape :math:`(n_{edges}, num\_edge\_out)`. + - **nodes** (dict) - Updated node feature dictionary with key "feat" of shape :math:`(n_{nodes}, num\_node\_out)`. + + Raises: + ValueError: If `attention_gate` is not "sigmoid" or "softmax". + ValueError: If edge or node features do not contain the required "feat" key. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> from mindspore import Tensor + >>> from mindchemistry.cell.orb.gns import AttentionInteractionNetwork + >>> attn_net = AttentionInteractionNetwork( + ... num_node_in=256, + ... num_node_out=256, + ... num_edge_in=256, + ... num_edge_out=256, + ... num_mlp_layers=2, + ... mlp_hidden_dim=512, + ... ) + >>> n_atoms = 4 + >>> n_edges = 10 + >>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) + >>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) + >>> for i, num in enumerate(atomic_numbers.asnumpy()): + ... atomic_numbers_embedding_np[i, num - 1] = 1.0 + >>> node_features = { + ... "atomic_numbers": atomic_numbers, + ... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), + ... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), + ... "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) + ... } + >>> edge_features = { + ... "vectors": Tensor(np.random.randn(n_edges, 3).astype(np.float32)), + ... "r": Tensor(np.abs(np.random.randn(n_edges).astype(np.float32) * 10)), + ... "feat": Tensor(np.random.randn(n_edges, 256).astype(np.float32)) + ... } + >>> senders = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) + >>> receivers = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) + >>> edges, nodes = attn_net( + ... edge_features, + ... node_features, + ... senders, + ... receivers, + ... ) + >>> print(edges["feat"].shape, nodes["feat"].shape) + (10, 256) (4, 256) + """ + + def __init__(self, + num_node_in: int, + num_node_out: int, + num_edge_in: int, + num_edge_out: int, + num_mlp_layers: int, + mlp_hidden_dim: int, + attention_gate: Literal["sigmoid", "softmax"] = "sigmoid", + distance_cutoff: bool = True, + polynomial_order: Optional[int] = 4, + cutoff_rmax: Optional[float] = 6.0): + """init + """ + super().__init__() + self._num_node_in = num_node_in + self._num_node_out = num_node_out + self._num_edge_in = num_edge_in + self._num_edge_out = num_edge_out + self._num_mlp_layers = num_mlp_layers + self._mlp_hidden_dim = mlp_hidden_dim + self._node_mlp = mlp_and_layer_norm( + num_node_in + num_edge_out + num_edge_out, num_node_out, mlp_hidden_dim, num_mlp_layers) + self._edge_mlp = mlp_and_layer_norm( + num_node_in + num_node_in + num_edge_in, num_edge_out, mlp_hidden_dim, num_mlp_layers) + self._receive_attn = nn.Dense(num_edge_in, 1) + self._send_attn = nn.Dense(num_edge_in, 1) + self._distance_cutoff = distance_cutoff + self._r_max = cutoff_rmax + self._polynomial_order = polynomial_order + self._attention_gate = attention_gate + + self.scatter_add = P.TensorScatterAdd() + + def construct(self, graph_edges, graph_nodes, senders, receivers): + """construct + """ + nodes = graph_nodes[_KEY] + edges = graph_edges[_KEY] + + p = self._polynomial_order + r_max = self._r_max + r = graph_edges['r'] + cutoff = get_cutoff(p, r, r_max) + + sent_attributes = ops.gather(nodes, senders, 0) + received_attributes = ops.gather(nodes, receivers, 0) + + if self._attention_gate == "softmax": + receive_attn = ops.softmax(self._receive_attn(edges), axis=0) + send_attn = ops.softmax(self._send_attn(edges), axis=0) + else: + receive_attn = ops.sigmoid(self._receive_attn(edges)) + send_attn = ops.sigmoid(self._send_attn(edges)) + + if self._distance_cutoff: + receive_attn = receive_attn * cutoff + send_attn = send_attn * cutoff + + edge_features = ops.cat( + [edges, sent_attributes, received_attributes], axis=1) + updated_edges = self._edge_mlp(edge_features) + + if senders.ndim < 2: + senders = senders.unsqueeze(-1) + sent_attributes = self.scatter_add( + ops.zeros_like(nodes), senders, updated_edges * send_attn) + if receivers.ndim < 2: + receivers = receivers.unsqueeze(-1) + received_attributes = self.scatter_add( + ops.zeros_like(nodes), receivers, updated_edges * receive_attn) + + node_features = ops.cat( + [nodes, received_attributes, sent_attributes], axis=1) + updated_nodes = self._node_mlp(node_features) + + nodes = graph_nodes[_KEY] + updated_nodes + edges = graph_edges[_KEY] + updated_edges + + node_features = {**graph_nodes, _KEY: nodes} + edge_features = {**graph_edges, _KEY: edges} + return edge_features, node_features + +class Decoder(nn.Cell): + r""" + Decoder for Graph Network States (GNS). + + This decoder processes node features using an MLP to produce predictions. + It takes the node features as input and outputs updated node features with predictions. + + Args: + num_node_in (int): Number of input features for nodes. + num_node_out (int): Number of output features for nodes. + num_mlp_layers (int): Number of MLP layers. + mlp_hidden_dim (int): Hidden dimension for the MLP. + batch_norm (bool, optional): Whether to apply batch normalization. Defaults to False. + + Inputs: + - **graph_nodes** (Dict[str, Tensor]) - Dictionary of node features, where key "feat" contains + the node features of shape (num_nodes, num_node_in). + + Outputs: + - **graph_nodes** (Dict[str, Tensor]) - Updated node features dictionary, where key "pred" contains + the predicted node features of shape (num_nodes, num_node_out). + + Supported Platforms: + ``Ascend`` + """ + def __init__(self, + num_node_in: int, + num_node_out: int, + num_mlp_layers: int, + mlp_hidden_dim: int, + batch_norm: bool = False): + """Initialization. + Args: + num_node_in (int): Number of input features for nodes. + num_node_out (int): Number of output features for nodes. + num_mlp_layers (int): Number of MLP layers. + mlp_hidden_dim (int): Hidden dimension for the MLP. + batch_norm (bool, optional): Whether to apply batch normalization. Defaults to False. + """ + super().__init__() + seq = build_mlp( + num_node_in, + [mlp_hidden_dim for _ in range(num_mlp_layers)], + num_node_out, + ) + if batch_norm: + seq.append(nn.BatchNorm1d(num_node_out)) + self.node_fn = nn.SequentialCell(seq) + + def construct(self, graph_nodes): + """Forward pass of the decoder. + Args: + graph_nodes (Dict[str, Tensor]): Dictionary of node features. + Returns: + Dict[str, Tensor]: Updated node features with predictions. + """ + nodes = graph_nodes[_KEY] + updated = self.node_fn(nodes) + return {**graph_nodes, "pred": updated} + + +# pylint: disable=C0301 +class MoleculeGNS(nn.Cell): + r""" + Molecular graph neural network. + Implements flexible modular graph neural network for molecular property prediction based on message passing + with attention or other interaction mechanisms. Supports node and edge embeddings, multiple message passing + steps, and customizable interaction layers for complex molecular graphs. + + Args: + num_node_in_features (int): Number of input features per node. + num_node_out_features (int): Number of output features per node. + num_edge_in_features (int): Number of input features per edge. + latent_dim (int): Latent dimension for node and edge representations. + num_message_passing_steps (int): Number of message passing layers. + num_mlp_layers (int): Number of hidden layers in node and edge update MLPs. + mlp_hidden_dim (int): Hidden dimension size of MLPs. + node_feature_names (List[str]): List of node feature keys to use from input dictionary. + edge_feature_names (List[str]): List of edge feature keys to use from input dictionary. + use_embedding (bool, optional): Whether to use atomic number embedding for nodes. Default: ``True``. + interactions (str, optional): Type of interaction layer to use (e.g., ``"simple_attention"``). Default: ``"simple_attention"``. + interaction_params (Optional[Dict[str, Any]], optional): Parameters for interaction layer, e.g., cutoff, + polynomial order, gate type. Default: ``None``. + + Inputs: + - **edge_features** (dict) - Edge feature dictionary, must contain keys specified in `edge_feature_names`. + - **node_features** (dict) - Node feature dictionary, must contain keys specified in `node_feature_names`. + - **senders** (Tensor) - Sender node indices for each edge, shape :math:`(n_{edges},)`. + - **receivers** (Tensor) - Receiver node indices for each edge, shape :math:`(n_{edges},)`. + + Outputs: + - **edges** (dict) - Updated edge feature dictionary with key "feat" of shape :math:`(n_{edges}, latent\_dim)`. + - **nodes** (dict) - Updated node feature dictionary with key "feat" of shape :math:`(n_{nodes}, latent\_dim)`. + + Raises: + ValueError: If required feature keys are missing in `edge_features` or `node_features`. + ValueError: If `interactions` is not a supported type. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> from mindspore import Tensor + >>> from mindchemistry.cell.orb.gns import MoleculeGNS + >>> gns_model = MoleculeGNS( + ... num_node_in_features=256, + ... num_node_out_features=3, + ... num_edge_in_features=23, + ... latent_dim=256, + ... interactions="simple_attention", + ... interaction_params={ + ... "distance_cutoff": True, + ... "polynomial_order": 4, + ... "cutoff_rmax": 6, + ... "attention_gate": "sigmoid", + ... }, + ... num_message_passing_steps=15, + ... num_mlp_layers=2, + ... mlp_hidden_dim=512, + ... use_embedding=True, + ... node_feature_names=["feat"], + ... edge_feature_names=["feat"], + ... ) + >>> n_atoms = 4 + >>> n_edges = 10 + >>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) + >>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) + >>> for i, num in enumerate(atomic_numbers.asnumpy()): + ... atomic_numbers_embedding_np[i, num - 1] = 1.0 + >>> node_features = { + ... "atomic_numbers": atomic_numbers, + ... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), + ... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), + ... "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) + ... } + >>> edge_features = { + ... "vectors": Tensor(np.random.randn(n_edges, 3).astype(np.float32)), + ... "r": Tensor(np.abs(np.random.randn(n_edges).astype(np.float32) * 10)), + ... "feat": Tensor(np.random.randn(n_edges, 256).astype(np.float32)) + ... } + >>> senders = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) + >>> receivers = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) + >>> edges, nodes = gns_model( + ... edge_features, + ... node_features, + ... senders, + ... receivers, + ... ) + >>> print(edges["feat"].shape, nodes["feat"].shape) + (10, 256) (4, 256) + """ + + def __init__(self, + num_node_in_features: int, + num_node_out_features: int, + num_edge_in_features: int, + latent_dim: int, + num_message_passing_steps: int, + num_mlp_layers: int, + mlp_hidden_dim: int, + node_feature_names: List[str], + edge_feature_names: List[str], + use_embedding: bool = True, + interactions: Literal["default", + "simple_attention"] = "simple_attention", + interaction_params: Optional[Dict[str, Any]] = None): + """init + """ + super().__init__() + self._encoder = Encoder( + num_node_in_features=num_node_in_features, + num_node_out_features=latent_dim, + num_edge_in_features=num_edge_in_features, + num_edge_out_features=latent_dim, + num_mlp_layers=num_mlp_layers, + mlp_hidden_dim=mlp_hidden_dim, + node_feature_names=node_feature_names, + edge_feature_names=edge_feature_names + ) + if interactions == "default": + InteractionNetworkClass = InteractionNetwork + elif interactions == "simple_attention": + InteractionNetworkClass = AttentionInteractionNetwork + self.num_message_passing_steps = num_message_passing_steps + if interaction_params is None: + interaction_params = {} + self.gnn_stacks = nn.CellList([ + InteractionNetworkClass( + num_node_in=latent_dim, + num_node_out=latent_dim, + num_edge_in=latent_dim, + num_edge_out=latent_dim, + num_mlp_layers=num_mlp_layers, + mlp_hidden_dim=mlp_hidden_dim, + **interaction_params + ) for _ in range(self.num_message_passing_steps) + ]) + self._decoder = Decoder( + num_node_in=latent_dim, + num_node_out=num_node_out_features, + num_mlp_layers=num_mlp_layers, + mlp_hidden_dim=mlp_hidden_dim + ) + self.rbf = partial(gaussian_basis_function, num_bases=20, radius=10.0) + self.use_embedding = use_embedding + if self.use_embedding: + self.atom_emb = AtomEmbedding(latent_dim, 118) + + def construct(self, edge_features, node_features, senders, receivers): + """construct + """ + edge_features = self.featurize_edges(edge_features) + node_features = self.featurize_nodes(node_features) + edges, nodes = self._encoder(node_features, edge_features) + for gnn in self.gnn_stacks: + edges, nodes = gnn(edges, nodes, senders, receivers) + nodes = self._decoder(nodes) + return edges, nodes + + def featurize_nodes(self, node_features): + """Featurize the nodes of a graph. + + Args: + node_features (Dict[str, Tensor]): Dictionary of node features. + + Returns: + Dict[str, Tensor]: Updated node features with atomic embeddings. + """ + one_hot_atomic = ops.OneHot()( + node_features["atomic_numbers"], 118, Tensor(1.0), Tensor(0.0) + ) + if self.use_embedding: + atomic_embedding = self.atom_emb(node_features["atomic_numbers"]) + else: + atomic_embedding = one_hot_atomic + + node_features = {**node_features, **{_KEY: atomic_embedding}} + return node_features + + def featurize_edges(self, edge_features): + """Featurize the edges of a graph. + + Args: + edge_features (Dict[str, Tensor]): Dictionary of edge features. + + Returns: + Dict[str, Tensor]: Updated edge features with radial basis functions and unit vectors. + """ + lengths = ops.norm(edge_features['vectors'], dim=1) + non_zero_divisor = ops.where( + lengths == 0, ops.ones_like(lengths), lengths) + unit_vectors = edge_features['vectors'] / ops.expand_dims(non_zero_divisor, 1) + rbfs = self.rbf(lengths) + edges = ops.cat([rbfs, unit_vectors], axis=1) + + edge_features = {**edge_features, **{_KEY: edges}} + return edge_features diff --git a/MindChemistry/mindchemistry/cell/orb/orb.py b/MindChemistry/mindchemistry/cell/orb/orb.py new file mode 100644 index 0000000000000000000000000000000000000000..8afc55dbf5e931505ad6747b53ff9fcd4e1fbe23 --- /dev/null +++ b/MindChemistry/mindchemistry/cell/orb/orb.py @@ -0,0 +1,698 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Orb GraphRegressor.""" + +from typing import Literal, Optional, Union +import numpy + +import mindspore as ms +from mindspore import Parameter, ops, Tensor, mint + +from mindchemistry.cell.orb.gns import _KEY, MoleculeGNS +from mindchemistry.cell.orb.utils import ( + aggregate_nodes, + build_mlp, + REFERENCE_ENERGIES, +) + + +class LinearReferenceEnergy(ms.nn.Cell): + r""" + Linear reference energy (no bias term). + + This class implements a linear reference energy model that can be used + to compute the reference energy for a given set of atomic numbers. + + Args: + weight_init (numpy.ndarray, optional): Initial weights for the linear layer. + If not provided, the weights will be initialized randomly. + trainable (bool, optional): Whether the weights are trainable or not. + If not provided, the weights will be trainable by default. + + Inputs: + - **atom_types** (Tensor) - A tensor of atomic numbers of shape (n_atoms,). + - **n_node** (Tensor) - A tensor of shape (n_graphs,) containing the number of nodes in each graph. + + Outputs: + - **Tensor** - A tensor of shape (n_graphs, 1) containing the reference energy. + + Raises: + ValueError: If the input tensor shapes are not compatible with the expected shapes. + TypeError: If the input types are not compatible with the expected types. + + Supported Platforms: + ``Ascend`` + """ + def __init__( + self, + weight_init: Optional[numpy.ndarray] = None, + trainable: Optional[bool] = None, + ): + """init + """ + super().__init__() + + if trainable is None: + trainable = weight_init is None + + self.linear = ms.nn.Dense(118, 1, has_bias=False) + if weight_init is not None: + self.linear.weight.set_data(Tensor(weight_init, dtype=ms.float32).reshape(1, 118)) + if not trainable: + self.linear.weight.requires_grad = False + + def construct(self, atom_types: Tensor, n_node: Tensor): + """construct + """ + one_hot_atomic = ops.OneHot()(atom_types, 118, Tensor(1.0, ms.float32), Tensor(0.0, ms.float32)) + + reduced = aggregate_nodes(one_hot_atomic, n_node, reduction="sum") + return self.linear(reduced) + + +class ScalarNormalizer(ms.nn.Cell): + r""" + Scalar normalizer that learns mean and std from data. + + NOTE: Multi-dimensional tensors are flattened before updating + the running mean/std. This is desired behaviour for force targets. + + Args: + init_mean (Tensor or float, optional): Initial mean value for normalization. + If not provided, defaults to 0.0. + init_std (Tensor or float, optional): Initial standard deviation value for normalization. + If not provided, defaults to 1.0. + init_num_batches (int, optional): Initial number of batches for normalization. + If not provided, defaults to 1000. + + Inputs: + - **x** (Tensor) - A tensor of shape (n_samples, n_features) to normalize. + + Outputs: + - **Tensor** - A tensor of the same shape as x, normalized by the running mean and std. + + Raises: + ValueError: If the input tensor is not of the expected shape. + TypeError: If the input types are not compatible with the expected types. + + Supported Platforms: + ``Ascend`` + """ + def __init__( + self, + init_mean: Optional[Union[Tensor, float]] = None, + init_std: Optional[Union[Tensor, float]] = None, + init_num_batches: Optional[int] = 1000, + ): + """init + """ + super().__init__() + self.bn = mint.nn.BatchNorm1d(1, affine=False, momentum=None) + self.bn.running_mean = Parameter(Tensor([0], ms.float32)) + self.bn.running_var = Parameter(Tensor([1], ms.float32)) + self.bn.num_batches_tracked = Parameter(Tensor([1000], ms.float32)) + self.stastics = { + "running_mean": init_mean if init_mean is not None else 0.0, + "running_var": init_std**2 if init_std is not None else 1.0, + "num_batches_tracked": init_num_batches if init_num_batches is not None else 1000, + } + + def construct(self, x: Tensor): + """construct + """ + if self.training: + self.bn(x.view(-1, 1)) + if hasattr(self, "running_mean"): + return (x - self.running_mean) / mint.sqrt(self.running_var) + return (x - self.bn.running_mean) / mint.sqrt(self.bn.running_var) + + def inverse(self, x: Tensor): + """Reverse the construct normalization. + + Args: + x: A tensor of shape (n_samples, n_features) to inverse normalize. + + Returns: + A tensor of the same shape as x, inverse normalized by the running mean and std. + """ + if hasattr(self, "running_mean"): + return x * mint.sqrt(self.running_var) + self.running_mean + return x * mint.sqrt(self.bn.running_var) + self.bn.running_mean + + +# pylint: disable=C0301 +class NodeHead(ms.nn.Cell): + r""" + Node-level prediction head. + + Implements neural network head for predicting node-level properties from node features. This head can be + added to base models to enable auxiliary tasks during pretraining or added in fine-tuning steps. + + Args: + latent_dim (int): Input feature dimension for each node. + num_mlp_layers (int): Number of hidden layers in MLP. + mlp_hidden_dim (int): Hidden dimension size of MLP. + target_property_dim (int): Output dimension of node-level target property. + dropout (Optional[float], optional): Dropout rate for MLP. Default: ``None``. + remove_mean (bool, optional): If True, remove mean from output, typically used for force prediction. + Default: ``True``. + + Inputs: + - **node_features** (dict) - Node feature dictionary, must contain key "feat" with shape :math:`(n_{nodes}, latent\_dim)`. + - **n_node** (Tensor) - Number of nodes in graph, shape :math:`(1,)`. + + Outputs: + - **output** (dict) - Dictionary containing key "node_pred" with value of shape :math:`(n_{nodes}, target\_property\_dim)`. + + Raises: + ValueError: If required feature keys are missing in `node_features`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> from mindspore import Tensor + >>> from mindchemistry.cell.orb.gns import NodeHead + >>> node_head = NodeHead( + ... latent_dim=256, + ... num_mlp_layers=1, + ... mlp_hidden_dim=256, + ... target_property_dim=3, + ... remove_mean=True, + ... ) + >>> n_atoms = 4 + >>> n_node = Tensor([n_atoms], mindspore.int32) + >>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) + >>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) + >>> for i, num in enumerate(atomic_numbers.asnumpy()): + ... atomic_numbers_embedding_np[i, num - 1] = 1.0 + >>> node_features = { + ... "atomic_numbers": atomic_numbers, + ... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), + ... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), + ... "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) + ... } + >>> output = node_head(node_features, n_node) + >>> print(output['node_pred'].shape) + (4, 3) + """ + def __init__( + self, + latent_dim: int, + num_mlp_layers: int, + mlp_hidden_dim: int, + target_property_dim: int, + dropout: Optional[float] = None, + remove_mean: bool = True, + ): + """init + """ + super().__init__() + self.target_property_dim = target_property_dim + self.normalizer = ScalarNormalizer() + + self.mlp = build_mlp( + input_size=latent_dim, + hidden_layer_sizes=[mlp_hidden_dim] * num_mlp_layers, + output_size=self.target_property_dim, + dropout=dropout, + ) + + self.remove_mean = remove_mean + + def construct(self, node_features, n_node): + """construct + """ + feat = node_features[_KEY] + pred = self.mlp(feat) + if self.remove_mean: + system_means = aggregate_nodes( + pred, n_node, reduction="mean" + ) + node_broadcasted_means = mint.repeat_interleave( + system_means, n_node, dim=0 + ) + pred = pred - node_broadcasted_means + res = {"node_pred": pred} + return res + + def predict(self, node_features, n_node): + """Predict node-level attributes. + + Args: + node_features: Node features tensor of shape (n_nodes, latent_dim). + n_node: Number of nodes in the graph. + + Returns: + node_pred: Node-level predictions of shape (n_nodes, target_property_dim). + """ + out = self(node_features, n_node) + pred = out["node_pred"] + return self.normalizer.inverse(pred) + + +# pylint: disable=C0301 +class GraphHead(ms.nn.Cell): + r""" + Graph-level prediction head. Implements graph-level prediction head that can be attached to base models + for predicting graph-level properties (e.g., stress tensor) from node features using aggregation and MLP. + + Args: + latent_dim (int): Input feature dimension for each node. + num_mlp_layers (int): Number of hidden layers in MLP. + mlp_hidden_dim (int): Hidden dimension size of MLP. + target_property_dim (int): Output dimension of graph-level property. + node_aggregation (str, optional): Aggregation method for node predictions, e.g., ``"mean"`` or ``"sum"``. Default: ``"mean"``. + dropout (Optional[float], optional): Dropout rate for MLP. Default: ``None``. + compute_stress (bool, optional): Whether to compute and output stress tensor. Default: ``False``. + + Inputs: + - **node_features** (dict) - Node feature dictionary, must contain key "feat" with shape :math:`(n_{nodes}, latent\_dim)`. + - **n_node** (Tensor) - Number of nodes in graph, shape :math:`(1,)`. + + Outputs: + - **output** (dict) - Dictionary containing key "stress_pred" with value of shape :math:`(1, target\_property\_dim)`. + + Raises: + ValueError: If required feature keys are missing in `node_features`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> from mindspore import Tensor + >>> from mindchemistry.cell.orb.gns import GraphHead + >>> graph_head = GraphHead( + ... latent_dim=256, + ... num_mlp_layers=1, + ... mlp_hidden_dim=256, + ... target_property_dim=6, + ... compute_stress=True, + ... ) + >>> n_atoms = 4 + >>> n_node = Tensor([n_atoms], mindspore.int32) + >>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) + >>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) + >>> for i, num in enumerate(atomic_numbers.asnumpy()): + ... atomic_numbers_embedding_np[i, num - 1] = 1.0 + >>> node_features = { + ... "atomic_numbers": atomic_numbers, + ... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), + ... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), + ... "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) + ... } + >>> output = graph_head(node_features, n_node) + >>> print(output['stress_pred'].shape) + (1, 6) + """ + + def __init__( + self, + latent_dim: int, + num_mlp_layers: int, + mlp_hidden_dim: int, + target_property_dim: int, + node_aggregation: Literal["sum", "mean"] = "mean", + dropout: Optional[float] = None, + compute_stress: Optional[bool] = False, + ): + """init + """ + super().__init__() + self.target_property_dim = target_property_dim + self.normalizer = ScalarNormalizer() + + self.node_aggregation = node_aggregation + self.mlp = build_mlp( + input_size=latent_dim, + hidden_layer_sizes=[mlp_hidden_dim] * num_mlp_layers, + output_size=self.target_property_dim, + dropout=dropout, + ) + self.output_activation = ops.Identity() + self.compute_stress = compute_stress + + def construct(self, node_features, n_node): + """construct + """ + feat = node_features[_KEY] + + # aggregate to get a tensor of shape (num_graphs, latent_dim) + mlp_input = aggregate_nodes( + feat, + n_node, + reduction=self.node_aggregation, + ) + + pred = self.mlp(mlp_input) + if self.compute_stress: + # name the stress prediction differently + res = {"stress_pred": pred} + else: + res = {"graph_pred": pred} + return res + + def predict(self, node_features, n_node, atomic_numbers=None): + """Predict graph-level attributes. + + Args: + node_features: Node features tensor + n_node: Number of nodes + atomic_numbers: Optional atomic numbers for reference energy calculation + + Returns: + probs: Graph-level predictions of shape (n_graphs, target_property_dim). + If compute_stress is True, this will be the stress tensor. + If compute_stress is False, this will be the graph-level property (e.g., energy). + """ + pred = self(node_features, n_node) + if self.compute_stress: + logits = pred["stress_pred"].squeeze(-1) + else: + assert atomic_numbers is not None, "atomic_numbers must be provided for graph prediction" + logits = pred["graph_pred"].squeeze(-1) + probs = self.output_activation(logits) + probs = self.normalizer.inverse(probs) + return probs + + +# pylint: disable=C0301 +class EnergyHead(GraphHead): + r""" + Graph-level energy prediction head. + Implements neural network head for predicting total energy or per-atom average energy of molecular graphs. + Supports node-level aggregation, reference energy offset, and flexible output modes. + + Args: + latent_dim (int): Input feature dimension for each node. + num_mlp_layers (int): Number of hidden layers in MLP. + mlp_hidden_dim (int): Hidden dimension size of MLP. + target_property_dim (int): Output dimension of energy property (typically 1). + predict_atom_avg (bool, optional): Whether to predict per-atom average energy instead of total energy. Default: ``True``. + reference_energy_name (str, optional): Reference energy name for offset, e.g., ``"vasp-shifted"``. Default: ``"mp-traj-d3"``. + train_reference (bool, optional): Whether to train reference energy as learnable parameter. Default: ``False``. + dropout (Optional[float], optional): Dropout rate for MLP. Default: ``None``. + node_aggregation (str, optional): Aggregation method for node predictions, e.g., ``"mean"`` or ``"sum"``. Default: ``None``. + + Inputs: + - **node_features** (dict) - Node feature dictionary, must contain key "feat" with shape :math:`(n_{nodes}, latent\_dim)`. + - **n_node** (Tensor) - Number of nodes in graph, shape :math:`(1,)`. + + Outputs: + - **output** (dict) - Dictionary containing key "graph_pred" with value of shape :math:`(1, target\_property\_dim)`. + + Raises: + ValueError: If required feature keys are missing in `node_features`. + ValueError: If `node_aggregation` is not a supported type. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> from mindspore import Tensor + >>> from mindchemistry.cell.orb.gns import EnergyHead + >>> energy_head = EnergyHead( + ... latent_dim=256, + ... num_mlp_layers=1, + ... mlp_hidden_dim=256, + ... target_property_dim=1, + ... node_aggregation="mean", + ... reference_energy_name="vasp-shifted", + ... train_reference=True, + ... predict_atom_avg=True, + ... ) + >>> n_atoms = 4 + >>> n_node = Tensor([n_atoms], mindspore.int32) + >>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) + >>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) + >>> for i, num in enumerate(atomic_numbers.asnumpy()): + ... atomic_numbers_embedding_np[i, num - 1] = 1.0 + >>> node_features = { + ... "atomic_numbers": atomic_numbers, + ... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), + ... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), + ... "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) + ... } + >>> output = energy_head(node_features, n_node) + >>> print(output['graph_pred'].shape) + (1, 1) + """ + + def __init__( + self, + latent_dim: int, + num_mlp_layers: int, + mlp_hidden_dim: int, + target_property_dim: int, + predict_atom_avg: bool = True, + reference_energy_name: str = "mp-traj-d3", + train_reference: bool = False, + dropout: Optional[float] = None, + node_aggregation: Optional[str] = "mean", + ): + """init + """ + ref = REFERENCE_ENERGIES[reference_energy_name] + + super().__init__( + latent_dim=latent_dim, + num_mlp_layers=num_mlp_layers, + mlp_hidden_dim=mlp_hidden_dim, + target_property_dim=target_property_dim, + node_aggregation=node_aggregation, + dropout=dropout, + ) + self.reference = LinearReferenceEnergy( + weight_init=ref.coefficients, trainable=train_reference + ) + self.atom_avg = predict_atom_avg + + def predict(self, node_features, n_node, atomic_numbers=None): + """Predict energy. + + Args: + node_features: Node features tensor + n_node: Number of nodes + atomic_numbers: Optional atomic numbers for reference energy calculation + + Returns: + graph_pred: Energy prediction + """ + if atomic_numbers is None: + raise ValueError("atomic_numbers is required for energy prediction") + + pred = self(node_features, n_node)["graph_pred"] + pred = self.normalizer.inverse(pred).squeeze(-1) + if self.atom_avg: + pred = pred * n_node + pred = pred + self.reference(atomic_numbers, n_node) + return pred + + +# pylint: disable=C0301 +class Orb(ms.nn.Cell): + r""" + Orb graph regressor. + Combines a pretrained base model (e.g., MoleculeGNS) with optional node, graph, and stress regression heads, supporting + fine-tuning or feature extraction workflows. + + Args: + model (MoleculeGNS): Pretrained or randomly initialized base model for message passing and feature extraction. + node_head (NodeHead, optional): Regression head for node-level property prediction. Default: ``None``. + graph_head (GraphHead, optional): Regression head for graph-level property prediction (e.g., energy). Default: ``None``. + stress_head (GraphHead, optional): Regression head for stress prediction. Default: ``None``. + model_requires_grad (bool, optional): Whether to fine-tune the base model (True) or freeze its parameters (False). Default: ``True``. + cutoff_layers (int, optional): If provided, only use the first ``cutoff_layers`` message passing layers of the base model. + Default: ``None``. + + Inputs: + - **edge_features** (dict) - Edge feature dictionary (e.g., `{"vectors": Tensor, "r": Tensor}`). + - **node_features** (dict) - Node feature dictionary (e.g., `{"atomic_numbers": Tensor, ...}`). + - **senders** (Tensor) - Sender node indices for each edge. Shape: :math:`(n_{edges},)`. + - **receivers** (Tensor) - Receiver node indices for each edge. Shape: :math:`(n_{edges},)`. + - **n_node** (Tensor) - Number of nodes for each graph in the batch. Shape: :math:`(n_{graphs},)`. + + Outputs: + - **output** (dict) - Dictionary containing: + - **edges** (dict) - Edge features after message passing, e.g., `{..., "feat": Tensor}`. + - **nodes** (dict) - Node features after message passing, e.g., `{..., "feat": Tensor}`. + - **graph_pred** (Tensor) - Graph-level predictions, e.g., energy. Shape: :math:`(n_{graphs}, target\_property\_dim)`. + - **node_pred** (Tensor) - Node-level predictions. Shape: :math:`(n_{nodes}, target\_property\_dim)`. + - **stress_pred** (Tensor) - Stress predictions (if stress_head is provided). Shape: :math:`(n_{graphs}, 6)`. + + Raises: + ValueError: If neither node_head nor graph_head is provided. + ValueError: If cutoff_layers exceeds the number of message passing steps in the base model. + ValueError: If atomic_numbers is not provided when graph_head is required. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> from mindspore import Tensor + >>> from mindchemistry.cell.orb import Orb, MoleculeGNS, EnergyHead, NodeHead, GraphHead + >>> Orb = Orb( + ... model=MoleculeGNS( + ... num_node_in_features=256, + ... num_node_out_features=3, + ... num_edge_in_features=23, + ... latent_dim=256, + ... interactions="simple_attention", + ... interaction_params={ + ... "distance_cutoff": True, + ... "polynomial_order": 4, + ... "cutoff_rmax": 6, + ... "attention_gate": "sigmoid", + ... }, + ... num_message_passing_steps=15, + ... num_mlp_layers=2, + ... mlp_hidden_dim=512, + ... use_embedding=True, + ... node_feature_names=["feat"], + ... edge_feature_names=["feat"], + ... ), + ... graph_head=EnergyHead( + ... latent_dim=256, + ... num_mlp_layers=1, + ... mlp_hidden_dim=256, + ... target_property_dim=1, + ... node_aggregation="mean", + ... reference_energy_name="vasp-shifted", + ... train_reference=True, + ... predict_atom_avg=True, + ... ), + ... node_head=NodeHead( + ... latent_dim=256, + ... num_mlp_layers=1, + ... mlp_hidden_dim=256, + ... target_property_dim=3, + ... remove_mean=True, + ... ), + ... stress_head=GraphHead( + ... latent_dim=256, + ... num_mlp_layers=1, + ... mlp_hidden_dim=256, + ... target_property_dim=6, + ... compute_stress=True, + ... ), + ... ) + >>> n_atoms = 4 + >>> n_edges = 10 + >>> n_node = Tensor([n_atoms], mindspore.int32) + >>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) + >>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) + >>> for i, num in enumerate(atomic_numbers.asnumpy()): + ... atomic_numbers_embedding_np[i, num - 1] = 1.0 + >>> node_features = { + ... "atomic_numbers": atomic_numbers, + ... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), + ... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)) + ... } + >>> edge_features = { + ... "vectors": Tensor(np.random.randn(n_edges, 3).astype(np.float32)), + ... "r": Tensor(np.abs(np.random.randn(n_edges).astype(np.float32) * 10)) + ... } + >>> senders = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) + >>> receivers = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) + >>> output = Orb(edge_features, node_features, senders, receivers, n_node) + >>> print(output['graph_pred'].shape, output['node_pred'].shape, output['stress_pred'].shape) + (1, 1) (4, 3) (1, 6) + """ + + def __init__( + self, + model: MoleculeGNS, + node_head: Optional[NodeHead] = None, + graph_head: Optional[GraphHead] = None, + stress_head: Optional[GraphHead] = None, + model_requires_grad: bool = True, + cutoff_layers: Optional[int] = None, + ): + """init + """ + super().__init__() + + if (node_head is None) and (graph_head is None): + raise ValueError("Must provide at least one node/graph head.") + self.node_head = node_head + self.graph_head = graph_head + self.stress_head = stress_head + self.cutoff_layers = cutoff_layers + + self.model = model + + if self.cutoff_layers is not None: + if self.cutoff_layers > self.model.num_message_passing_steps: + raise ValueError( + f"cutoff_layers ({self.cutoff_layers}) must be less than or equal to" + f" the number of message passing steps ({self.model.num_message_passing_steps})" + ) + self.model.gnn_stacks = self.model.gnn_stacks[: self.cutoff_layers] + self.model.num_message_passing_steps = self.cutoff_layers + + self.model_requires_grad = model_requires_grad + + if not model_requires_grad: + for param in self.model.parameters(): + param.requires_grad = False + + + def predict(self, edge_features, node_features, senders, receivers, n_node, atomic_numbers): + """Predict node and/or graph level attributes. + + Args: + edge_features: A dictionary, e.g., `{"vectors": Tensor, "r": Tensor}`. + node_features: A dictionary, e.g., `{"atomic_numbers": Tensor, "positions": Tensor, + "atomic_numbers_embedding": Tensor}`. + senders: A tensor of shape (n_edges,) containing the sender node indices. + receivers: A tensor of shape (n_edges,) containing the receiver node indices. + n_node: A tensor of shape (1,) containing the number of nodes. + atomic_numbers: A tensor of atomic numbers for reference energy calculation. + + Returns: + ouput_dict: A dictionary containing the predictions: + - `graph_pred`: Graph-level predictions (e.g., energy) of shape (n_graphs, graph_property_dim). + - `stress_pred`: Stress predictions (if stress_head is provided) of shape (n_graphs, stress_dim). + - `node_pred`: Node-level predictions of shape (n_nodes, node_property_dim). + """ + _, nodes = self.model(edge_features, node_features, senders, receivers) + + output = {} + output["graph_pred"] = self.graph_head.predict(nodes, n_node, atomic_numbers) + output["stress_pred"] = self.stress_head.predict(nodes, n_node) + output["node_pred"] = self.node_head.predict(nodes, n_node) + + return output + + def construct(self, edge_features, node_features, senders, receivers, n_node): + """construct + """ + edges, nodes = self.model(edge_features, node_features, senders, receivers) + + res = {"edges": edges, "nodes": nodes} + res.update(self.graph_head(nodes, n_node)) + res.update(self.stress_head(nodes, n_node)) + res.update(self.node_head(nodes, n_node)) + + return res diff --git a/MindChemistry/mindchemistry/cell/orb/utils.py b/MindChemistry/mindchemistry/cell/orb/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..26797d9df20cc84e83ce9a20ebf63653fc2be33b --- /dev/null +++ b/MindChemistry/mindchemistry/cell/orb/utils.py @@ -0,0 +1,737 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utils.""" + +from typing import NamedTuple, List, Optional, Type + +import numpy as np +import mindspore as ms +from mindspore import nn, ops, Tensor, mint, context + +MSINT = [ms.int64, ms.int32, ms.int16, ms.int8, ms.uint8] + + +def aggregate_nodes(tensor: Tensor, n_node: Tensor, reduction: str = "mean", deterministic: bool = False) -> Tensor: + """Aggregates over a tensor based on graph sizes.""" + count = len(n_node) + if deterministic: + ms.set_seed(1) + segments = ops.arange(count).repeat_interleave(n_node).astype(ms.int32) + if reduction == "sum": + return scatter_sum(tensor, segments, dim=0) + if reduction == "mean": + return scatter_mean(tensor, segments, dim=0) + if reduction == "max": + return scatter_max(tensor, segments, dim=0) + raise ValueError("Invalid reduction argument. Use sum, mean or max.") + + +def segment_sum(data: Tensor, segment_ids: Tensor, num_segments: int): + """Computes index based sum over segments of a tensor.""" + return scatter_sum(data, segment_ids, dim=0, dim_size=num_segments) + + +def segment_max(data: Tensor, segment_ids: Tensor, num_segments: int): + """Computes index based max over segments of a tensor.""" + assert segment_ids is not None, "segment_ids must not be None" + assert num_segments > 0, "num_segments must be greater than 0" + max_op = ops.ArgMaxWithValue(axis=0) + _, max_values = max_op(data) + return max_values + + +def segment_mean(data: Tensor, segment_ids: Tensor, num_segments: int): + """Computes index based mean over segments of a tensor.""" + sum_v = segment_sum(data, segment_ids, num_segments) + count = ops.scatter_add(ops.zeros( + (num_segments,), dtype=ms.int32), segment_ids, ops.ones_like(segment_ids)) + return sum_v / count.astype(sum_v.dtype) + + +def segment_softmax(data: Tensor, segment_ids: Tensor, num_segments: int, weights: Optional[Tensor] = None): + """Computes a softmax over segments of the tensor.""" + data_max = segment_max(data, segment_ids, num_segments) + data = data - data_max[segment_ids] + + unnormalised_probs = ops.exp(data) + if weights is not None: + unnormalised_probs = unnormalised_probs * weights + denominator = segment_sum(unnormalised_probs, segment_ids, num_segments) + + return safe_division(unnormalised_probs, denominator, segment_ids) + + +def safe_division(numerator: Tensor, denominator: Tensor, segment_ids: Tensor): + """Divides logits by denominator, setting 0 where the denominator is zero.""" + result = ops.where(denominator[segment_ids] == + 0, 0, numerator / denominator[segment_ids]) + return result + + +def _broadcast(src: Tensor, other: Tensor, dim: int): + """Broadcasts the source tensor to match the shape of the other tensor along the specified dimension.""" + if dim < 0: + dim = other.ndim + dim + if src.ndim == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.ndim, other.ndim): + src = src.unsqueeze(-1) + src = src.expand_as(other) + return src + + +def scatter_sum( + src: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, + dim_size: Optional[int] = None, reduce: str = "sum" +) -> Tensor: + """Applies a sum reduction of the orb_models tensor along the specified dimension.""" + assert reduce == "sum" + index = _broadcast(index, src, dim) + if out is None: + size = list(src.shape) + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = ops.zeros(size, dtype=src.dtype) + return mint.scatter_add(out, dim, index, src) + return mint.scatter_add(out, dim, index, src) + + +def scatter_std( + src: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, + dim_size: Optional[int] = None, unbiased: bool = True +) -> Tensor: + """Computes the standard deviation of the orb_models tensor along the specified dimension.""" + if out is not None: + dim_size = out.shape[dim] + + if dim < 0: + dim = src.ndim + dim + + count_dim = dim + if index.ndim <= dim: + count_dim = index.ndim - 1 + + ones = ops.ones(index.shape, dtype=src.dtype) + count = scatter_sum(ones, index, count_dim, dim_size=dim_size) + + index = _broadcast(index, src, dim) + tmp = scatter_sum(src, index, dim, dim_size=dim_size) + count = _broadcast(count, tmp, dim).clip(1) + mean = tmp / count + + var = src - mean.gather(dim, index) + var = var * var + out = scatter_sum(var, index, dim, out=out, dim_size=dim_size) + + if unbiased: + count = count - 1 + count = count.clip(1) + out = out / (count + 1e-6) + out = ops.sqrt(out) + return out + + +def scatter_mean( + src: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, + dim_size: Optional[int] = None +) -> Tensor: + """Computes the mean of the orb_models tensor along the specified dimension.""" + out = scatter_sum(src, index, dim, out=out, dim_size=dim_size) + dim_size = out.shape[dim] + + index_dim = dim + if index_dim < 0: + index_dim = index_dim + src.ndim + if index.ndim <= index_dim: + index_dim = index.ndim - 1 + + ones = ops.ones(index.shape, dtype=src.dtype) + count = scatter_sum(ones, index, index_dim, dim_size=dim_size) + count = count.clip(1) + count = _broadcast(count, out, dim) + out = out / count + return out + + +def scatter_max( + src: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, + dim_size: Optional[int] = None +) -> Tensor: + """Computes the maximum of the orb_models tensor for each group defined by index along the specified dimension.""" + if out is not None: + raise NotImplementedError( + "The 'out' argument is not supported for scatter_max") + + if src.dtype in MSINT: + init_value = np.iinfo(src.dtype).min + else: + init_value = np.finfo(src.dtype).min + + if dim < 0: + dim = src.ndim + dim + + if dim_size is None: + dim_size = int(index.max()) + 1 + + result = ops.ones( + (dim_size, *src.shape[:dim], *src.shape[dim + 1:]), dtype=src.dtype) + result = init_value * result + broadcasted_index = _broadcast(index, src, dim) + + scatter_result = ops.ZerosLike()(result) + index = ops.expand_dims(broadcasted_index, dim) + scatter_result = scatter_result.scatter_update(index, src) + result = ops.Maximum()(result, scatter_result) + return result + + +class SSP(nn.Cell): + """Shifted Softplus activation function. + + This activation is twice differentiable so can be used when regressing + gradients for conservative force fields. + """ + + def __init__(self, beta: int = 1, threshold: int = 20): + super().__init__() + self.beta = beta + self.threshold = threshold + + def construct(self, input_x: Tensor) -> Tensor: + sp0 = ops.softplus(ops.zeros(1), self.beta, self.threshold) + return ops.softplus(input_x, self.beta, self.threshold) - sp0 + + +def build_mlp( + input_size: int, + hidden_layer_sizes: List[int], + output_size: Optional[int] = None, + output_activation: Type[nn.Cell] = nn.Identity, + activation: Type[nn.Cell] = SSP, + dropout: Optional[float] = None, +) -> nn.Cell: + """Build a MultiLayer Perceptron. + + Args: + input_size: Size of input layer. + hidden_layer_sizes: An array of input size for each hidden layer. + output_size: Size of the output layer. + output_activation: Activation function for the output layer. + activation: Activation function for the hidden layers. + dropout: Dropout rate for hidden layers. + checkpoint: Whether to use checkpointing. + + Returns: + mlp: An MLP sequential container. + """ + # Size of each layer + layer_sizes = [input_size] + hidden_layer_sizes + if output_size: + layer_sizes.append(output_size) + + # Number of layers + nlayers = len(layer_sizes) - 1 + + # Create a list of activation functions and + # set the last element to output activation function + act = [activation for _ in range(nlayers)] + act[-1] = output_activation + + # Create a list to hold layers + layers = [] + for i in range(nlayers): + if dropout is not None: + layers.append(nn.Dropout(keep_prob=1 - dropout)) + layers.append(nn.Dense(layer_sizes[i], layer_sizes[i + 1])) + layers.append(act[i]()) + + # Create a sequential container + mlp = nn.SequentialCell(layers) + return mlp + + +class CheckpointedSequential(nn.Cell): + """Sequential container with checkpointing.""" + + def __init__(self, *args, n_layers: int = 1): + super().__init__() + self.n_layers = n_layers + self.layers = nn.CellList(list(args)) + + def construct(self, input_x: Tensor) -> Tensor: + """Forward pass with checkpointing enabled in training mode.""" + if context.get_context("mode") == context.GRAPH_MODE: + # In graph mode, checkpointing is handled by MindSpore's graph optimization + for layer in self.layers: + input_x = layer(input_x) + else: + # In PyNative mode, we can manually checkpoint each layer + for i in range(self.n_layers): + input_x = self.layers[i](input_x) + return input_x + + +class ReferenceEnergies(NamedTuple): + """ + Reference energies for an atomic system. + + Our vasp reference energies are computed by running vasp + optimisations on a single atom of each atom-type. + + Other reference energies are fitted using least-squares. + + Doing so with mp-traj-d3 gives the following: + + ---------- LSTQ ---------- + Reference MAE: 13.35608855004781 + (energy - ref) mean: 1.3931169304958624 + (energy - ref) std: 22.45615276341948 + (energy - ref)/natoms mean: 0.16737045963056316 + (energy - ref)/natoms std: 0.8189314920219992 + CO2: Predicted vs DFT: -23.154158610392408 vs -22.97 + H2O: Predicted vs DFT: -11.020918107591324 vs - 14.23 + ---------- VASP ---------- + Reference MAE: 152.4722089438871 + (energy - ref) mean: -152.47090833346033 + (energy - ref) std: 153.89049784836962 + (energy - ref)/natoms mean: -4.734136414817941 + (energy - ref)/natoms std: 1.3603868419157275 + CO2: Predicted vs DFT: -4.35888857 vs -22.97 + H2O: Predicted vs DFT: -2.66521147 vs - 14.23 + ---------- Shifted VASP ---------- + Reference MAE: 28.95948216608197 + (energy - ref) mean: 0.7083632520428979 + (energy - ref) std: 48.61861182844561 + (energy - ref)/natoms mean: 0.17320099403091083 + (energy - ref)/natoms std: 1.3603868419157275 + CO2: Predicted vs DFT: -19.080900796546562 vs -22.97 + H2O: Predicted vs DFT: -12.479886287697706 vs - 14.23 + + Args: + coefficients: Coefficients for each atom in the periodic table. + Must be of length 118 with first entry equal to 0. + residual_mean: Mean of (pred - target) + residual_std: Standard deviation of (pred - target) + residual_mean_per_atom: Mean of (pred - target)/n_atoms. + residual_std_per_atom: Standard deviation of (pred - target)/n_atoms. + """ + + coefficients: np.ndarray + residual_mean: float + residual_std: float + residual_mean_per_atom: float + residual_std_per_atom: float + + +# We have only computed these for the first +# 88 elements, and padded the remainder with 0. +vasp_reference_energies = ReferenceEnergies( + coefficients=np.array( + [ + 0.0, # padding + -1.11725225e00, + 7.69290000e-04, + -3.22788480e-01, + -4.47021900e-02, + -2.90627280e-01, + -1.26297013e00, + -3.12415058e00, + -1.54795922e00, + -4.39757050e-01, + -1.25673900e-02, + -2.63927430e-01, + -1.92670300e-02, + -2.11267040e-01, + -8.24799500e-01, + -1.88734631e00, + -8.91048980e-01, + -2.58371430e-01, + -2.50008000e-02, + -2.71936150e-01, + -7.11147600e-02, + -2.06076796e00, + -2.42753196e00, + -3.57144559e00, + -5.45540047e00, + -5.15708214e00, + -3.31393675e00, + -1.84639284e00, + -6.32812480e-01, + -2.38017450e-01, + -1.41047600e-02, + -2.06349980e-01, + -7.77477960e-01, + -1.70160351e00, + -7.84231510e-01, + -2.27541260e-01, + -2.26104900e-02, + -2.79760570e-01, + -9.92851900e-02, + -2.18560872e00, + -2.26603086e00, + -3.14842282e00, + -4.61199158e00, + -3.34329762e00, + -2.48233722e00, + -1.27872811e00, + -1.47784242e00, + -2.04068960e-01, + -1.89639300e-02, + -1.88520140e-01, + -6.76700640e-01, + -1.42966694e00, + -6.57608340e-01, + -1.89308030e-01, + -1.20491300e-02, + -3.07991050e-01, + -1.58601400e-01, + -4.89728600e-01, + -1.35031403e00, + -3.31509450e-01, + -3.23660410e-01, + -3.15316610e-01, + -3.11184530e-01, + -8.44684689e00, + -1.04408371e01, + -2.30922790e-01, + -2.26295040e-01, + -2.92747580e-01, + -2.92191740e-01, + -2.91465170e-01, + -3.80611000e-02, + -2.87691040e-01, + -3.51528971e00, + -3.51343142e00, + -4.64232388e00, + -2.88816624e00, + -1.46089612e00, + -5.36042350e-01, + -1.87182020e-01, + -1.33549100e-02, + -1.68142250e-01, + -6.25378750e-01, + -1.32291753e00, + -3.26246040e-01, + -1.10239294e00, + -2.30839543e00, + -4.61968511e00, + -7.30638139e00, + -1.04613411e01, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + ] + ), + residual_mean=-152.47090833346033, + residual_std=153.89049784836962, + residual_mean_per_atom=-4.734136414817941, + residual_std_per_atom=1.3603868419157275, +) + +vasp_shifted_reference_energies = ReferenceEnergies( + coefficients=np.array( + [ + 0.0, # padding + -6.0245896588488534, + -4.9065681188488535, + -5.230125888848853, + -4.952039598848853, + -5.197964688848853, + -6.170307538848854, + -8.031487988848854, + -6.455296628848854, + -5.347094458848853, + -4.919904798848854, + -5.171264838848853, + -4.9266044388488535, + -5.118604448848854, + -5.732136908848854, + -6.794683718848853, + -5.798386388848853, + -5.165708838848854, + -4.932338208848853, + -5.179273558848854, + -4.978452168848854, + -6.968105368848853, + -7.334869368848853, + -8.478782998848853, + -10.362737878848854, + -10.064419548848853, + -8.221274158848853, + -6.7537302488488535, + -5.540149888848854, + -5.145354858848854, + -4.921442168848854, + -5.113687388848853, + -5.684815368848853, + -6.6089409188488535, + -5.691568918848853, + -5.134878668848853, + -4.929947898848853, + -5.187097978848853, + -5.006622598848853, + -7.092946128848853, + -7.173368268848853, + -8.055760228848854, + -9.519328988848853, + -8.250635028848853, + -7.389674628848853, + -6.186065518848854, + -6.3851798288488535, + -5.111406368848853, + -4.9263013388488535, + -5.095857548848853, + -5.5840380488488535, + -6.337004348848853, + -5.564945748848854, + -5.096645438848854, + -4.919386538848854, + -5.2153284588488535, + -5.065938808848854, + -5.397066008848854, + -6.257651438848853, + -5.238846858848854, + -5.230997818848854, + -5.2226540188488535, + -5.218521938848854, + -13.354184298848853, + -15.348174508848853, + -5.138260198848854, + -5.133632448848854, + -5.200084988848854, + -5.199529148848853, + -5.198802578848854, + -4.945398508848854, + -5.195028448848854, + -8.422627118848853, + -8.420768828848853, + -9.549661288848853, + -7.795503648848854, + -6.368233528848854, + -5.443379758848853, + -5.094519428848853, + -4.920692318848854, + -5.075479658848853, + -5.532716158848854, + -6.230254938848853, + -5.2335834488488535, + -6.009730348848853, + -7.2157328388488535, + -9.527022518848852, + -12.213718798848854, + -15.368678508848854, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + ] + ), + residual_mean=0.7083632520428979, + residual_std=48.61861182844561, + residual_mean_per_atom=0.17320099403091083, + residual_std_per_atom=1.3603868419157275, +) + +mp_traj_d3_reference_energies = ReferenceEnergies( + coefficients=np.array( + [ + 0.0, # padding + -3.6818229500085327, + -1.3199148098871394, + -3.688797198716366, + -4.938608191337134, + -7.901604711660046, + -8.475968295226822, + -7.42601366967988, + -7.339095157582792, + -4.9239197309790725, + -0.061236726924086424, + -3.0526401941340806, + -3.0836199809602105, + -5.055909838526647, + -7.875649504560413, + -7.175538036602013, + -4.814514763424572, + -2.9198, + -0.13127266880110078, + -2.8792125576832865, + -5.635016298424046, + -8.164720105254204, + -10.712143655281858, + -9.00292017736733, + -9.619640942931085, + -8.610981088341331, + -7.3506162257219385, + -5.943664565392655, + -5.592846831852426, + -3.6868017794232077, + -1.579885044321145, + -3.744040760877656, + -4.945137332817033, + -4.2021571924020655, + -4.045303645442562, + -2.652667661940346, + 6.497305115069106, + -2.806819346028444, + -5.164089337915934, + -10.493037547114369, + -12.256967896681578, + -12.642602087796805, + -9.20874164629371, + -9.292405362859506, + -8.304141715175632, + -7.49355696426791, + -5.44150554776011, + -2.5621691409635474, + -0.9687174918829102, + -3.055905969721681, + -4.02975498585447, + -3.847125028451477, + -3.1016305514702203, + -1.8001556831915142, + 9.742275211909387, + -3.045410331644577, + -5.204088972093178, + -9.267561428901118, + -9.031458669303145, + -8.345252241333469, + -8.584977779192705, + -7.955970517402418, + -8.519743221802353, + -13.927799873369949, + -19.12242499580686, + -8.156787154342183, + -8.505944162624234, + -8.015433843487497, + -7.129355408977684, + -8.166165621829014, + -3.9995952334750644, + -7.884852034766514, + -13.281575162667238, + -14.598283494757041, + -9.729591400065184, + -11.798570715867179, + -9.878207068760076, + -7.891075131963705, + -5.964524120587406, + -2.9665634245721275, + -0.10530075207060852, + -2.649420791761001, + -4.00193074336809, + -3.7403644338639785, + -1.5543122344752192e-15, + -8.881784197001252e-16, + -8.881784197001252e-16, + 0.0, + 0.0, + -5.480602125607218, + -11.9439263006771, + -12.974770001312883, + -14.376719109855834, + -15.49262474740642, + -16.02533150334938, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ] + ), + residual_mean=1.3931169304958624, + residual_std=22.45615276341948, + residual_mean_per_atom=0.16737045963056316, + residual_std_per_atom=0.8189314920219992, +) + +REFERENCE_ENERGIES = { + "vasp": vasp_reference_energies, + "vasp-shifted": vasp_shifted_reference_energies, + "mp-traj-d3": mp_traj_d3_reference_energies, +} diff --git a/MindEarth/applications/earthquake/G-TEAM/README.md b/MindEarth/applications/earthquake/G-TEAM/README.md index fde491dff2bd91a0151e4fc01a75b4fea21a6352..fc37db4f345362b658769cb6fc9c58b11ca03070 100644 --- a/MindEarth/applications/earthquake/G-TEAM/README.md +++ b/MindEarth/applications/earthquake/G-TEAM/README.md @@ -32,7 +32,7 @@ The model is trained using the [Diting Dataset 2.0 - Multifunctional Large AI Tr - Only retains initial P-wave and S-wave phases - Includes events recorded by ≥3 stations for reliability -The inference module has been open-sourced and supports prediction using provided checkpoint files (.ckpt). +This model has fully open-sourced both the inference and training modules. For the inference part, the provided [ckpt](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/G-TEAM/) is used for inference, while the training part utilizes the provided [hdf5](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/G-TEAM/) and [pkl](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/G-TEAM/) files for training. ## Quick Start @@ -41,6 +41,9 @@ You can download the required data and ckpt files for training and inference at ### Execution Run via command line using the `main` script: +It is necessary to configure the istraining parameter in the config.yaml file in advance to set up inference or training: +istraining: false -- Inference +istraining: true -- Training ```python python main.py --cfg_path ./config/config.yaml --device_id 0 --device_target Ascend @@ -52,13 +55,15 @@ Parameters: --device_target: Hardware type (default: Ascend) --device_id: Device ID (default: 0) +### Inference + ### Visualization ![](./images/pga.png) Scatter plot compares predicted vs actual PGA values (x-axis vs y-axis). Closer alignment to y=x line indicates higher accuracy. -### 结果展示 +### Results Presentation | Parameter | NPU | |:----------------------:|:--------------------------:| @@ -73,8 +78,15 @@ Scatter plot compares predicted vs actual PGA values (x-axis vs y-axis). Closer | Inference Resource | 1NPU | | Inference Speed(ms/step) | 556 | +### Training + +### Results Presentation + +![](./images/train_loss.png) +Under normal circumstances, the Average Training Loss should continue to converge. + ## Contributors -gitee id: chengjie, longjundong, xujiabao, dinghongyang, funfunplus +gitee id: xujiabao, longjundong, dinghongyang, chengjie email: funniless@163.com \ No newline at end of file diff --git a/MindEarth/applications/earthquake/G-TEAM/README_CN.md b/MindEarth/applications/earthquake/G-TEAM/README_CN.md index 5bb7eff96f7929d3d92326666e23552f37e2c909..3475aa849629194b4375bac120ce2ec3225ecdc4 100644 --- a/MindEarth/applications/earthquake/G-TEAM/README_CN.md +++ b/MindEarth/applications/earthquake/G-TEAM/README_CN.md @@ -15,7 +15,7 @@ 本模型的训练数据来源于[谛听数据集2.0 -中国地震台网多功能大型人工智能训练数据集](http://www.esdc.ac.cn/article/137),该数据集汇集了中国大陆及其邻近地区(15°-50°N,65°-140°E)1177 个中国地震台网固定台站的波形记录,覆盖时间范围为 2020 年 3 月至 2023 年 2 月。数据集包含研究区域内所有震级大于 0 的地方震事件,共计 264,298 个。我们在训练过程中仅选取了初至 P 波和 S 波震相,并且只保留至少被三个台站记录到的地震事件,以确保数据的可靠性和稳定性。 -目前本模型已开源推理部分,可使用提供的[ckpt](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/G-TEAM/)进行推理。 +本模型已全部开源推理和训练模块,其中推理部分使用提供的[ckpt](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/G-TEAM/)进行推理,训练部分使用提供的[hdf5](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/G-TEAM/)和[pkl](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/G-TEAM/)进行训练。 ## 快速开始 @@ -23,7 +23,9 @@ ### 运行方式: 在命令行调用`main`脚本 -### 推理 +需提前在config.yaml中配置istraining参数设定推理/训练 +istraining: false -- 推理 +istraining: true -- 训练 ```python @@ -33,7 +35,9 @@ python main.py --cfg_path ./config/config.yaml --device_id 0 --device_target Asc 其中, --cfg_path表示配置文件路径,默认值"./config/config.yaml" --device_target 表示设备类型,默认Ascend。 --device_id 表示运行设备的编号,默认值0。 -### 结果可视化 +### 推理 + +### 可视化结果 ![](./images/pga.png) @@ -54,8 +58,15 @@ python main.py --cfg_path ./config/config.yaml --device_id 0 --device_target Asc | 推理资源 | 1NPU | | 推理速度(ms/step) | 556 | +### 训练 + +### 结果展示 + +![](./images/train_loss.png) +正常情况Average Training Loss会持续收敛。 + ## 贡献者 -gitee id: chengjie, longjundong, xujiabao, dinghongyang, funfunplus +gitee id: xujiabao, longjundong, dinghongyang, chengjie email: funniless@163.com \ No newline at end of file diff --git a/MindEarth/applications/earthquake/G-TEAM/config/GTEAM.yaml b/MindEarth/applications/earthquake/G-TEAM/config/GTEAM.yaml index 68c066929f8c5503703691f7f9494f86acc29fa6..0faf89c09619b16e1e69f0b67c1905c0d2696249 100644 --- a/MindEarth/applications/earthquake/G-TEAM/config/GTEAM.yaml +++ b/MindEarth/applications/earthquake/G-TEAM/config/GTEAM.yaml @@ -1,4 +1,6 @@ -model: +model: + istraining: false + use_mlp: False hidden_dim: 1000 hidden_dropout: 0.0 n_heads: 10 @@ -13,6 +15,7 @@ model: pga: true mode: test no_event_token : False + max_stations: 5 data: root_dir: "./dataset" batch_size: 64 @@ -35,6 +38,39 @@ data: waveform_shape: [3000, 3] overwrite_sampling_rate: None noise_seconds: 5 +training_params: + seed: 42 + clipnorm: 1.0 + data_path: ./diting2_2020-2022_sc_abridged.hdf5 + ensemble_rotation: true + epochs_full_model: 100 + epochs_single_station: 5 + filter_single_station_by_pick: true + generator_params: + - batch_size: 1 + cutout_end: 25 + cutout_start: -1 + disable_station_foreshadowing: true + key: Mag + magnitude_resampling: 1.5 + min_upsample_magnitude: 4 + pga_from_inactive: true + pga_key: pga + pga_selection_skew: 1000 + pos_offset: [30,102] + scale_metadata: false + selection_skew: 1000 + shuffle_train_dev: true + transform_target_only: false + translate: false + trigger_based: true + upsample_high_station_events: 10 + loss_weights: + location: 1 + magnitude: 0.3 + pga: 1 + lr: 1e-5 + workers: 1 summary: summary_dir: "./summary" - ckpt_path: "./dataset/ckpt/g_team.ckpt" + ckpt_path: "./dataset/ckpt/g_team.ckpt" \ No newline at end of file diff --git a/MindEarth/applications/earthquake/G-TEAM/images/train_loss.png b/MindEarth/applications/earthquake/G-TEAM/images/train_loss.png new file mode 100644 index 0000000000000000000000000000000000000000..77ee77fa11f4db7eec0504a01a2e0acbac987697 Binary files /dev/null and b/MindEarth/applications/earthquake/G-TEAM/images/train_loss.png differ diff --git a/MindEarth/applications/earthquake/G-TEAM/main.py b/MindEarth/applications/earthquake/G-TEAM/main.py index d8237f8eaa395f327859e53cb21566320c7a8ad6..2e12e479f90e8dbc6732172c1e66bbc4b59a56bf 100644 --- a/MindEarth/applications/earthquake/G-TEAM/main.py +++ b/MindEarth/applications/earthquake/G-TEAM/main.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"main function" +"""main function""" import argparse import mindspore as ms @@ -20,7 +20,7 @@ from mindspore import context from mindearth import load_yaml_config, make_dir from src.utils import init_model, get_logger -from src.forcast import GTeamInference +from src.forcast import GTeamInference, GTeamTrain def get_args(): @@ -32,7 +32,6 @@ def get_args(): parse_args = parser.parse_args() return parse_args - def test(cfg): """main test""" save_dir = cfg["summary"].get("summary_dir", "./summary") @@ -43,9 +42,22 @@ def test(cfg): processor.test() -if __name__ == "__main__": +def train(cfg): + """main train""" + save_dir = cfg["summary"].get("summary_dir", "./summary") + make_dir(save_dir) + model = init_model(cfg) + logger_obj = get_logger(cfg) + processor = GTeamTrain(model, cfg, save_dir, logger_obj) + processor.train() + + +if __name__ == '__main__': args = get_args() config = load_yaml_config(args.cfg_path) context.set_context(mode=ms.PYNATIVE_MODE) ms.set_device(device_target=args.device_target, device_id=args.device_id) - test(config) + if config['model']['istraining']: + train(config) + else: + test(config) diff --git a/MindEarth/applications/earthquake/G-TEAM/src/data.py b/MindEarth/applications/earthquake/G-TEAM/src/data.py index 80ee3066718e010224d3117064464defccb3dd46..52abb64ec4ca01c4bf1f9ea6410cccceb85932ee 100644 --- a/MindEarth/applications/earthquake/G-TEAM/src/data.py +++ b/MindEarth/applications/earthquake/G-TEAM/src/data.py @@ -19,7 +19,7 @@ import glob import h5py import numpy as np -import mindspore +import mindspore as ms from mindspore.dataset import Dataset # degrees to kilometers @@ -33,14 +33,6 @@ def load_pickle_data(filename): print(f"Data loaded from {filename}") return data - -def save_pickle_data(filename, data): - """Serialize and save data to a pickle file.""" - with open(filename, "wb") as file: - pickle.dump(data, file) - print(f"Data saved to {filename}") - - def load_data(cfg): """Load preprocessed seismic data from a configured pickle file.""" data_path = glob.glob(os.path.join(cfg["data"].get("root_dir"), "*.hdf5"))[0] @@ -74,6 +66,113 @@ def detect_location_keys(columns): return coord_keys +class DataGenerator(Dataset): + """ + A PyTorch Dataset subclass for generating earthquake detection training data. + Handles loading, preprocessing, and batching of seismic waveform data. + """ + def __init__(self, data_path, event_metadata_index, event_key, mag_key='M_J', batch_size=32, cutout=None, + sliding_window=False, windowlen=3000, shuffle=True, label_smoothing=False, decimate=1): + """ + Initialize the data generator. + + Args: + data_path (str): Path to the HDF5 file containing seismic data. + event_metadata_index (pd.DataFrame): DataFrame containing event metadata indices. + event_key (str): Column name in metadata used to identify events. + mag_key (str, optional): Column name containing magnitude values. Defaults to 'M_J'. + batch_size (int, optional): Number of samples per batch. Defaults to 32. + cutout (tuple, optional): Time window for data augmentation. Defaults to None. + sliding_window (bool, optional): Use sliding window for cutouts. Defaults to False. + windowlen (int, optional): Length of time window for analysis. Defaults to 3000. + shuffle (bool, optional): Shuffle data during epoch. Defaults to True. + label_smoothing (bool, optional): Apply label smoothing. Defaults to False. + decimate (int, optional): Decimation factor for waveform data. Defaults to 1. + """ + super().__init__() + self.data_path = data_path + self.event_metadata_index = event_metadata_index + self.event_key = event_key + self.batch_size = batch_size + self.mag_key = mag_key + self.cutout = cutout + self.sliding_window = sliding_window + self.windowlen = windowlen + self.shuffle = shuffle + self.label_smoothing = label_smoothing + self.decimate = decimate + self.indexes = np.arange(len(self.event_metadata_index)) + self.on_epoch_end() + + def __len__(self): + """ + Return the number of batches in the dataset. + + Returns: + int: Number of batches = total samples / batch size (floor division) + """ + return int(np.floor(len(self.event_metadata_index) / self.batch_size)) + + def __getitem__(self, index): + """ + Get a batch of data by index. + + Args: + index (int): Batch index + + Returns: + tuple: (x, y) where X is input tensor, y is target tensor + """ + indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size] + batch_metadata = self.event_metadata_index.iloc[indexes] + x, y = self.__data_generation(batch_metadata) + return x, y + + def on_epoch_end(self): + """ + Called when an epoch ends. Resets indexes and shuffles if required. + """ + self.indexes = np.arange(len(self.event_metadata_index)) + if self.shuffle: + np.random.shuffle(self.indexes) + + def __data_generation(self, batch_metadata): + """ + Generate data for a batch of events. + + Args: + batch_metadata (pd.DataFrame): Metadata for batch of events + + Returns: + tuple: (x, y) where x is input tensor, y is target tensor + """ + x = [] + y = [] + with h5py.File(self.data_path, 'r') as f: + for _, event in batch_metadata.iterrows(): + event_name = str(int(event[self.event_key])) + if event_name not in f['data']: + continue + g_event = f['data'][event_name] + waveform = g_event['waveforms'][event['index'], ::self.decimate, :] + if self.cutout: + if self.sliding_window: + windowlen = self.windowlen + window_end = np.random.randint(max(windowlen, self.cutout[0]), + min(waveform.shape[1], self.cutout[1]) + 1) + waveform = waveform[:, window_end - windowlen:window_end] + else: + waveform[:, np.random.randint(*self.cutout):] = 0 + x.append(waveform) + y.append(event[self.mag_key]) + x = np.array(x) + y = np.array(y) + if self.label_smoothing: + y += (y > 4) * np.random.randn(y.shape[0]).reshape(y.shape) * (y - 4) * 0.05 + + return (ms.tensor(x, dtype=ms.float32), + ms.tensor(np.expand_dims(np.expand_dims(y, axis=1), axis=2), dtype=ms.float32)) + class EarthquakeDataset(Dataset): """ Dataset class for loading and processing seismic event data. @@ -105,7 +204,7 @@ class EarthquakeDataset(Dataset): **kwargs, ): - super(EarthquakeDataset, self).__init__() + super().__init__() self.data_path = data_path self.event_key = event_key @@ -199,421 +298,343 @@ class EarthquakeDataset(Dataset): if self.shuffle: np.random.shuffle(self.indexes) - -class DataProcessor: +class PreloadedEventGenerator(Dataset): """ - A data processor for seismic event analysis that handles waveform preprocessing, - station selection, and target preparation for machine learning models. - Key functionalities: - Batch processing of seismic waveforms and metadata - Station selection strategies for efficient processing - Multiple preprocessing techniques (cutout, integration, etc.) - Coordinate transformations and target preparations - PGA (Peak Ground Acceleration) target handling - Data augmentation techniques (label smoothing, station blinding) + A custom dataset generator for preloading seismic events. """ + def __init__(self, data_path, event_key, data, event_metadata, waveform_shape=(3000, 6), key='MA', batch_size=32, + cutout=None, sliding_window=False, windowlen=3000, shuffle=True, coords_target=True, oversample=1, + pos_offset=(-21, -69), label_smoothing=False, station_blinding=False, magnitude_resampling=3, + pga_targets=None, adjust_mean=True, transform_target_only=False, max_stations=None, trigger_based=None, + min_upsample_magnitude=2, disable_station_foreshadowing=False, selection_skew=None, + pga_from_inactive=False, integrate=False, differentiate=False, sampling_rate=100., select_first=False, + fake_borehole=False, scale_metadata=True, pga_key='pga', pga_mode=False, p_pick_limit=5000, + coord_keys=None, upsample_high_station_events=None, no_event_token=False, pga_selection_skew=None, + **kwargs): + ''' + Initializes the PreloadedEventGenerator. - def __init__( - self, - waveform_shape=(3000, 6), - max_stations=None, - cutout=None, - sliding_window=False, - windowlen=3000, - coords_target=True, - pos_offset=(-21, -69), - label_smoothing=False, - station_blinding=False, - pga_targets=None, - adjust_mean=True, - transform_target_only=False, - trigger_based=None, - disable_station_foreshadowing=False, - selection_skew=None, - pga_from_inactive=False, - integrate=False, - sampling_rate=100.0, - select_first=False, - scale_metadata=True, - p_pick_limit=5000, - pga_mode=False, - no_event_token=False, - pga_selection_skew=None, - **kwargs, - ): + Args: + data_path: Path to the HDF5 file containing waveform data. + event_key: The key in the event metadata DataFrame identifying each event. + data: Dictionary containing 'coords' and 'pga' keys for metadata and PGA values. + event_metadata: Pandas DataFrame with event metadata. + waveform_shape: Shape of each waveform (number of samples, number of channels). + key: The key in event metadata to use for magnitude. + batch_size: Number of events per batch. + cutout: Tuple specifying the range for random cutout in the waveform. + sliding_window: Whether to use a sliding window for cutout. + windowlen: Length of the sliding window. + shuffle: Whether to shuffle the events at the end of each epoch. + coords_target: Whether to include event coordinates as targets. + oversample: Factor by which to oversample the events. + pos_offset: Offset to apply to event coordinates. + label_smoothing: Whether to apply label smoothing to magnitudes. + station_blinding: Whether to randomly blind stations in the waveforms. + magnitude_resampling: Factor by which to resample events based on their magnitude. + pga_targets: Number of PGA targets to sample per event. + adjust_mean: Whether to adjust the mean of the waveforms. + transform_target_only: Whether to apply transformations only to the target coordinates. + max_stations: Maximum number of stations to include per event. + trigger_based: Whether to zero out waveforms before the P-wave trigger. + min_upsample_magnitude: Minimum magnitude for upsampling. + disable_station_foreshadowing: Whether to disable station foreshadowing. + selection_skew: Skew parameter for selecting stations when max_stations is reached. + pga_from_inactive: Whether to sample PGA from inactive stations. + integrate: Whether to integrate the waveforms. + differentiate: Whether to differentiate the waveforms. + sampling_rate: Sampling rate of the waveforms. + select_first: Whether to select the first stations when max_stations is reached. + fake_borehole: Whether to add fake borehole channels to the waveforms. + scale_metadata: Whether to scale the metadata coordinates. + pga_key: Key in the data dictionary for PGA values. + pga_mode: Whether to operate in PGA mode. + p_pick_limit: Limit for P-wave picks. + coord_keys: Keys in the event metadata for coordinates. + upsample_high_station_events: Whether to upsample events with high station counts. + no_event_token: Whether to include an event token in the outputs. + pga_selection_skew: Skew parameter for selecting PGA targets. + **kwargs: + ''' + super().__init__() + if kwargs: + print(f'Unused parameters: {", ".join(kwargs.keys())}') + self.data_path = data_path + self.event_key = event_key + self.batch_size = batch_size + self.shuffle = shuffle self.waveform_shape = waveform_shape - self.max_stations = max_stations + self.metadata = data['coords'] + self.event_metadata = event_metadata + self.pga = data[pga_key] + self.key = key self.cutout = cutout self.sliding_window = sliding_window self.windowlen = windowlen self.coords_target = coords_target + self.oversample = oversample self.pos_offset = pos_offset self.label_smoothing = label_smoothing self.station_blinding = station_blinding + self.magnitude_resampling = magnitude_resampling self.pga_targets = pga_targets self.adjust_mean = adjust_mean self.transform_target_only = transform_target_only + self.max_stations = max_stations self.trigger_based = trigger_based self.disable_station_foreshadowing = disable_station_foreshadowing self.selection_skew = selection_skew self.pga_from_inactive = pga_from_inactive + self.pga_selection_skew = pga_selection_skew self.integrate = integrate + self.differentiate = differentiate self.sampling_rate = sampling_rate self.select_first = select_first + self.fake_borehole = fake_borehole self.scale_metadata = scale_metadata - self.p_pick_limit = p_pick_limit - self.pga_mode = pga_mode + self.upsample_high_station_events = upsample_high_station_events self.no_event_token = no_event_token - self.pga_selection_skew = pga_selection_skew - self.key = kwargs["key"] - - def process_batch(self, batch_data): - """Main method to process a batch of data, now decomposed into smaller functions.""" - ( - indexes, - waveforms_list, - metadata_list, - pga_list, - p_picks_list, - event_info_list, - pga_indexes, - ) = self._extract_batch_data(batch_data) + self.triggers = data['p_picks'] + self.pga_mode = pga_mode + self.p_pick_limit = p_pick_limit + self.base_indexes = np.arange(self.event_metadata.shape[0]) + self.reverse_index = None + if magnitude_resampling > 1: + magnitude = self.event_metadata[key].values + for i in np.arange(min_upsample_magnitude, 9): + ind = np.where(np.logical_and(i < magnitude, magnitude <= i + 1))[0] + self.base_indexes = np.concatenate( + (self.base_indexes, np.repeat(ind, int(magnitude_resampling ** (i - 1) - 1)))) + if pga_mode: + new_base_indexes = [] + self.reverse_index = [] + c = 0 + for idx in self.base_indexes: + num_samples = (len(self.pga[idx]) - 1) // pga_targets + 1 + new_base_indexes += [(idx, i) for i in range(num_samples)] + self.reverse_index += [c] + c += num_samples + self.reverse_index += [c] + self.base_indexes = new_base_indexes + self.indexes = np.arange(len(self.event_metadata)) + if coord_keys is None: + self.coord_keys = detect_location_keys(event_metadata.columns) + else: + self.coord_keys = coord_keys + self.on_epoch_end() - true_batch_size = len(indexes) - true_max_stations_in_batch = self._get_max_stations_in_batch(metadata_list) - waveforms, metadata, pga, full_p_picks, p_picks, reverse_selections = ( - self._initialize_arrays( - true_batch_size, true_max_stations_in_batch, metadata_list - ) - ) - waveforms, metadata, pga, p_picks, full_p_picks, reverse_selections = ( - self._process_stations( - waveforms_list, - metadata_list, - pga_list, - p_picks_list, - waveforms, - metadata, - pga, - p_picks, - full_p_picks, - ) - ) - magnitude, target = self._process_magnitude_and_targets(event_info_list) - org_waveform_length = waveforms.shape[2] - waveforms, _ = self._process_waveforms(waveforms, org_waveform_length, p_picks) - metadata, target = self._transform_locations(metadata, target) - magnitude = self._apply_label_smoothing(magnitude) - metadata, pga = self._adjust_metadata_and_pga(metadata, pga) - pga_values, pga_targets_data = self._process_pga_targets( - true_batch_size, - pga, - metadata, - pga_indexes, - reverse_selections, - full_p_picks, - indexes, - ) - waveforms, metadata = self._apply_station_blinding(waveforms, metadata) - waveforms, metadata = self._handle_stations_without_trigger(waveforms, metadata) - waveforms, metadata = self._ensure_no_empty_arrays(waveforms, metadata) - inputs, outputs = self._prepare_model_io( - waveforms, metadata, magnitude, target, pga_targets_data, pga_values - ) + def __len__(self): + """ + Returns the number of batches in the dataset. + """ + return int(np.ceil(len(self.indexes) / self.batch_size)) - return inputs, outputs + def __getitem__(self, index): + """ + Retrieves a batch of events from the dataset. + """ + indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size] + true_batch_size = len(indexes) + if self.pga_mode: + self.pga_indexes = [x[1] for x in indexes] + indexes = [x[0] for x in indexes] - def _extract_batch_data(self, batch_data): - """Extract data from the batch dictionary.""" - indexes = batch_data["indexes"] - waveforms_list = batch_data["waveforms"] - metadata_list = batch_data["metadata"] - pga_list = batch_data["pga"] - p_picks_list = batch_data["p_picks"] - event_info_list = batch_data["event_info"] - pga_indexes = batch_data.get("pga_indexes", None) - - return ( - indexes, - waveforms_list, - metadata_list, - pga_list, - p_picks_list, - event_info_list, - pga_indexes, - ) - - def _get_max_stations_in_batch(self, metadata_list): - """Calculate the maximum number of stations in the batch.""" - return max( - [len(m) for m in metadata_list if m is not None] + [self.max_stations] - ) - - def _initialize_arrays( - self, true_batch_size, true_max_stations_in_batch, metadata_list - ): - """Initialize arrays for batch processing.""" waveforms = np.zeros([true_batch_size, self.max_stations] + self.waveform_shape) - metadata = np.zeros( - (true_batch_size, true_max_stations_in_batch) + metadata_list[0].shape[1:] - ) + true_max_stations_in_batch = max(max([self.metadata[idx].shape[0] for idx in indexes]), self.max_stations) + metadata = np.zeros((true_batch_size, true_max_stations_in_batch) + self.metadata[0].shape[1:]) pga = np.zeros((true_batch_size, true_max_stations_in_batch)) full_p_picks = np.zeros((true_batch_size, true_max_stations_in_batch)) p_picks = np.zeros((true_batch_size, self.max_stations)) reverse_selections = [] - return waveforms, metadata, pga, full_p_picks, p_picks, reverse_selections + waveforms, metadata, pga, p_picks, reverse_selections, full_p_picks = ( + self.htpyfile_process(indexes, waveforms, metadata, pga, + p_picks, reverse_selections, full_p_picks)) - def _process_stations( - self, - waveforms_list, - metadata_list, - pga_list, - p_picks_list, - waveforms, - metadata, - pga, - p_picks, - full_p_picks, - ): - """Process stations and waveforms for each item in the batch.""" - reverse_selections = [] + magnitude = self.event_metadata.iloc[indexes][self.key].values.copy() + magnitude = magnitude.astype(np.float32) - for i, (waveform_data, meta, pga_data, p_pick_data) in enumerate( - zip(waveforms_list, metadata_list, pga_list, p_picks_list) - ): - if waveform_data is None: - continue - - num_stations = waveform_data.shape[0] - - if num_stations <= self.max_stations: - waveforms[i, :num_stations] = waveform_data - metadata[i, : len(meta)] = meta - pga[i, : len(pga_data)] = pga_data - p_picks[i, : len(p_pick_data)] = p_pick_data - reverse_selections += [[]] - else: - selection = self._select_stations(num_stations, p_pick_data) + target, waveforms, magnitude, metadata, pga_values, pga_targets = ( + self.data_preprocessing(indexes, waveforms, p_picks, magnitude, metadata, + pga, true_batch_size, reverse_selections, full_p_picks)) - metadata[i, : len(selection)] = meta[selection] - pga[i, : len(selection)] = pga_data[selection] - full_p_picks[i, : len(selection)] = p_pick_data[selection] + waveforms, metadata = self.data_processing(waveforms, metadata) - tmp_reverse_selection = [0 for _ in selection] - for j, s in enumerate(selection): - tmp_reverse_selection[s] = j - reverse_selections += [tmp_reverse_selection] + return self.get_result(waveforms, metadata, magnitude, target, pga_targets, pga_values) - selection = selection[: self.max_stations] - waveforms[i] = waveform_data[selection] - p_picks[i] = p_pick_data[selection] - - return waveforms, metadata, pga, p_picks, full_p_picks, reverse_selections - - def _select_stations(self, num_stations, p_pick_data): - """Select stations based on configured strategy.""" - if self.selection_skew is None: - selection = np.arange(0, num_stations) - np.random.shuffle(selection) + def htpyfile_process(self, indexes, waveforms, metadata, pga, + p_picks, reverse_selections, full_p_picks): + """ + Processes the HDF5 file to retrieve waveform data for a batch of events. + """ + with h5py.File(self.data_path, 'r') as f: + for i, idx in enumerate(indexes): + event = self.event_metadata.iloc[idx] + event_name = str(event[self.event_key]) + if event_name not in f['data']: + continue + g_event = f['data'][event_name] + waveform_data = g_event['waveforms'][:, :, :] + + num_stations = waveform_data.shape[0] + + if num_stations <= self.max_stations: + waveforms[i, :num_stations] = waveform_data + metadata[i, :len(self.metadata[idx])] = self.metadata[idx] + pga[i, :len(self.pga[idx])] = self.pga[idx] + p_picks[i, :len(self.triggers[idx])] = self.triggers[idx] + reverse_selections += [[]] + else: + if self.selection_skew is None: + selection = np.arange(0, num_stations) + np.random.shuffle(selection) + else: + tmp_p_picks = self.triggers[idx].copy() + mask = np.logical_and(tmp_p_picks <= 0, tmp_p_picks > self.p_pick_limit) + tmp_p_picks[mask] = min(np.max(tmp_p_picks), self.p_pick_limit) + coeffs = np.exp(-tmp_p_picks / self.selection_skew) + coeffs *= np.random.random(coeffs.shape) + coeffs[self.triggers[idx] == 0] = 0 + coeffs[self.triggers[idx] > self.waveform_shape[0]] = 0 + selection = np.argsort(-coeffs) + + if self.select_first: + selection = np.argsort(self.triggers[idx]) + + metadata[i, :len(selection)] = self.metadata[idx][selection] + pga[i, :len(selection)] = self.pga[idx][selection] + full_p_picks[i, :len(selection)] = self.triggers[idx][selection] + + tmp_reverse_selection = [0 for _ in selection] + for j, s in enumerate(selection): + tmp_reverse_selection[s] = j + reverse_selections += [tmp_reverse_selection] + + selection = selection[:self.max_stations] + waveforms[i] = waveform_data[selection] + p_picks[i] = self.triggers[idx][selection] + return waveforms, metadata, pga, p_picks, reverse_selections, full_p_picks + + def pga_mode_process(self, waveforms, reverse_selections, metadata, + pga_values, pga_targets, pga, indexes, full_p_picks): + """ + Processes the data in PGA mode. + """ + if self.pga_mode: + for i in range(waveforms.shape[0]): + pga_index = self.pga_indexes[i] + if reverse_selections[i]: + sorted_pga = pga[i, reverse_selections[i]] + sorted_metadata = metadata[i, reverse_selections[i]] + else: + sorted_pga = pga[i] + sorted_metadata = metadata[i] + pga_values_pre = sorted_pga[pga_index * self.pga_targets:(pga_index + 1) * self.pga_targets] + pga_values[i, :len(pga_values_pre)] = pga_values_pre + pga_targets_pre = sorted_metadata[pga_index * self.pga_targets:(pga_index + 1) * self.pga_targets, :] + if pga_targets_pre.shape[-1] == 4: + pga_targets_pre = pga_targets_pre[:, (0, 1, 3)] + pga_targets[i, :len(pga_targets_pre), :] = pga_targets_pre else: - tmp_p_picks = p_pick_data.copy() - mask = np.logical_and(tmp_p_picks <= 0, tmp_p_picks > self.p_pick_limit) - tmp_p_picks[mask] = min(np.max(tmp_p_picks), self.p_pick_limit) - coeffs = np.exp(-tmp_p_picks / self.selection_skew) - coeffs *= np.random.random(coeffs.shape) - coeffs[p_pick_data == 0] = 0 - coeffs[p_pick_data > self.waveform_shape[0]] = 0 - selection = np.argsort(-coeffs) - - if self.select_first: - selection = np.argsort(p_pick_data) - - return selection - - def _process_magnitude_and_targets(self, event_info_list): - """Process magnitude and coordinate targets.""" - magnitude = np.array([e[self.key] for e in event_info_list], dtype=np.float32) + pga[np.logical_or(np.isnan(pga), np.isinf(pga))] = 0 + for i in range(waveforms.shape[0]): + active = np.where(pga[i] != 0)[0] + l = len(active) + if l == 0: + raise ValueError(f'Found event without PGA idx={indexes[i]}') + while len(active) < self.pga_targets: + active = np.repeat(active, 2) + if self.pga_selection_skew is not None: + active_p_picks = full_p_picks[i, active] + mask = np.logical_and(active_p_picks <= 0, active_p_picks > self.p_pick_limit) + active_p_picks[mask] = min(np.max(active_p_picks), self.p_pick_limit) + coeffs = np.exp(-active_p_picks / self.pga_selection_skew) + coeffs *= np.random.random(coeffs.shape) + active = active[np.argsort(-coeffs)] + else: + np.random.shuffle(active) + + samples = active[:self.pga_targets] + if metadata.shape[-1] == 3: + pga_targets[i] = metadata[i, samples, :] + else: + full_targets = metadata[i, samples] + pga_targets[i] = full_targets[:, (0, 1, 3)] + pga_values[i] = pga[i, samples] + return pga_values, pga_targets + + def data_preprocessing(self, indexes, waveforms, p_picks, magnitude, metadata, + pga, true_batch_size, reverse_selections, full_p_picks): + """ + Data preprocessing. + """ target = None - if self.coords_target: - coord_keys = detect_location_keys( - [col for e in event_info_list for col in e.index] - ) - target = np.array( - [[e[k] for k in coord_keys] for e in event_info_list], dtype=np.float32 - ) - - magnitude = np.expand_dims(np.expand_dims(magnitude, axis=-1), axis=-1) - return magnitude, target - - def _process_waveforms(self, waveforms, org_waveform_length, p_picks): - """Apply cutout, sliding window, trigger-based, and integration transformations to waveforms.""" - cutout = org_waveform_length - + target = self.event_metadata.iloc[indexes][self.coord_keys].values + target = target.astype(np.float32) + org_waveform_length = waveforms.shape[2] if self.cutout: if self.sliding_window: windowlen = self.windowlen - window_end = np.random.randint( - max(windowlen, self.cutout[0]), - min(waveforms.shape[2], self.cutout[1]) + 1, - ) - waveforms = waveforms[:, :, window_end - windowlen : window_end] - + window_end = np.random.randint(max(windowlen, self.cutout[0]), + min(waveforms.shape[2], self.cutout[1]) + 1) + waveforms = waveforms[:, :, window_end - windowlen: window_end] cutout = window_end if self.adjust_mean: waveforms -= np.mean(waveforms, axis=2, keepdims=True) else: cutout = np.random.randint(*self.cutout) if self.adjust_mean: - waveforms -= np.mean( - waveforms[:, :, : cutout + 1], axis=2, keepdims=True - ) + waveforms -= np.mean(waveforms[:, :, :cutout + 1], axis=2, keepdims=True) waveforms[:, :, cutout:] = 0 + else: + cutout = waveforms.shape[2] if self.trigger_based: p_picks[p_picks <= 0] = org_waveform_length waveforms[cutout < p_picks, :, :] = 0 - if self.integrate: waveforms = np.cumsum(waveforms, axis=2) / self.sampling_rate + if self.differentiate: + waveforms = np.diff(waveforms, axis=2) - return waveforms, cutout - - def _transform_locations(self, metadata, target): - """Transform locations using the location_transformation method.""" + magnitude = np.expand_dims(np.expand_dims(magnitude, axis=-1), axis=-1) if self.coords_target: metadata, target = self.location_transformation(metadata, target) else: metadata = self.location_transformation(metadata) - return metadata, target - def _apply_label_smoothing(self, magnitude): - """Apply label smoothing to magnitude if enabled.""" if self.label_smoothing: - magnitude += ( - (magnitude > 4) - * np.random.randn(magnitude.shape[0]).reshape(magnitude.shape) - * (magnitude - 4) - * 0.05 - ) - return magnitude - - def _adjust_metadata_and_pga(self, metadata, pga): - """Adjust metadata and PGA arrays based on configuration.""" + magnitude += (magnitude > 4) * np.random.randn(magnitude.shape[0]).reshape(magnitude.shape) * ( + magnitude - 4) * 0.05 if not self.pga_from_inactive and not self.pga_mode: - metadata = metadata[:, : self.max_stations] - pga = pga[:, : self.max_stations] - return metadata, pga - - def _process_pga_targets( - self, - true_batch_size, - pga, - metadata, - pga_indexes, - reverse_selections, - full_p_picks, - indexes, - ): - """Process PGA targets if enabled.""" - pga_values = None - pga_targets_data = None - + metadata = metadata[:, :self.max_stations] + pga = pga[:, :self.max_stations] + pga_values = () + pga_targets = () if self.pga_targets: pga_values = np.zeros((true_batch_size, self.pga_targets)) - pga_targets_data = np.zeros((true_batch_size, self.pga_targets, 3)) - - if self.pga_mode and pga_indexes is not None: - self._process_pga_mode( - pga_values, - pga_targets_data, - pga, - metadata, - pga_indexes, - reverse_selections, - ) - else: - self._process_pga_normal( - pga_values, pga_targets_data, pga, metadata, full_p_picks, indexes - ) + pga_targets = np.zeros((true_batch_size, self.pga_targets, 3)) + + pga_values, pga_targets = self.pga_mode_process(waveforms, reverse_selections, metadata, + pga_values, pga_targets, pga, indexes, full_p_picks) pga_values = pga_values.reshape((true_batch_size, self.pga_targets, 1, 1)) - return pga_values, pga_targets_data + return target, waveforms, magnitude, metadata, pga_values, pga_targets - def _process_pga_mode( - self, - pga_values, - pga_targets_data, - pga, - metadata, - pga_indexes, - reverse_selections, - ): - """Process PGA in PGA mode.""" - for i in range(len(pga_values)): - pga_index = pga_indexes[i] - if reverse_selections[i]: - sorted_pga = pga[i, reverse_selections[i]] - sorted_metadata = metadata[i, reverse_selections[i]] - else: - sorted_pga = pga[i] - sorted_metadata = metadata[i] - pga_values_pre = sorted_pga[ - pga_index * self.pga_targets : (pga_index + 1) * self.pga_targets - ] - pga_values[i, : len(pga_values_pre)] = pga_values_pre - pga_targets_pre = sorted_metadata[ - pga_index * self.pga_targets : (pga_index + 1) * self.pga_targets, - :, - ] - if pga_targets_pre.shape[-1] == 4: - pga_targets_pre = pga_targets_pre[:, (0, 1, 3)] - pga_targets_data[i, : len(pga_targets_pre), :] = pga_targets_pre - - def _process_pga_normal( - self, pga_values, pga_targets_data, pga, metadata, full_p_picks, indexes - ): - """Process PGA in normal mode.""" - pga[np.logical_or(np.isnan(pga), np.isinf(pga))] = 0 - for i in range(pga_values.shape[0]): - active = np.where(pga[i] != 0)[0] - if not active: - raise ValueError(f"Found event without PGA idx={indexes[i]}") - while len(active) < self.pga_targets: - active = np.repeat(active, 2) - - if self.pga_selection_skew is not None: - active = self._select_pga_with_skew(active, full_p_picks[i]) - else: - np.random.shuffle(active) - - samples = active[: self.pga_targets] - if metadata.shape[-1] == 3: - pga_targets_data[i] = metadata[i, samples, :] - else: - full_targets = metadata[i, samples] - pga_targets_data[i] = full_targets[:, (0, 1, 3)] - pga_values[i] = pga[i, samples] - - def _select_pga_with_skew(self, active, full_p_picks): - """Select PGA with skew-based selection.""" - active_p_picks = full_p_picks[active] - mask = np.logical_and(active_p_picks <= 0, active_p_picks > self.p_pick_limit) - active_p_picks[mask] = min(np.max(active_p_picks), self.p_pick_limit) - coeffs = np.exp(-active_p_picks / self.pga_selection_skew) - coeffs *= np.random.random(coeffs.shape) - return active[np.argsort(-coeffs)] - - def _apply_station_blinding(self, waveforms, metadata): - """Apply station blinding if enabled.""" + def data_processing(self, waveforms, metadata): + """ + Data process. + """ + metadata = metadata[:, :self.max_stations] if self.station_blinding: mask = np.zeros(waveforms.shape[:2], dtype=bool) for i in range(waveforms.shape[0]): active = np.where((waveforms[i] != 0).any(axis=(1, 2)))[0] - if not active == 0: + l = len(active) + if l == 0: active = np.zeros(1, dtype=int) blind_length = np.random.randint(0, len(active)) np.random.shuffle(active) @@ -623,58 +644,48 @@ class DataProcessor: waveforms[mask] = 0 metadata[mask] = 0 - return waveforms, metadata - - def _handle_stations_without_trigger(self, waveforms, metadata): - """Handle stations without trigger signal.""" - stations_without_trigger = (metadata != 0).any(axis=2) & (waveforms == 0).all( - axis=(2, 3) - ) - + stations_without_trigger = (metadata != 0).any(axis=2) & (waveforms == 0).all(axis=(2, 3)) if self.disable_station_foreshadowing: metadata[stations_without_trigger] = 0 else: waveforms[stations_without_trigger, 0, 0] += 1e-9 - return waveforms, metadata - - def _ensure_no_empty_arrays(self, waveforms, metadata): - """Ensure there are no empty arrays in the batch.""" - mask = np.logical_and( - (metadata == 0).all(axis=(1, 2)), (waveforms == 0).all(axis=(1, 2, 3)) - ) + mask = np.logical_and((metadata == 0).all(axis=(1, 2)), (waveforms == 0).all(axis=(1, 2, 3))) waveforms[mask, 0, 0, 0] = 1e-9 metadata[mask, 0, 0] = 1e-9 return waveforms, metadata - def _prepare_model_io( - self, waveforms, metadata, magnitude, target, pga_targets_data, pga_values - ): - """Prepare model inputs and outputs.""" - inputs = [ - mindspore.tensor(waveforms, dtype=mindspore.float32), - mindspore.tensor(metadata, dtype=mindspore.float32), - ] + def get_result(self, waveforms, metadata, magnitude, target, pga_targets, pga_values): + """ + get result. + """ + inputs = [ms.tensor(waveforms, dtype=ms.float32), ms.tensor(metadata, dtype=ms.float32)] outputs = [] - if not self.no_event_token: - outputs += [mindspore.tensor(magnitude, dtype=mindspore.float32)] + outputs += [ms.tensor(magnitude, dtype=ms.float32)] if self.coords_target: target = np.expand_dims(target, axis=-1) - outputs += [mindspore.tensor(target, dtype=mindspore.float32)] + outputs += [ms.tensor(target, dtype=ms.float32)] - if self.pga_targets and pga_values is not None and pga_targets_data is not None: - inputs += [mindspore.tensor(pga_targets_data, dtype=mindspore.float32)] - outputs += [mindspore.tensor(pga_values, dtype=mindspore.float32)] + if self.pga_targets: + inputs += [ms.tensor(pga_targets, dtype=ms.float32)] + outputs += [ms.tensor(pga_values, dtype=ms.float32)] return inputs, outputs + def on_epoch_end(self): + """ + Resets the indexes for a new epoch, optionally with oversampling and shuffling. + """ + self.indexes = np.repeat(self.base_indexes.copy(), self.oversample, axis=0) + if self.shuffle: + np.random.shuffle(self.indexes) + def location_transformation(self, metadata, target=None): """ - Apply transformations to the metadata and optionally to the target. - Adjusts positions based on a positional offset and scales the data if required. + Transforms the event coordinates and optionally the target coordinates. """ transform_target_only = self.transform_target_only metadata = metadata.copy() @@ -682,86 +693,32 @@ class DataProcessor: metadata_old = metadata metadata = metadata.copy() mask = (metadata == 0).all(axis=2) - if target is not None: target[:, 0] -= self.pos_offset[0] target[:, 1] -= self.pos_offset[1] - metadata[:, :, 0] -= self.pos_offset[0] metadata[:, :, 1] -= self.pos_offset[1] + + # Coordinates to kilometers (assuming a flat earth, which is okay close to equator) if self.scale_metadata: metadata[:, :, :2] *= D2KM if target is not None: target[:, :2] *= D2KM + metadata[mask] = 0 + if self.scale_metadata: metadata /= 100 if target is not None: target /= 100 + if transform_target_only: metadata = metadata_old + if target is None: return metadata - return metadata, target - - -class PreloadedEventGenerator(Dataset): - """ - A custom PyTorch Dataset class designed to generate preloaded event data for training or evaluation. - This class wraps an `EarthquakeDataset` and a `DataProcessor` to provide processed input-output pairs. - Attributes: - dataset (EarthquakeDataset): An instance of the EarthquakeDataset class, responsible for loading - raw earthquake-related data. - processor (DataProcessor): An instance of the DataProcessor class, responsible for processing - the raw data into model-ready inputs and outputs. - """ - - def __init__(self, data_path, event_key, data, event_metadata, **kwargs): - """ - Initializes the PreloadedEventGenerator. - Args: - data_path (str): The file path or directory where the dataset is stored. - event_key (str): A key used to identify specific events within the dataset. - data (dict or array-like): Raw data associated with the events. - event_metadata (dict or DataFrame): Metadata describing the events in the dataset. - **kwargs: Additional keyword arguments passed to both EarthquakeDataset and DataProcessor. - """ - super(PreloadedEventGenerator, self).__init__() - self.dataset = EarthquakeDataset( - data_path=data_path, - event_key=event_key, - data=data, - event_metadata=event_metadata, - **kwargs, - ) - self.processor = DataProcessor(**kwargs) - - def __len__(self): - """ - Returns the total number of samples in the dataset. - - Returns: - int: The length of the underlying EarthquakeDataset. - """ - return len(self.dataset) - - def __getitem__(self, index): - """ - Retrieves and processes a single batch of data at the given index. - - Args: - index (int): The index of the data sample to retrieve. - - Returns: - tuple: A tuple containing two elements: - inputs: Processed input data ready for model consumption. - outputs: Corresponding target outputs for the model. - """ - batch_data = self.dataset[index] - inputs, outputs = self.processor.process_batch(batch_data) - - return inputs, outputs + return metadata, target def generator_from_config( config, @@ -796,4 +753,5 @@ def generator_from_config( pga_mode=pga, **generator_params, ) + return generator diff --git a/MindEarth/applications/earthquake/G-TEAM/src/forcast.py b/MindEarth/applications/earthquake/G-TEAM/src/forcast.py index 2c67ab33b3ac1d27f13c9dd34e15175c90ef6dd0..b1752b8e4ac55119e310e2698bf312674f7d6ec5 100644 --- a/MindEarth/applications/earthquake/G-TEAM/src/forcast.py +++ b/MindEarth/applications/earthquake/G-TEAM/src/forcast.py @@ -12,19 +12,73 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"GTeam inference" +"GTeam forcast" +import os +from tqdm import tqdm import numpy as np +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + from src.utils import ( predict_at_time, calc_mag_stats, calc_loc_stats, calc_pga_stats, ) -from src.data import load_data +from src.data import DataGenerator, PreloadedEventGenerator, load_pickle_data, load_data +from src.models import SingleStationModel +from src.utils import evaluation, seed_np_tf from src.visual import generate_true_pred_plot +class CustomWithLossCell(nn.Cell): + """ + A neural network cell that wraps a main network and loss function together, + allowing the entire forward pass including loss computation to be treated as a single cell. + + This class combines a neural network model and a loss function into a single computation unit, + which is useful for training loops and model encapsulation in deep learning frameworks. + + Attributes: + net (nn.Cell): The main neural network model whose output will be used in loss computation. + loss_fn (nn.Cell): The loss function cell that computes the difference between predictions + and true labels. + """ + + def __init__(self, net, loss_fn): + """ + Initializes the CustomWithLossCell with a network model and loss function. + + Args: + net (nn.Cell): The neural network model whose output will be used for loss calculation. + loss_fn (nn.Cell): The loss computation function that takes (true_labels, predictions) + and returns a scalar loss value. + """ + super().__init__() + self.net = net + self.loss_fn = loss_fn + + def construct(self, x, y): + ''' + Computes the loss by first passing input data through the network and then applying the loss function. + + Args: + X (Tensor): Input data tensor containing features. + y (Tensor): Ground truth labels tensor. + + Returns: + Tensor: Computed loss value. + + Note: + The input labels 'y' are squeezed along dimension 2 to match the output shape from the network. + This ensures the loss function receives inputs of the expected shape. + ''' + outputs = self.net(x) + return self.loss_fn(y.squeeze(2), outputs) + + class GTeamInference: """ Initialize the GTeamInference class. @@ -143,3 +197,325 @@ class GTeamInference: ) self._save_results() print("Inference completed and results saved") + +class GTeamTrain: + """ + A class to handle the training of a full model for earthquake detection and localization. + It manages data loading, training of single-station models, and full-model training. + """ + def __init__(self, model_ins, cfg, output_dir, logger): + """ + Initialize the GTeamTrain class with model, configuration, output directory, and logger. + Args: + model_ins (nn.Cell): The full model instance to be trained. + cfg (dict): Configuration dictionary containing training parameters and paths. + output_dir (str): Directory to save checkpoints and outputs. + logger (logging.Logger): Logger instance for logging messages. + """ + self.full_model = model_ins + self.cfg = cfg + self.output_dir = output_dir + self.logger = logger + self.waveform_shape = [3000, 3] + self.training_params = self.cfg['training_params'] + self.generator_params = self.training_params.get('generator_params', [self.training_params.copy()]) + self.file_basename = os.path.basename(self.training_params['data_path']).split('.')[0] + + def load_train_data(self): + """ + Load training data from a pickle file. + Returns: + Data structure: The loaded training data. + """ + data_path = self.cfg['data']["root_dir"] + filename_train = os.path.join(data_path, f"{self.file_basename}_train.pkl") + return load_pickle_data(filename_train) + + def load_val_data(self): + """ + Load validation data from a pickle file. + Returns: + Data structure: The loaded validation data. + """ + data_path = self.cfg['data']["root_dir"] + filename_val = os.path.join(data_path, f"{self.file_basename}_val.pkl") + return load_pickle_data(filename_val) + + def init_single_generator(self, sampling_rate, event_metadata_index_train, event_key_train, + event_metadata_index_val, event_key_val, decimate_train): + """ + Initialize the single-station model and its data generators for training and validation. + Args: + sampling_rate (float): Sampling rate of the seismic data. + event_metadata_index_train (list): Indices for training events in the metadata. + event_key_train (str): Key for selecting the training event data. + event_metadata_index_val (list): Indices for validation events in the metadata. + event_key_val (str): Key for selecting the validation event data. + decimate_train (bool): Whether to decimate the training data. + """ + self.single_station_model = SingleStationModel(output_mlp_dims=self.cfg['model']['output_mlp_dims'], + use_mlp=self.cfg['model']['use_mlp']) + noise_seconds = self.generator_params[0].get('noise_seconds', 5) + cutout = (sampling_rate * (noise_seconds + self.generator_params[0]['cutout_start']), + sampling_rate * (noise_seconds + self.generator_params[0]['cutout_end'])) + self.single_train_generator = DataGenerator(self.training_params['data_path'], + event_metadata_index_train, event_key_train, + mag_key=self.generator_params[0]['key'], + batch_size=self.generator_params[0]['batch_size'], + cutout=cutout, + label_smoothing=True, + sliding_window=self.generator_params[0].get('sliding_window', + False), + decimate=decimate_train) + self.single_validation_generator = DataGenerator(self.training_params['data_path'], + event_metadata_index_val, event_key_val, + mag_key=self.generator_params[0]['key'], + batch_size=self.generator_params[0]['batch_size'], + cutout=cutout, + label_smoothing=True, + sliding_window=self.generator_params[0].get('sliding_window', + False), + decimate=decimate_train) + optimizer_single = nn.Adam(self.single_station_model.trainable_params(), learning_rate=1e-4) + self.criterion_single_mse = nn.MSELoss() + + loss_net = CustomWithLossCell(self.single_station_model, self.criterion_single_mse) + self.single_train_network = nn.TrainOneStepCell(loss_net, optimizer_single) + + self.single_station_model.set_train(True) + + def single_station_train(self, sampling_rate, event_metadata_index_train, event_key_train, + event_metadata_index_val, event_key_val, decimate_train): + """ + Train the single-station model. Loads a pre-trained model if specified, otherwise + initializes the generator and trains from scratch. + Args: + sampling_rate (float): Sampling rate of the seismic data. + event_metadata_index_train (list): Indices for training events in the metadata. + event_key_train (str): Key for selecting the training event data. + event_metadata_index_val (list): Indices for validation events in the metadata. + event_key_val (str): Key for selecting the validation event data. + decimate_train (bool): Whether to decimate the training data. + """ + if 'single_station_model_path' in self.training_params: + print('Loading single station model') + param_dict = ms.load_checkpoint(self.training_params['single_station_model_path']) + ms.load_param_into_net(self.single_station_model, param_dict) + elif 'transfer_model_path' not in self.training_params: + self.init_single_generator(sampling_rate, event_metadata_index_train, event_key_train, + event_metadata_index_val, event_key_val, decimate_train) + + for epoch in tqdm(range(self.training_params['epochs_single_station']), + desc='training single station model'): + train_loss = 0.0 + + for i in range(len(self.single_train_generator)): + x, y = self.single_train_generator[i] + loss = self.single_train_network(x, y) + train_loss += loss.asnumpy() + + train_loss /= len(self.single_train_generator) + + val_loss = 0.0 + for i in range(len(self.single_validation_generator)): + x, y = self.single_validation_generator[i] + outputs = self.single_station_model(x) + loss = self.criterion_single_mse(y.squeeze(2), outputs) + val_loss += loss.item() + + val_loss /= len(self.single_validation_generator) + + print(f'Epoch {epoch + 1}/{self.training_params["epochs_single_station"]}, ' + f'Training Loss: {train_loss}, Validation Loss: {val_loss}') + + ms.save_checkpoint(self.single_station_model, + os.path.join(self.output_dir, f'single-station-{epoch + 1}')) + + def init_full_generator(self, sampling_rate, event_key_train, data_train, event_metadata_train, + max_stations, event_key_val, data_val, event_metadata_val): + """ + Initialize the full model's data generators and optimizer. + Args: + sampling_rate (float): Sampling rate of the seismic data. + event_key_train (str): Key for selecting the training event data. + data_train: Training data. + event_metadata_train: Metadata for training events. + max_stations (int): Maximum number of stations to consider. + event_key_val (str): Key for selecting the validation event data. + data_val: Validation data. + event_metadata_val: Metadata for validation events. + """ + if 'load_model_path' in self.training_params: + print('Loading full model') + param_dict = ms.load_checkpoint(self.training_params['load_model_path']) + ms.load_param_into_net(self.full_model, param_dict) + + n_pga_targets = self.cfg['model'].get('n_pga_targets', 0) + no_event_token = self.cfg['model'].get('no_event_token', False) + + self.optimizer_full = nn.Adam(self.full_model.trainable_params(), learning_rate=1e-4) + self.losses_full_mse = {'magnitude': nn.MSELoss(), 'location': nn.MSELoss(), 'pga': nn.MSELoss()} + + generator_param_set = self.generator_params[0] + noise_seconds = generator_param_set.get('noise_seconds', 5) + cutout = (sampling_rate * (noise_seconds + generator_param_set['cutout_start']), + sampling_rate * (noise_seconds + generator_param_set['cutout_end'])) + + generator_param_set['transform_target_only'] = generator_param_set.get('transform_target_only', True) + + if 'data_path' in generator_param_set: + del generator_param_set['data_path'] + + self.full_train_generator = PreloadedEventGenerator(self.training_params['data_path'], + event_key_train, + data_train, + event_metadata_train, + waveform_shape=self.waveform_shape, + coords_target=True, + label_smoothing=True, + station_blinding=True, + cutout=cutout, + pga_targets=n_pga_targets, + max_stations=max_stations, + sampling_rate=sampling_rate, + no_event_token=no_event_token, + **generator_param_set) + + old_oversample = generator_param_set.get('oversample', 1) + generator_param_set['oversample'] = 4 + + self.full_validation_generator = PreloadedEventGenerator(self.training_params['data_path'], + event_key_val, + data_val, + event_metadata_val, + waveform_shape=self.waveform_shape, + coords_target=True, + station_blinding=True, + cutout=cutout, + pga_targets=n_pga_targets, + max_stations=max_stations, + sampling_rate=sampling_rate, + no_event_token=no_event_token, + **generator_param_set) + + generator_param_set['oversample'] = old_oversample + print('len(full_train_generator)', len(self.full_train_generator)) + + self.loss_weights = self.training_params['loss_weights'] + print(f'The total number of parameters: {sum(p.numel() for p in self.full_model.trainable_params())}') + + def full_station_train(self, sampling_rate, event_key_train, data_train, event_metadata_train, + max_stations, event_key_val, data_val, event_metadata_val): + """ + Train the full station model using the initialized generators and optimizer. + + Args: + sampling_rate (float): Sampling rate of the seismic data + event_key_train (str): Key for selecting training event data + data_train: Training data + event_metadata_train: Training event metadata + max_stations (int): Maximum number of stations to consider + event_key_val (str): Key for selecting validation event data + data_val: Validation data + event_metadata_val: Validation event metadata + """ + self.init_full_generator(sampling_rate, event_key_train, data_train, event_metadata_train, + max_stations, event_key_val, data_val, event_metadata_val) + def calculate_total_loss(network, x, y): + train_mag_loss = 0 + train_loc_loss = 0 + train_pga_loss = 0 + outputs = network(x[0], x[1], x[2]) + total_loss = 0 + for k, loss_fn in self.losses_full_mse.items(): + if k == 'magnitude': + mag_pre = outputs[0] + mag_target = y[0] + mag_loss = loss_fn(mag_target.squeeze(2), mag_pre) * self.loss_weights[k] + train_mag_loss += mag_loss + total_loss += mag_loss + elif k == 'location': + loc_pre = outputs[1] + loc_target = y[1] + loc_loss = loss_fn(loc_target.squeeze(2), loc_pre) * self.loss_weights[k] + train_loc_loss += loc_loss + total_loss += loc_loss + elif k == 'pga': + pga_pre = outputs[2] + if 'italy' in self.file_basename: + pga_target = y[2] + else: + pga_target = ops.log(ops.abs(y[2])) + pga_loss = loss_fn(pga_target.squeeze(3), pga_pre) * self.loss_weights[k] + train_pga_loss += pga_loss + total_loss += pga_loss + return total_loss + + self.full_model.set_train() + grad_fn = ms.value_and_grad( + fn=calculate_total_loss, + grad_position=None, + weights=self.full_model.trainable_params(), + has_aux=False + ) + for epoch in tqdm(range(self.training_params['epochs_full_model']), desc='training full model'): + train_loss = 0 + + for i in range(len(self.full_train_generator)): + x, y = self.full_train_generator[i] + + total_loss, grads = grad_fn(self.full_model, x, y) + self.optimizer_full(grads) + + train_loss += total_loss.item() + avg_train_loss = train_loss / len(self.full_train_generator) + + avg_val_loss = evaluation(self.full_model, self.full_validation_generator, + self.losses_full_mse, self.loss_weights) + + print(f'Epoch {epoch + 1}/{self.training_params["epochs_full_model"]}', + f'Average Training Loss: {avg_train_loss}', f'Average val Loss: {avg_val_loss}') + + ms.save_checkpoint(self.full_model, os.path.join(self.output_dir, f'event-{epoch + 1}')) + + print('Training complete, and loss history saved.') + + def train(self): + """ + Train the full model for earthquake detection and localization. + + This method orchestrates the training process by: + 1. Setting the random seed for reproducibility. + 2. Loading training and validation datasets. + 3. Extracting key parameters like sampling rate and event metadata. + 4. Training single-station models for each station in the dataset. + 5. Training the full multi-station model using the pre-trained single-station models. + + Steps: + - Initialize random seed from configuration (default: 42) + - Load training data and extract metadata + - Load validation data + - Extract sampling rate and remove 'max_stations' from model config + - Train single-station models using training and validation data + - Train full model using combined data from all stations + + Note: This method assumes that the `single_station_train` and `full_station_train` methods are implemented. + """ + seed_np_tf(self.cfg['training_params'].get('seed', 42)) + + print('Loading data') + (event_metadata_index_train, event_metadata_train, metadata_train, + data_train, event_key_train, decimate_train) = self.load_train_data() + (event_metadata_index_val, event_metadata_val, _, + data_val, event_key_val, _) = self.load_val_data() + + sampling_rate = metadata_train['sampling_rate'] + max_stations = self.cfg['model']['max_stations'] + del self.cfg['model']['max_stations'] + + print('training') + self.single_station_train(sampling_rate, event_metadata_index_train, event_key_train, + event_metadata_index_val, event_key_val, decimate_train) + + self.full_station_train(sampling_rate, event_key_train, data_train, event_metadata_train, + max_stations, event_key_val, data_val, event_metadata_val) diff --git a/MindEarth/applications/earthquake/G-TEAM/src/models.py b/MindEarth/applications/earthquake/G-TEAM/src/models.py index d7a650c9d266161043ee21ea2a7abf3ed3955816..1c91673a673d2c80b0fbf448eb096240fcf393ed 100644 --- a/MindEarth/applications/earthquake/G-TEAM/src/models.py +++ b/MindEarth/applications/earthquake/G-TEAM/src/models.py @@ -29,20 +29,31 @@ class MLP(nn.Cell): final_activation: The activation function for the final layer. Default is nn.ReLU. """ - def __init__(self, input_shape, dims=(100, 50), final_activation=nn.ReLU()): + def __init__(self, input_shape, dims=(100, 50), final_activation=nn.ReLU(), is_mlp=False): super().__init__() layers = [] in_dim = input_shape[0] - - for dim in dims[:-1]: - layers.append(nn.Dense(in_dim, dim)) - layers.append(nn.ReLU()) - in_dim = dim - layers.append(nn.Dense(in_dim, dims[-1])) - - if final_activation: - layers.append(final_activation) - self.model = nn.SequentialCell(*layers) + if is_mlp: + for dim in dims[:-1]: + layers.append(nn.Dense(in_dim, dim)) + layers.append(nn.LayerNorm((dim,))) + layers.append(nn.ReLU()) + in_dim = dim + layers.append(nn.Dense(in_dim, dims[-1])) + + if final_activation: + layers.append(final_activation) + self.model = nn.SequentialCell(*layers) + else: + for dim in dims[:-1]: + layers.append(nn.Dense(in_dim, dim)) + layers.append(nn.ReLU()) + in_dim = dim + layers.append(nn.Dense(in_dim, dims[-1])) + + if final_activation: + layers.append(final_activation) + self.model = nn.SequentialCell(*layers) def construct(self, x): """ @@ -61,7 +72,7 @@ class NormalizedScaleEmbedding(nn.Cell): convolutional and pooling layers, and processes the features through a multi-layer perceptron (MLP). """ - def __init__(self, downsample=5, mlp_dims=(500, 300, 200, 150), eps=1e-8): + def __init__(self, downsample=5, mlp_dims=(500, 300, 200, 150), eps=1e-8, use_mlp=False): """ Initialize the module with given parameters. Parameters: @@ -98,8 +109,20 @@ class NormalizedScaleEmbedding(nn.Cell): self.conv1d_5 = nn.Conv1d(32, 16, kernel_size=4, has_bias=True, pad_mode="pad") self.flatten = nn.Flatten() - self.mlp = MLP((865,), dims=self.mlp_dims) + self.mlp = MLP((865,), dims=self.mlp_dims, is_mlp=use_mlp) self.leaky_relu = nn.LeakyReLU(alpha=0.01) + self._initialize_weights() + + def _initialize_weights(self): + self.conv2d_1.bias.set_data(ms.numpy.zeros_like(self.conv2d_1.bias)) + self.conv2d_2.bias.set_data(ms.numpy.zeros_like(self.conv2d_2.bias)) + + # For Conv1d layers + self.conv1d_1.bias.set_data(ms.numpy.zeros_like(self.conv1d_1.bias)) + self.conv1d_2.bias.set_data(ms.numpy.zeros_like(self.conv1d_2.bias)) + self.conv1d_3.bias.set_data(ms.numpy.zeros_like(self.conv1d_3.bias)) + self.conv1d_4.bias.set_data(ms.numpy.zeros_like(self.conv1d_4.bias)) + self.conv1d_5.bias.set_data(ms.numpy.zeros_like(self.conv1d_5.bias)) def construct(self, x): """ @@ -213,6 +236,7 @@ class PositionEmbedding(nn.Cell): min_lat, max_lat = wavelengths[0] min_lon, max_lon = wavelengths[1] min_depth, max_depth = wavelengths[2] + assert emb_dim % 10 == 0 lat_dim = emb_dim // 5 lon_dim = emb_dim // 5 depth_dim = emb_dim // 10 @@ -307,7 +331,43 @@ class AddEventToken(nn.Cell): return x +class SingleStationModel(nn.Cell): + """ + A neural network model for processing seismic waveforms from a single station. + This class implements a two-stage processing pipeline: waveform embedding followed by feature extraction. + """ + def __init__(self, waveform_model_dims=(500, 500, 500), + output_mlp_dims=(150, 100, 50, 30, 10), downsample=5, use_mlp=False): + """ + Initialize the SingleStationModel. + + Args: + waveform_model_dims (tuple): Dimensions of the MLP in the waveform embedding module. + Format: (input_dim, hidden_dim1, hidden_dim2, ...) + output_mlp_dims (tuple): Dimensions of the final MLP for feature extraction. + Format: (input_dim, hidden_dim1, hidden_dim2, ...) + downsample (int): Factor by which to downsample the input waveform data. + """ + super().__init__() + + self.waveform_model = NormalizedScaleEmbedding(downsample=downsample, mlp_dims=waveform_model_dims, + use_mlp=use_mlp) + self.mlp_mag_single_station = MLP((self.waveform_model.mlp_dims[-1],), output_mlp_dims) + + def construct(self, x): + """ + Forward pass of the SingleStationModel. + + Args: + x (Tensor): Input waveform data with shape (batch_size, time_steps, features) + + Returns: + Tensor: Extracted features with shape (batch_size, output_features) + """ + emb = self.waveform_model(x) + emb_mlp = self.mlp_mag_single_station(emb) + return emb_mlp def _init_pad_mask(waveforms, pga_targets): """ _init_pad_mask function, used to initialize the padding mask. @@ -344,10 +404,11 @@ class WaveformFullmodel(nn.Cell): hidden_dropout=0.0, n_pga_targets=0, downsample=5, + use_mlp=False ): super().__init__() self.waveform_model = NormalizedScaleEmbedding( - downsample=downsample, mlp_dims=waveform_model_dims + downsample=downsample, mlp_dims=waveform_model_dims, use_mlp=use_mlp ) self.transformer = TransformerEncoder( d_model=waveform_model_dims[-1], @@ -357,12 +418,12 @@ class WaveformFullmodel(nn.Cell): dropout=hidden_dropout, ) - self.mlp_mag = MLP((waveform_model_dims[-1],), output_mlp_dims) + self.mlp_mag = MLP((waveform_model_dims[-1],), output_mlp_dims, is_mlp=use_mlp) self.mlp_loc = MLP( - (waveform_model_dims[-1],), output_location_dims, final_activation=None + (waveform_model_dims[-1],), output_location_dims, final_activation=None, is_mlp=use_mlp ) self.mlp_pga = MLP( - (waveform_model_dims[-1],), output_mlp_dims, final_activation=None + (waveform_model_dims[-1],), output_mlp_dims, final_activation=None, is_mlp=use_mlp ) self.position_embedding = PositionEmbedding( diff --git a/MindEarth/applications/earthquake/G-TEAM/src/utils.py b/MindEarth/applications/earthquake/G-TEAM/src/utils.py index f1a4ae25019e61ffeeb4621f2d2ccabd6d36ecfe..f8ae9cabf85584f8adcb6d33eb1eaf4acec0cffc 100644 --- a/MindEarth/applications/earthquake/G-TEAM/src/utils.py +++ b/MindEarth/applications/earthquake/G-TEAM/src/utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"GTeam util" +"""GTeam util""" import os import copy import numpy as np @@ -21,12 +21,11 @@ from geopy.distance import geodesic import mindspore as ms import mindspore.ops as ops +from mindearth import create_logger from src import data from src.data import generator_from_config, D2KM from src.models import WaveformFullmodel -from mindearth.utils import create_logger - def predict_at_time( model, @@ -71,32 +70,23 @@ def predict_at_time( loc_pred_filter = [] pga_pred_filter = [] - for i, (start, end) in enumerate( - zip(generator.dataset.reverse_index[:-1], generator.dataset.reverse_index[1:]) - ): - sample_mag_pred = predictions[0][start:end].reshape( - (-1,) + predictions[0].shape[-1:] - ) - sample_mag_pred = sample_mag_pred[: len(generator.dataset.pga[i])] + for i, (start, end) in enumerate(zip(generator.reverse_index[:-1], generator.reverse_index[1:])): + sample_mag_pred = predictions[0][start:end].reshape((-1,) + predictions[0].shape[-1:]) + sample_mag_pred = sample_mag_pred[:len(generator.pga[i])] mag_pred_filter += [sample_mag_pred] - sample_loc_pred = predictions[1][start:end].reshape( - (-1,) + predictions[1].shape[-1:] - ) - sample_loc_pred = sample_loc_pred[: len(generator.dataset.pga[i])] + sample_loc_pred = predictions[1][start:end].reshape((-1,) + predictions[1].shape[-1:]) + sample_loc_pred = sample_loc_pred[:len(generator.pga[i])] loc_pred_filter += [sample_loc_pred] - sample_pga_pred = predictions[2][start:end].reshape( - (-1,) + predictions[2].shape[-1:] - ) - sample_pga_pred = sample_pga_pred[: len(generator.dataset.pga[i])] + sample_pga_pred = predictions[2][start:end].reshape((-1,) + predictions[2].shape[-1:]) + sample_pga_pred = sample_pga_pred[:len(generator.pga[i])] pga_pred_filter += [sample_pga_pred] preds = [mag_pred_filter, loc_pred_filter, pga_pred_filter] return preds - def calc_mag_stats(mag_pred, event_metadata, key): """Calculate statistical information for magnitude predictions""" mean_mag = mag_pred @@ -109,7 +99,6 @@ def calc_mag_stats(mag_pred, event_metadata, key): mae = metrics.mean_absolute_error(true_mag, mean_mag) return r2, rmse, mae - def calc_pga_stats(pga_pred, pga_true, suffix=""): """Calculate statistical information for PGA predictions""" if suffix: @@ -123,7 +112,6 @@ def calc_pga_stats(pga_pred, pga_true, suffix=""): return [r2, rmse, mae] - def calc_loc_stats(loc_pred, event_metadata, pos_offset): """Calculate statistical information for location predictions""" coord_keys = data.detect_location_keys(event_metadata.columns) @@ -153,19 +141,70 @@ def calc_loc_stats(loc_pred, event_metadata, pos_offset): return rmse_hypo, mae_hypo, rmse_epi, mae_epi +def seed_np_tf(seed): + '''Set the random seed for numpy and manual seed for mindspore.''' + np.random.seed(seed) + ms.manual_seed(seed) + + +def evaluation(full_model, val_generator, losses, loss_weights): + """ + Evaluates the performance of the full_model on the validation data provided by val_generator. + Calculates the average validation loss by accumulating losses from different components (magnitude, location, pga) + using the specified loss functions and weights. + Args: + full_model (nn.Cell): The complete model to be evaluated in inference mode. + val_generator (generator): A generator that yields batches of validation data (x, y). + Each x is expected to be a tuple of three input tensors, and y is a tuple of three target tensors. + losses (dict): A dictionary mapping loss names to their respective loss functions. + Supported keys: 'magnitude', 'location', 'pga'. + loss_weights (dict): A dictionary mapping loss names to their corresponding weights. + Returns: + float: The average validation loss over the entire validation dataset. + """ + full_model.set_train(False) + epoch_val_loss = 0 + for i in range(len(val_generator)): + x, y = val_generator[i] + outputs = full_model(x[0], x[1], x[2]) + total_val_loss = ms.Tensor(0) + + for k, loss_fn in losses.items(): + if k == 'magnitude': + mag_pre = outputs[0] + mag_target = y[0] + mag_loss = loss_fn(mag_target.squeeze(2), mag_pre) * loss_weights[k] + total_val_loss += mag_loss + elif k == 'location': + loc_pre = outputs[1] + loc_target = y[1] + loc_loss = loss_fn(loc_target.squeeze(2), loc_pre) * loss_weights[k] + total_val_loss += loc_loss + elif k == 'pga': + pga_pre = outputs[2] + pga_target = ops.log(ops.abs(y[2])) + pga_loss = loss_fn(pga_target.squeeze(3), pga_pre) * loss_weights[k] + total_val_loss += pga_loss + epoch_val_loss += total_val_loss.item() + avg_val_loss = epoch_val_loss / len(val_generator) + return avg_val_loss def init_model(arg): """set model""" tmpcfg = copy.deepcopy(arg["model"]) + tmpcfg.pop("istraining") tmpcfg.pop("no_event_token") tmpcfg.pop("run_with_less_data") tmpcfg.pop("pga") tmpcfg.pop("mode") tmpcfg.pop("times") + tmpcfg.pop("max_stations") model = WaveformFullmodel(**tmpcfg) - param_dict = ms.load_checkpoint(arg["summary"].get("ckpt_path")) - # Load parameters into the network - ms.load_param_into_net(model, param_dict) - model.set_train(False) + if arg['model']['istraining']: + model.set_train(True) + else: + param_dict = ms.load_checkpoint(arg["summary"].get("ckpt_path")) + ms.load_param_into_net(model, param_dict) + model.set_train(False) return model diff --git a/MindEarth/applications/earthquake/G-TEAM/src/visual.py b/MindEarth/applications/earthquake/G-TEAM/src/visual.py index ef3dc119504fb9d8ef48bbd65dcebf626d60d044..78354a6f8fb4eb2a0b40cc4b8672a004eafda847 100644 --- a/MindEarth/applications/earthquake/G-TEAM/src/visual.py +++ b/MindEarth/applications/earthquake/G-TEAM/src/visual.py @@ -18,7 +18,6 @@ import matplotlib.pyplot as plt import numpy as np import sklearn.metrics as metrics - def generate_true_pred_plot(pred_values, true_values, time, path, suffix=""): """ Generate a plot comparing true values and predicted values, and calculate diff --git a/MindSPONGE/src/mindsponge/pipeline/models/esm_if1/module/features.py b/MindSPONGE/src/mindsponge/pipeline/models/esm_if1/module/features.py index a164be032b51cce8812337c6d9dee4dd1472e349..9fb79ddd1c40ed54369d1f56aa1a47bb9db10b19 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/esm_if1/module/features.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/esm_if1/module/features.py @@ -147,7 +147,7 @@ class GVPInputFeaturizer(nn.Cell): d_neighbors, e_idx = ops.Sort(axis=-1, descending=True)(d_adjust) else: d_neighbors, e_idx = ops.TopK(sorted=True)(d_adjust, d_adjust.shape[-1]) - d_neighbors, e_idx = d_neighbors[..., ::-1], e_idx[..., ::-1] + d_neighbors, e_idx = ms.mint.flip(d_neighbors, [-1]), ms.mint.flip(e_idx, [-1]) d_neighbors, e_idx = d_neighbors[:, :, 0:int(min(top_k_neighbors, x.shape[1]))], \ e_idx[:, :, 0:int(min(top_k_neighbors, x.shape[1]))] d_neighbors = ms.Tensor(d_neighbors, ms.float32)*1e4 diff --git a/MindSPONGE/src/mindsponge/pipeline/models/ufold/ufold.py b/MindSPONGE/src/mindsponge/pipeline/models/ufold/ufold.py index 5e45acd649c14f3cbea4e626073dac5ac24920b9..742e61dbf385fc00a39191aecf49b07ea3f56738 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/ufold/ufold.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/ufold/ufold.py @@ -143,7 +143,7 @@ class UFold(Model): seq_embedding_batch = Tensor(ops.Cast()(seq_embeddings, mstype.float32)) pred_contacts = self.network(seq_embedding_batch) contact_masks = ops.ZerosLike()(pred_contacts) - contact_masks[:, :seq_lens.item(0), :seq_lens.item(0)] = 1 + contact_masks[:, :seq_lens[0].item(), :seq_lens[0].item()] = 1 contact_masks = contact_masks.astype(ms.float32) feat = [seq_embedding_batch, contact_masks, contacts_batch] feat = mutable(feat) diff --git a/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.AttentionInteractionNetwork.rst b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.AttentionInteractionNetwork.rst new file mode 100644 index 0000000000000000000000000000000000000000..7778f3a537968f0a3a629102da109eae1e335619 --- /dev/null +++ b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.AttentionInteractionNetwork.rst @@ -0,0 +1,32 @@ +mindchemistry.cell.orb.AttentionInteractionNetwork +================================================== + +.. py:class:: mindchemistry.cell.orb.AttentionInteractionNetwork(num_node_in: int, num_node_out: int, num_edge_in: int, num_edge_out: int, num_mlp_layers: int, mlp_hidden_dim: int, attention_gate: str = "sigmoid", distance_cutoff: bool = True, polynomial_order: int = 4, cutoff_rmax: float = 6.0) + + 注意力交互网络。实现基于注意力机制的消息传递神经网络层,用于分子图的边更新。 + + 参数: + - **num_node_in** (int) - 节点输入特征数量。 + - **num_node_out** (int) - 节点输出特征数量。 + - **num_edge_in** (int) - 边输入特征数量。 + - **num_edge_out** (int) - 边输出特征数量。 + - **num_mlp_layers** (int) - 节点和边更新MLP的隐藏层数量。 + - **mlp_hidden_dim** (int) - MLP的隐藏维度大小。 + - **attention_gate** (str,可选) - 注意力门类型, ``"sigmoid"`` 或 ``"softmax"``。默认值: ``"sigmoid"``。 + - **distance_cutoff** (bool,可选) - 是否使用基于距离的边截断。默认值: ``True``。 + - **polynomial_order** (int,可选) - 多项式截断函数的阶数。默认值: ``4``。 + - **cutoff_rmax** (float,可选) - 截断的最大距离。默认值: ``6.0``。 + + 输入: + - **graph_edges** (dict) - 边特征字典,必须包含键"feat",形状为 :math:`(n_{edges}, num\_edge\_in)`。 + - **graph_nodes** (dict) - 节点特征字典,必须包含键"feat",形状为 :math:`(n_{nodes}, num\_node\_in)`。 + - **senders** (Tensor) - 每条边的发送节点索引,形状为 :math:`(n_{edges},)`。 + - **receivers** (Tensor) - 每条边的接收节点索引,形状为 :math:`(n_{edges},)`。 + + 输出: + - **edges** (dict) - 更新的边特征字典,键"feat"的形状为 :math:`(n_{edges}, num\_edge\_out)`。 + - **nodes** (dict) - 更新的节点特征字典,键"feat"的形状为 :math:`(n_{nodes}, num\_node\_out)`。 + + 异常: + - **ValueError** - 如果 `attention_gate` 不是"sigmoid"或"softmax"。 + - **ValueError** - 如果边或节点特征不包含必需的"feat"键。 \ No newline at end of file diff --git a/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.EnergyHead.rst b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.EnergyHead.rst new file mode 100644 index 0000000000000000000000000000000000000000..fb549db328a7b26008669b6bbe43d7dc1a925bd9 --- /dev/null +++ b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.EnergyHead.rst @@ -0,0 +1,28 @@ +mindchemistry.cell.orb.EnergyHead +================================== + +.. py:class:: mindchemistry.cell.orb.EnergyHead(latent_dim: int, num_mlp_layers: int, mlp_hidden_dim: int, target_property_dim: int, predict_atom_avg: bool = True, reference_energy_name: str = "mp-traj-d3", train_reference: bool = False, dropout: Optional[float] = None, node_aggregation: Optional[str] = None) + + 图级能量预测头。实现用于预测分子图总能量或原子平均能量的神经网络头。支持节点级聚合、参考能量偏移和灵活的输出模式。 + + 参数: + - **latent_dim** (int) - 每个节点的输入特征维度。 + - **num_mlp_layers** (int) - MLP中的隐藏层数量。 + - **mlp_hidden_dim** (int) - MLP的隐藏维度大小。 + - **target_property_dim** (int) - 能量属性的输出维度(通常为1)。 + - **predict_atom_avg** (bool,可选) - 是否预测每原子平均能量而不是总能量。默认值: ``True``。 + - **reference_energy_name** (str,可选) - 用于偏移的参考能量名称,例如 ``"vasp-shifted"``。默认值: ``"mp-traj-d3"``。 + - **train_reference** (bool,可选) - 是否将参考能量训练为可学习参数。默认值: ``False``。 + - **dropout** (Optional[float],可选) - MLP的dropout率。默认值: ``None``。 + - **node_aggregation** (str,可选) - 节点预测的聚合方法,例如 ``"mean"`` 或 ``"sum"``。默认值: ``None``。 + + 输入: + - **node_features** (dict) - 节点特征字典,必须包含键"feat",形状为 :math:`(n_{nodes}, latent\_dim)`。 + - **n_node** (Tensor) - 图中节点数量,形状为 :math:`(1,)`。 + + 输出: + - **output** (dict) - 包含键"graph_pred"的字典,值的形状为 :math:`(1, target\_property\_dim)`。 + + 异常: + - **ValueError** - 如果 `node_features` 中缺少必需的特征键。 + - **ValueError** - 如果 `node_aggregation` 不是支持的类型。 \ No newline at end of file diff --git a/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.GraphHead.rst b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.GraphHead.rst new file mode 100644 index 0000000000000000000000000000000000000000..75ae5ad7c52da0a39f5b91142a8458c68fd51323 --- /dev/null +++ b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.GraphHead.rst @@ -0,0 +1,25 @@ +mindchemistry.cell.orb.GraphHead +================================= + +.. py:class:: mindchemistry.cell.orb.GraphHead(latent_dim: int, num_mlp_layers: int, mlp_hidden_dim: int, target_property_dim: int, node_aggregation: str = "mean", dropout: Optional[float] = None, compute_stress: Optional[bool] = False) + + 图级预测头。实现可以附加到基础模型的图级预测头,用于从节点特征预测图级属性(例如应力张量),使用聚合和MLP。 + + 参数: + - **latent_dim** (int) - 每个节点的输入特征维度。 + - **num_mlp_layers** (int) - MLP中的隐藏层数量。 + - **mlp_hidden_dim** (int) - MLP的隐藏维度大小。 + - **target_property_dim** (int) - 图级属性的输出维度。 + - **node_aggregation** (str,可选) - 节点预测的聚合方法,例如 ``"mean"`` 或 ``"sum"``。默认值: ``"mean"``。 + - **dropout** (Optional[float],可选) - MLP的dropout率。默认值: ``None``。 + - **compute_stress** (bool,可选) - 是否计算和输出应力张量。默认值: ``False``。 + + 输入: + - **node_features** (dict) - 节点特征字典,必须包含键"feat",形状为 :math:`(n_{nodes}, latent\_dim)`。 + - **n_node** (Tensor) - 图中节点数量,形状为 :math:`(1,)`。 + + 输出: + - **output** (dict) - 包含键"stress_pred"的字典,值的形状为 :math:`(1, target\_property\_dim)`。 + + 异常: + - **ValueError** - 如果 `node_features` 中缺少必需的特征键。 \ No newline at end of file diff --git a/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.MoleculeGNS.rst b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.MoleculeGNS.rst new file mode 100644 index 0000000000000000000000000000000000000000..c44551f329f4d6bd28dcacf08210fc60cc7a662d --- /dev/null +++ b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.MoleculeGNS.rst @@ -0,0 +1,34 @@ +mindchemistry.cell.orb.MoleculeGNS +=================================== + +.. py:class:: mindchemistry.cell.orb.MoleculeGNS(num_node_in_features: int, num_node_out_features: int, num_edge_in_features: int, latent_dim: int, num_message_passing_steps: int, num_mlp_layers: int, mlp_hidden_dim: int, node_feature_names: List[str], edge_feature_names: List[str], use_embedding: bool = True, interactions: str = "simple_attention", interaction_params: Optional[Dict[str, Any]] = None) + + 分子图神经网络。实现用于分子性质预测的灵活模块化图神经网络,基于注意力或其他交互机制的消息传递。支持节点和边嵌入、多个消息传递步骤,以及用于复杂分子图的可定制交互层。 + + 参数: + - **num_node_in_features** (int) - 每个节点的输入特征数量。 + - **num_node_out_features** (int) - 每个节点的输出特征数量。 + - **num_edge_in_features** (int) - 每条边的输入特征数量。 + - **latent_dim** (int) - 节点和边表示的潜在维度。 + - **num_message_passing_steps** (int) - 消息传递层的数量。 + - **num_mlp_layers** (int) - 节点和边更新MLP的隐藏层数量。 + - **mlp_hidden_dim** (int) - MLP的隐藏维度大小。 + - **node_feature_names** (List[str]) - 从输入字典中使用的节点特征键列表。 + - **edge_feature_names** (List[str]) - 从输入字典中使用的边特征键列表。 + - **use_embedding** (bool,可选) - 是否对节点使用原子序数嵌入。默认值: ``True``。 + - **interactions** (str,可选) - 要使用的交互层类型(例如, ``"simple_attention"``)。默认值: ``"simple_attention"``。 + - **interaction_params** (Optional[Dict[str, Any]],可选) - 交互层的参数,例如截断、多项式阶数、门类型。默认值: ``None``。 + + 输入: + - **edge_features** (dict) - 边特征字典,必须包含 `edge_feature_names` 中指定的键。 + - **node_features** (dict) - 节点特征字典,必须包含 `node_feature_names` 中指定的键。 + - **senders** (Tensor) - 每条边的发送节点索引,形状为 :math:`(n_{edges},)`。 + - **receivers** (Tensor) - 每条边的接收节点索引,形状为 :math:`(n_{edges},)`。 + + 输出: + - **edges** (dict) - 更新的边特征字典,键"feat"的形状为 :math:`(n_{edges}, latent\_dim)`。 + - **nodes** (dict) - 更新的节点特征字典,键"feat"的形状为 :math:`(n_{nodes}, latent\_dim)`。 + + 异常: + - **ValueError** - 如果 `edge_features` 或 `node_features` 中缺少必需的特征键。 + - **ValueError** - 如果 `interactions` 不是支持的类型。 \ No newline at end of file diff --git a/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.NodeHead.rst b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.NodeHead.rst new file mode 100644 index 0000000000000000000000000000000000000000..2e422d861892a40688faea4b8b38ca0a5848bb63 --- /dev/null +++ b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.NodeHead.rst @@ -0,0 +1,26 @@ +mindchemistry.cell.orb.NodeHead +=============================== + +.. py:class:: mindchemistry.cell.orb.NodeHead(latent_dim: int, num_mlp_layers: int, mlp_hidden_dim: int, target_property_dim: int, dropout: Optional[float] = None, remove_mean: bool = True) + + 节点级预测头。 + + 实现用于从节点特征预测节点级属性的神经网络头。该头可以添加到基础模型中以在预训练期间启用辅助任务,或在微调步骤中添加。 + + 参数: + - **latent_dim** (int) - 每个节点的输入特征维度。 + - **num_mlp_layers** (int) - MLP中的隐藏层数量。 + - **mlp_hidden_dim** (int) - MLP的隐藏维度大小。 + - **target_property_dim** (int) - 节点级目标属性的输出维度。 + - **dropout** (Optional[float],可选) - MLP的dropout率。默认值: ``None``。 + - **remove_mean** (bool,可选) - 如果为True,从输出中移除均值,通常用于力预测。默认值: ``True``。 + + 输入: + - **node_features** (dict) - 节点特征字典,必须包含键 "feat",形状为 :math:`(n_{nodes}, latent\_dim)`。 + - **n_node** (Tensor) - 图中节点数量,形状为 :math:`(1,)`。 + + 输出: + - **output** (dict) - 包含键 "node_pred" 的字典,值的形状为 :math:`(n_{nodes}, target\_property\_dim)`。 + + 异常: + - **ValueError** - 如果 `node_features` 中缺少必需的特征键。 \ No newline at end of file diff --git a/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.Orb.rst b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.Orb.rst new file mode 100644 index 0000000000000000000000000000000000000000..fe1d53a26ce54523eb47205d9a5fa5729892f8c2 --- /dev/null +++ b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.Orb.rst @@ -0,0 +1,35 @@ +mindchemistry.cell.orb.Orb +=========================== + +.. py:class:: mindchemistry.cell.orb.Orb(model: MoleculeGNS, node_head: Optional[NodeHead] = None, graph_head: Optional[GraphHead] = None, stress_head: Optional[GraphHead] = None, model_requires_grad: bool = True, cutoff_layers: Optional[int] = None) + + Orb图回归器。将预训练的基础模型(如MoleculeGNS)与可选的节点、图和应力回归头结合,支持微调或特征提取工作流程。 + + 参数: + - **model** (MoleculeGNS) - 用于消息传递和特征提取的预训练或随机初始化基础模型。 + - **node_head** (NodeHead,可选) - 节点级属性预测的回归头。默认值: ``None``。 + - **graph_head** (GraphHead,可选) - 图级属性预测(例如能量)的回归头。默认值: ``None``。 + - **stress_head** (GraphHead,可选) - 应力预测的回归头。默认值: ``None``。 + - **model_requires_grad** (bool,可选) - 是否微调基础模型(True)或冻结其参数(False)。默认值: ``True``。 + - **cutoff_layers** (int,可选) - 如果提供,仅使用基础模型的前 ``"cutoff_layers"`` 个消息传递层。默认值: ``None``。 + + 输入: + - **edge_features** (dict) - 边特征字典(例如,`{"vectors": Tensor, "r": Tensor}`)。 + - **node_features** (dict) - 节点特征字典(例如,`{"atomic_numbers": Tensor, ...}`)。 + - **senders** (Tensor) - 每条边的发送节点索引。形状::math:`(n_{edges},)`。 + - **receivers** (Tensor) - 每条边的接收节点索引。形状::math:`(n_{edges},)`。 + - **n_node** (Tensor) - 批次中每个图的节点数量。形状::math:`(n_{graphs},)`。 + + 输出: + - **output** (dict) - 包含以下内容的字典: + + - **edges** (dict) - 消息传递后的边特征,例如 `{..., "feat": Tensor}`。 + - **nodes** (dict) - 消息传递后的节点特征,例如 `{..., "feat": Tensor}`。 + - **graph_pred** (Tensor) - 图级预测,例如能量。形状::math:`(n_{graphs}, target\_property\_dim)`。 + - **node_pred** (Tensor) - 节点级预测。形状::math:`(n_{nodes}, target\_property\_dim)`。 + - **stress_pred** (Tensor) - 应力预测(如果提供stress_head)。形状::math:`(n_{graphs}, 6)`。 + + 异常: + - **ValueError** - 如果既未提供node_head也未提供graph_head。 + - **ValueError** - 如果cutoff_layers超过基础模型中的消息传递步骤数。 + - **ValueError** - 如果graph_head需要时未提供atomic_numbers。 \ No newline at end of file diff --git a/docs/api_python/mindchemistry/mindchemistry.cell.rst b/docs/api_python/mindchemistry/mindchemistry.cell.rst index 3c509e9ec3ab7962b7ee447049d53aed40bc0178..d346532e43dc2c3414d4e35edcb73aa369809ab4 100644 --- a/docs/api_python/mindchemistry/mindchemistry.cell.rst +++ b/docs/api_python/mindchemistry/mindchemistry.cell.rst @@ -10,3 +10,9 @@ mindchemistry.cell mindchemistry.cell.AutoEncoder mindchemistry.cell.FCNet mindchemistry.cell.MLPNet + mindchemistry.cell.orb.AttentionInteractionNetwork + mindchemistry.cell.orb.EnergyHead + mindchemistry.cell.orb.GraphHead + mindchemistry.cell.orb.MoleculeGNS + mindchemistry.cell.orb.NodeHead + mindchemistry.cell.orb.Orb \ No newline at end of file diff --git a/docs/api_python_en/mindchemistry/mindchemistry.cell.rst b/docs/api_python_en/mindchemistry/mindchemistry.cell.rst index a5c9d41431795142d9524825cb3f42482dfe2cfc..78cd3bfffeded29ea9cbc13642292928e4f6d464 100644 --- a/docs/api_python_en/mindchemistry/mindchemistry.cell.rst +++ b/docs/api_python_en/mindchemistry/mindchemistry.cell.rst @@ -10,3 +10,9 @@ mindchemistry.cell mindchemistry.cell.AutoEncoder mindchemistry.cell.FCNet mindchemistry.cell.MLPNet + mindchemistry.cell.orb.AttentionInteractionNetwork + mindchemistry.cell.orb.EnergyHead + mindchemistry.cell.orb.GraphHead + mindchemistry.cell.orb.MoleculeGNS + mindchemistry.cell.orb.NodeHead + mindchemistry.cell.orb.Orb \ No newline at end of file diff --git a/tests/st/mindchemistry/cell/test_orb/base.py b/tests/st/mindchemistry/cell/test_orb/base.py new file mode 100644 index 0000000000000000000000000000000000000000..12ebf5a21d0b2a32f60d24e838446b03af9276ca --- /dev/null +++ b/tests/st/mindchemistry/cell/test_orb/base.py @@ -0,0 +1,119 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Base data class.""" + +from typing import Dict, Mapping, NamedTuple, Optional, Union + +from mindspore import Tensor + + +Metric = Union[Tensor, int, float] +TensorDict = Mapping[str, Optional[Tensor]] + + +class ModelOutput(NamedTuple): + """A model's output.""" + + loss: Tensor + log: Mapping[str, Metric] + + +class AtomGraphs(NamedTuple): + """A class representing the input to a model for a graph. + + Args: + senders (ms.Tensor): The integer source nodes for each edge. + receivers (ms.Tensor): The integer destination nodes for each edge. + n_node (ms.Tensor): A (batch_size, ) shaped tensor containing the number of nodes per graph. + n_edge (ms.Tensor): A (batch_size, ) shaped tensor containing the number of edges per graph. + node_features (Dict[str, ms.Tensor]): A dictionary containing node feature tensors. + It will always contain "atomic_numbers" and "positions" keys, representing the + atomic numbers of each node, and the 3d cartesian positions of them respectively. + edge_features (Dict[str, ms.Tensor]): A dictionary containing edge feature tensors. + system_features (Optional[TensorDict]): An optional dictionary containing system-level features. + node_targets (Optional[Dict[ms.Tensor]]): An optional dict of tensors containing targets + for individual nodes. This tensor is commonly expected to have shape (num_nodes, *). + edge_target (Optional[ms.Tensor]): An optional tensor containing targets for individual edges. + This tensor is commonly expected to have (num_edges, *). + system_targets (Optional[Dict[ms.Tensor]]): An optional dict of tensors containing targets for the + entire system. system_id (Optional[ms.Tensor]): An optional tensor containing the ID of the system. + fix_atoms (Optional[ms.Tensor]): An optional tensor containing information on fixed atoms in the system. + """ + + senders: Tensor + receivers: Tensor + n_node: Tensor + n_edge: Tensor + node_features: Dict[str, Tensor] + edge_features: Dict[str, Tensor] + system_features: Dict[str, Tensor] + node_targets: Optional[Dict[str, Tensor]] = None + edge_targets: Optional[Dict[str, Tensor]] = None + system_targets: Optional[Dict[str, Tensor]] = None + system_id: Optional[Tensor] = None + fix_atoms: Optional[Tensor] = None + tags: Optional[Tensor] = None + radius: Optional[float] = None + max_num_neighbors: Optional[int] = None + + @property + def positions(self): + """Get positions of atoms.""" + return self.node_features["positions"] + + @positions.setter + def positions(self, val: Tensor): + self.node_features["positions"] = val + + @property + def atomic_numbers(self): + """Get integer atomic numbers.""" + return self.node_features["atomic_numbers"] + + @atomic_numbers.setter + def atomic_numbers(self, val: Tensor): + self.node_features["atomic_numbers"] = val + + @property + def cell(self): + """Get unit cells.""" + assert self.system_features + return self.system_features.get("cell") + + @cell.setter + def cell(self, val: Tensor): + assert self.system_features + self.system_features["cell"] = val + + def clone(self): + """Clone the AtomGraphs object.""" + return AtomGraphs( + senders=self.senders.clone(), + receivers=self.receivers.clone(), + n_node=self.n_node.clone(), + n_edge=self.n_edge.clone(), + node_features={k: v.clone() for k, v in self.node_features.items()}, + edge_features={k: v.clone() for k, v in self.edge_features.items()}, + system_features={k: v.clone() for k, v in self.system_features.items()}, + node_targets={k: v.clone() for k, v in (self.node_targets or {}).items()}, + edge_targets=self.edge_targets.clone() if self.edge_targets is not None else None, + system_targets={k: v.clone() for k, v in (self.system_targets or {}).items()}, + system_id=self.system_id.clone() if self.system_id is not None else None, + fix_atoms=self.fix_atoms.clone() if self.fix_atoms is not None else None, + tags=self.tags.clone() if self.tags is not None else None, + radius=self.radius, + max_num_neighbors=self.max_num_neighbors + ) diff --git a/tests/st/mindchemistry/cell/test_orb/test_orb.py b/tests/st/mindchemistry/cell/test_orb/test_orb.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb6534088433e267d87585eda9781da29755aac --- /dev/null +++ b/tests/st/mindchemistry/cell/test_orb/test_orb.py @@ -0,0 +1,451 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test mindchemistry ORB""" + +import os +import sys +from typing import Optional +import pickle + +import requests +import pytest +import numpy as np +import mindspore +from mindspore import nn, Tensor, load_checkpoint, load_param_into_net + +from mindchemistry.cell import ( + AttentionInteractionNetwork, + MoleculeGNS, + NodeHead, + GraphHead, + EnergyHead, + Orb, +) +import base +from utils import numpy_to_tensor, tensor_to_numpy, is_equal + +# pylint: disable=C0413 +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")) +sys.path.append(PROJECT_ROOT) +from common.cell import compare_output + + +def load_graph_data(pkl_path: str): + """Load graph data from pickle file. + Args: + pkl_path: Path to the pickle file + Returns: + tuple: (input_graph_ms, output_graph_np) + """ + with open(pkl_path, "rb") as f: + loaded = pickle.load(f) + + input_graph_np = loaded["input_graph"] + output_graph_np = loaded["output_graph"] + + input_graph_ms = base.AtomGraphs( + *[numpy_to_tensor(getattr(input_graph_np, field)) + for field in input_graph_np._fields] + ) + + return input_graph_ms, output_graph_np + + +def get_gns( + latent_dim: int = 256, + mlp_hidden_dim: int = 512, + num_message_passing_steps: int = 15, + num_edge_in_features: int = 23, + distance_cutoff: bool = True, + attention_gate: str = "sigmoid", +) -> MoleculeGNS: + """Define the base pretrained model architecture. + Args: + latent_dim: The latent dimension of the model. + mlp_hidden_dim: The hidden dimension of the MLP layers. + num_message_passing_steps: The number of message passing steps. + num_edge_in_features: The number of edge input features. + distance_cutoff: Whether to use distance cutoff in the interaction. + attention_gate: The type of attention gate to use. + Returns: + MoleculeGNS: The MoleculeGNS model instance. + """ + return MoleculeGNS( + num_node_in_features=256, + num_node_out_features=3, + num_edge_in_features=num_edge_in_features, + latent_dim=latent_dim, + interactions="simple_attention", + interaction_params={ + "distance_cutoff": distance_cutoff, + "polynomial_order": 4, + "cutoff_rmax": 6, + "attention_gate": attention_gate, + }, + num_message_passing_steps=num_message_passing_steps, + num_mlp_layers=2, + mlp_hidden_dim=mlp_hidden_dim, + use_embedding=True, + node_feature_names=["feat"], + edge_feature_names=["feat"], + ) + + +def load_model_for_inference(model: nn.Cell, weights_path: str) -> nn.Cell: + """Load a pretrained model in inference mode. + Args: + model: The model to load the weights into. + weights_path: Path to the checkpoint file. + Returns: + nn.Cell: The model with loaded weights. + Raises: + FileNotFoundError: If the checkpoint file does not exist. + ValueError: If the checkpoint file has more parameters than the model. + """ + if not os.path.exists(weights_path): + raise FileNotFoundError(f"Checkpoint file {weights_path} not found.") + param_dict = load_checkpoint(weights_path) + + try: + load_param_into_net(model, param_dict) + except ValueError: + print("Warning: The checkpoint file has more parameters than the model. \ + This may be due to a mismatch in the model architecture or version.") + params = [] + for key in param_dict: + params.append(param_dict[key]) + for parameters in model.trainable_params(): + param_ckpt = params.pop(0) + assert parameters.shape == param_ckpt.shape, f"Shape mismatch: {parameters.name}" + param_ckpt = param_ckpt.reshape(parameters.shape) + parameters.set_data(param_ckpt) + + model.set_train(False) + return model + +def orb_v2(weights_path: Optional[str]) -> nn.Cell: + """Load ORB v2. + Args: + weights_path: Path to the checkpoint file. + Returns: + Orb GraphRegressor: The ORB v2 model instance. + """ + gns = get_gns() + + model = Orb( + graph_head=EnergyHead( + latent_dim=256, + num_mlp_layers=1, + mlp_hidden_dim=256, + target_property_dim=1, + node_aggregation="mean", + reference_energy_name="vasp-shifted", + train_reference=True, + predict_atom_avg=True, + ), + node_head=NodeHead( + latent_dim=256, + num_mlp_layers=1, + mlp_hidden_dim=256, + target_property_dim=3, + remove_mean=True, + ), + stress_head=GraphHead( + latent_dim=256, + num_mlp_layers=1, + mlp_hidden_dim=256, + target_property_dim=6, + compute_stress=True, + ), + model=gns, + ) + model = load_model_for_inference(model, weights_path) + return model + + +def orb_mptraj_only_v2( + weights_path: Optional[str] = None, +): + """Load ORB MPTraj Only v2.""" + + return orb_v2(weights_path,) + + +def download_file(url, local_filename): + """Download a file from a URL to a local path.""" + response = requests.get(url, timeout=30) + if response.status_code == 200: + with open(local_filename, 'wb') as f: + f.write(response.content) + else: + print(f"Failed to download file. HTTP Status Code: {response.status_code}") + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_attn(): + """ + Feature: Test AttentionInteractionNetwork in platform ascend. + Description: The forward output should has expected shape and accuracy. + Expectation: Success or throw AssertionError. + """ + mindspore.set_context(mode=mindspore.PYNATIVE_MODE) + # prepare data + download_file( + 'https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/test/attn_input_output.pkl', + 'attn_input_output.pkl' + ) + download_file( + 'https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/orb_ckpts/attn_net.ckpt', + 'attn_net.ckpt' + ) + input_graph_ms, output_graph_np = load_graph_data('attn_input_output.pkl') + + attn_net = AttentionInteractionNetwork( + num_node_in=256, + num_node_out=256, + num_edge_in=256, + num_edge_out=256, + num_mlp_layers=2, + mlp_hidden_dim=512, + ) + + # load checkpoint + param_dict = load_checkpoint('attn_net.ckpt') + load_param_into_net(attn_net, param_dict) + + # inference + edges, nodes = attn_net( + input_graph_ms.edge_features, + input_graph_ms.node_features, + input_graph_ms.senders, + input_graph_ms.receivers, + ) + + # Validate results + out_node_feats = tensor_to_numpy(nodes["feat"]) + out_edge_feats = tensor_to_numpy(edges["feat"]) + out_node_feats_np = output_graph_np.node_features["feat"] + out_edge_feats_np = output_graph_np.edge_features["feat"] + + flag_node = is_equal(out_node_feats, out_node_feats_np) + flag_edge = is_equal(out_edge_feats, out_edge_feats_np) + assert flag_node, "Failed! Node features mismatch in attention network" + assert flag_edge, "Failed! Edge features mismatch in attention network" + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_gns(): + """ + Feature: Test MoleculeGNS network in platform ascend. + Description: The forward output should has expected shape and accuracy. + Expectation: Success or throw AssertionError. + """ + mindspore.set_context(mode=mindspore.PYNATIVE_MODE) + download_file( + 'https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/test/gns_input_output.pkl', + 'gns_input_output.pkl' + ) + download_file( + 'https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/orb_ckpts/gns_net.ckpt', + 'gns_net.ckpt' + ) + input_graph_ms, output_graph_np = load_graph_data('gns_input_output.pkl') + + # load gns model and checkpoint + gns_model = get_gns() + + # load checkpoint + param_dict = load_checkpoint('gns_net.ckpt') + load_param_into_net(gns_model, param_dict) + + edges, nodes = gns_model( + input_graph_ms.edge_features, + input_graph_ms.node_features, + input_graph_ms.senders, + input_graph_ms.receivers, + ) + + out_node_feats = tensor_to_numpy(nodes["feat"]) + out_edge_feats = tensor_to_numpy(edges["feat"]) + out_node_feats_np = output_graph_np.node_features["feat"] + out_edge_feats_np = output_graph_np.edge_features["feat"] + + flag_node = is_equal(out_node_feats, out_node_feats_np) + flag_edge = is_equal(out_edge_feats, out_edge_feats_np) + assert flag_node, "Failed! Node features mismatch in MoleculeGNS network" + assert flag_edge, "Failed! Edge features mismatch in MoleculeGNS network" + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_node_head(): + """ + Feature: Test NodeHead in platform ascend. + Description: The forward output should has expected shape and accuracy. + Expectation: Success or throw AssertionError. + """ + mindspore.set_context(mode=mindspore.PYNATIVE_MODE) + node_head = NodeHead( + latent_dim=256, + num_mlp_layers=1, + mlp_hidden_dim=256, + target_property_dim=3, + remove_mean=True, + ) + + n_atoms = 4 + n_node = Tensor([n_atoms], mindspore.int32) + atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) + atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) + for i, num in enumerate(atomic_numbers.asnumpy()): + atomic_numbers_embedding_np[i, num - 1] = 1.0 + + node_features = { + "atomic_numbers": atomic_numbers, + "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), + "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), + "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) + } + + output = node_head(node_features, n_node) + assert output['node_pred'].shape == (4, 3), \ + f"Expected node_pred shape (4, 3), but got {output['node_pred'].shape}" + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_graph_head(): + """ + Feature: Test GraphHead in platform ascend. + Description: The forward output should has expected shape and accuracy. + Expectation: Success or throw AssertionError. + """ + mindspore.set_context(mode=mindspore.PYNATIVE_MODE) + graph_head = GraphHead( + latent_dim=256, + num_mlp_layers=1, + mlp_hidden_dim=256, + target_property_dim=6, + compute_stress=True, + ) + + n_atoms = 4 + n_node = Tensor([n_atoms], mindspore.int32) + atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) + atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) + for i, num in enumerate(atomic_numbers.asnumpy()): + atomic_numbers_embedding_np[i, num - 1] = 1.0 + + node_features = { + "atomic_numbers": atomic_numbers, + "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), + "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), + "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) + } + + output = graph_head(node_features, n_node) + assert output['stress_pred'].shape == (1, 6), \ + f"Expected stress_pred shape (1, 6), but got {output['stress_pred'].shape}" + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_energy_head(): + """ + Feature: Test EnergyHead in platform ascend. + Description: The forward output should has expected shape and accuracy. + Expectation: Success or throw AssertionError. + """ + mindspore.set_context(mode=mindspore.PYNATIVE_MODE) + energy_head = EnergyHead( + latent_dim=256, + num_mlp_layers=1, + mlp_hidden_dim=256, + target_property_dim=1, + node_aggregation="mean", + reference_energy_name="vasp-shifted", + train_reference=True, + predict_atom_avg=True, + ) + + n_atoms = 4 + n_node = Tensor([n_atoms], mindspore.int32) + atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) + atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) + for i, num in enumerate(atomic_numbers.asnumpy()): + atomic_numbers_embedding_np[i, num - 1] = 1.0 + + node_features = { + "atomic_numbers": atomic_numbers, + "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), + "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), + "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) + } + + output = energy_head(node_features, n_node) + assert output['graph_pred'].shape == (1, 1), \ + f"Expected graph_pred shape {(1, 1)}, but got {output['graph_pred'].shape}" + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_inference(): + """ + Feature: Test Orb network in platform ascend. + Description: The forward output should has expected shape and accuracy. + Expectation: Success or throw AssertionError. + """ + mindspore.set_context(mode=mindspore.PYNATIVE_MODE) + download_file( + 'https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/test/orb_input_output.pkl', + 'orb_input_output.pkl' + ) + download_file( + 'https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/orb_ckpts/orb-mptraj-only-v2.ckpt', + 'orb-mptraj-only-v2.ckpt' + ) + reference_path = 'orb_input_output.pkl' + with open(reference_path, "rb") as f: + loaded = pickle.load(f) + + atom_graph_ms = loaded["input_graph"] + output_pt = loaded["output"] + + regressor = orb_mptraj_only_v2(weights_path='orb-mptraj-only-v2.ckpt') + regressor.set_train(False) + + out_ms = regressor.predict( + atom_graph_ms.edge_features, + atom_graph_ms.node_features, + atom_graph_ms.senders, + atom_graph_ms.receivers, + atom_graph_ms.n_node, + atom_graph_ms.atomic_numbers, + ) + + out_ms = {k: tensor_to_numpy(v) for k, v in out_ms.items()} + + for k in out_ms: + flag = compare_output(out_ms[k], output_pt[k]) + assert flag, f"Failed! Orb network inference output {k} mismatch" diff --git a/tests/st/mindchemistry/cell/test_orb/utils.py b/tests/st/mindchemistry/cell/test_orb/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ad100964ae2e4addcd18be6c699012889445031a --- /dev/null +++ b/tests/st/mindchemistry/cell/test_orb/utils.py @@ -0,0 +1,105 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utils""" +import os +import sys +from typing import Any + +import numpy as np +from mindspore import Tensor + +# pylint: disable=C0413 +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")) +sys.path.append(PROJECT_ROOT) +from common.cell import compare_output, FP32_ATOL, FP32_RTOL + + +def tensor_to_numpy(data: Any) -> Any: + """Convert MindSpore Tensors to NumPy arrays recursively. + This function traverses the input data structure and converts all MindSpore Tensors + to NumPy arrays, while leaving other data types unchanged. + Args: + data (Any): Input data which can be a MindSpore Tensor, dict, list, tuple, or other types. + Returns: + Any: Data structure with MindSpore Tensors converted to NumPy arrays. + """ + if isinstance(data, Tensor): + return data.numpy() + if isinstance(data, dict): + return {k: tensor_to_numpy(v) for k, v in data.items()} + if isinstance(data, (list, tuple)): + return type(data)(tensor_to_numpy(v) for v in data) + return data + + +def numpy_to_tensor(data: Any) -> Any: + """Convert NumPy arrays to MindSpore Tensors recursively. + This function traverses the input data structure and converts all NumPy arrays + to MindSpore Tensors, while leaving other data types unchanged. + Args: + data (Any): Input data which can be a NumPy array, dict, list, tuple, or other types. + Returns: + Any: Data structure with NumPy arrays converted to MindSpore Tensors. + """ + if isinstance(data, np.ndarray): + return Tensor(data) + if isinstance(data, dict): + return {k: numpy_to_tensor(v) for k, v in data.items()} + if isinstance(data, (list, tuple)): + return type(data)(numpy_to_tensor(v) for v in data) + return data + + +def is_equal(a: Any, b: Any) -> bool: + """Compare two objects for equality with special handling for different types. + + This function performs a deep comparison between two objects, supporting: + - NumPy arrays (using tolerance-based comparison) + - Dictionaries (recursive comparison of values) + - Lists and tuples (element-wise comparison) + - NamedTuples (field-wise comparison) + - Other types (using standard equality comparison) + + Args: + a (Any): First object to compare + b (Any): Second object to compare + + Returns: + bool: True if objects are considered equal, False otherwise + + Examples: + >>> is_equal(np.array([1.0]), np.array([1.0])) + True + >>> is_equal({'a': 1, 'b': 2}, {'a': 1, 'b': 2}) + True + >>> is_equal([1, 2, 3], [1, 2, 3]) + True + """ + if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + return compare_output(a, b, FP32_ATOL, FP32_RTOL) + if isinstance(a, dict) and isinstance(b, dict): + if a.keys() != b.keys(): + return False + return all(is_equal(a[k], b[k]) for k in a) + if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)): + if len(a) != len(b): + return False + return all(is_equal(x, y) for x, y in zip(a, b)) + if hasattr(a, "_fields") and hasattr(b, "_fields"): + if a._fields != b._fields: + return False + return all(is_equal(getattr(a, f), getattr(b, f)) for f in a._fields) + return a == b diff --git a/tests/st/mindflow/cell/test_optimizers.py b/tests/st/mindflow/cell/test_optimizers.py index 1882e36df8e4902f5e29c5546b825ef76e473b39..4085507fd42dcb9b85facc4ea593c5fb03b249fe 100644 --- a/tests/st/mindflow/cell/test_optimizers.py +++ b/tests/st/mindflow/cell/test_optimizers.py @@ -22,9 +22,9 @@ import pytest import numpy as np import mindspore as ms -from mindspore import ops, set_seed, nn +from mindspore import ops, set_seed, nn, mint from mindspore import dtype as mstype -from mindflow import UNet2D, TransformerBlock, AdaHessian +from mindflow import UNet2D, TransformerBlock, MultiHeadAttention, AdaHessian from mindflow.cell.attention import FeedForward from mindflow.cell.unet2d import Down @@ -75,16 +75,64 @@ class TestUNet2D(UNet2D): class TestAttentionBlock(TransformerBlock): ''' Child class for testing optimizing Attention with AdaHessian ''' - def __init__(self, in_channels, num_heads, drop_mode="dropout", dropout_rate=0.0, compute_dtype=mstype.float16): - super().__init__( - in_channels, num_heads, drop_mode=drop_mode, dropout_rate=dropout_rate, compute_dtype=compute_dtype) + def __init__(self, + in_channels: int, + num_heads: int, + enable_flash_attn: bool = False, + fa_dtype: mstype = mstype.bfloat16, + drop_mode: str = "dropout", + dropout_rate: float = 0.0, + compute_dtype: mstype = mstype.float32, + ): + super().__init__(in_channels=in_channels, + num_heads=num_heads, + enable_flash_attn=enable_flash_attn, + fa_dtype=fa_dtype, + drop_mode=drop_mode, + dropout_rate=dropout_rate, + compute_dtype=compute_dtype, + ) class TestMlp(FeedForward): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.act_fn = nn.ReLU() # replace `gelu` with `relu` to avoid `vjp` problem - self.ffn = TestMlp(in_channels=in_channels, dropout_rate=dropout_rate, compute_dtype=compute_dtype) + class TestMultiHeadAttention(MultiHeadAttention): + ''' MultiHeadAttention modified to support vjp ''' + def get_qkv(self, x: ms.Tensor) -> tuple[ms.Tensor]: + ''' use masks to select out q, k, v, instead of tensor reshaping & indexing ''' + b, n, c_full = x.shape + c = c_full // self.num_heads + + # use matmul with masks to select out q, k, v to avoid vjp problem + q_mask = ms.Tensor(np.vstack([np.eye(c), np.zeros([2 * c, c])]), dtype=self.compute_dtype) + k_mask = ms.Tensor(np.vstack([np.zeros([c, c]), np.eye(c), np.zeros([c, c])]), dtype=self.compute_dtype) + v_mask = ms.Tensor(np.vstack([np.zeros([2 * c, c]), np.eye(c)]), dtype=self.compute_dtype) + + qkv = self.qkv(x) + qkv = qkv.reshape(b, n, self.num_heads, -1).swapaxes(1, 2) + + q = mint.matmul(qkv, q_mask) + k = mint.matmul(qkv, k_mask) + v = mint.matmul(qkv, v_mask) + + return q, k, v + + self.ffn = TestMlp( + in_channels=in_channels, + dropout_rate=dropout_rate, + compute_dtype=compute_dtype, + ) + self.attention = TestMultiHeadAttention( + in_channels=in_channels, + num_heads=num_heads, + enable_flash_attn=enable_flash_attn, + fa_dtype=fa_dtype, + drop_mode=drop_mode, + dropout_rate=dropout_rate, + compute_dtype=compute_dtype, + ) @pytest.mark.level0 @@ -110,7 +158,7 @@ def test_adahessian_accuracy(mode): in_channels=2, out_channels=4, kernel_size=3, has_bias=True, weight_init=weight_init, bias_init=bias_init) def forward(a): - return ops.mean(net(a)**2)**.5 + return ops.sqrt(ops.mean(ops.square(net(a)))) grad_fn = ms.grad(forward, grad_position=None, weights=net.trainable_params()) @@ -133,7 +181,7 @@ def test_adahessian_accuracy(mode): @pytest.mark.platform_arm_ascend910b_training @pytest.mark.env_onecard @pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -@pytest.mark.parametrize('model_option', ['unet']) +@pytest.mark.parametrize('model_option', ['unet', 'attention']) def test_adahessian_st(mode, model_option): """ Feature: AdaHessian ST test @@ -146,7 +194,7 @@ def test_adahessian_st(mode, model_option): # default test with Attention network net = TestAttentionBlock(in_channels=256, num_heads=4) - inputs = ms.Tensor(np.sin(np.reshape(range(102400), [4, 100, 256])), dtype=ms.float32) + inputs = ms.Tensor(np.sin(np.arange(102400)).reshape(4, 100, 256), dtype=ms.float32) # test with UNet network if model_option.lower() == 'unet': @@ -159,12 +207,12 @@ def test_adahessian_st(mode, model_option): stride=2, activation='relu', data_format="NCHW", - enable_bn=True, + enable_bn=False, # bn leads to bug in PYNATIVE_MODE for MS2.5.0 ) - inputs = ms.Tensor(np.sin(np.reshape(range(16384), [2, 2, 64, 64])), dtype=ms.float32) + inputs = ms.Tensor(np.random.rand(2, 2, 64, 64), dtype=ms.float32) def forward(a): - return ops.mean(net(a)**2)**.5 + return ops.sqrt(ops.mean(ops.square(net(a)))) grad_fn = ms.grad(forward, grad_position=None, weights=net.trainable_params()) @@ -173,16 +221,17 @@ def test_adahessian_st(mode, model_option): learning_rate=0.1, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.) for _ in range(4): - loss = forward(inputs) optimizer(grad_fn, inputs) + loss = forward(inputs) assert ops.isfinite(loss) -@pytest.mark.level1 +@pytest.mark.level0 @pytest.mark.platform_arm_ascend910b_training @pytest.mark.env_onecard -def test_adahessian_compare(): +@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE]) +def test_adahessian_compare(mode): """ Feature: AdaHessian compare with Adam Description: Compare the algorithm results of the AdaHessian optimizer with Adam. @@ -190,15 +239,15 @@ def test_adahessian_compare(): The optimization runs 100 rounds to demonstrate an essential loss decrease. Expectation: The loss of AdaHessian outperforms Adam by 20% under the same configuration on an Attention network. """ - ms.set_context(mode=ms.PYNATIVE_MODE) + ms.set_context(mode=mode) def get_loss(optimizer_option): ''' compare Adam and AdaHessian ''' net = TestAttentionBlock(in_channels=256, num_heads=4) - inputs = ms.Tensor(np.sin(np.reshape(range(102400), [4, 100, 256])), dtype=ms.float32) + inputs = ms.Tensor(np.sin(np.arange(102400)).reshape(4, 100, 256), dtype=ms.float32) def forward(a): - return ops.mean(net(a)**2)**.5 + return ops.sqrt(ops.mean(ops.square(net(a)))) grad_fn = ms.grad(forward, grad_position=None, weights=net.trainable_params()) @@ -211,13 +260,13 @@ def test_adahessian_compare(): net.trainable_params(), learning_rate=0.01, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.) - for _ in range(100): - loss = forward(inputs) + for _ in range(20): if optimizer_option.lower() == 'adam': optimizer(grad_fn(inputs)) else: optimizer(grad_fn, inputs) + loss = forward(inputs) return loss loss_adam = get_loss('adam')