From 4be1cb09c28d5a143b0690043dcd038c9840ead0 Mon Sep 17 00:00:00 2001 From: dechin Date: Tue, 25 Jul 2023 14:29:36 +0800 Subject: [PATCH] Fix the protein interface issue --- .../protein_relaxation/protein_relax.py | 57 ++++++++++--------- .../src/sponge/potential/bias/oscillator.py | 2 + 2 files changed, 33 insertions(+), 26 deletions(-) diff --git a/MindSPONGE/applications/molecular_dynamics/protein_relaxation/protein_relax.py b/MindSPONGE/applications/molecular_dynamics/protein_relaxation/protein_relax.py index 334817537..40d0d0f5f 100644 --- a/MindSPONGE/applications/molecular_dynamics/protein_relaxation/protein_relax.py +++ b/MindSPONGE/applications/molecular_dynamics/protein_relaxation/protein_relax.py @@ -19,23 +19,24 @@ $ python3 protein_relax.py -i examples/protein/case2.pdb -o examples/protein/cas """ import argparse -import numpy as np from mindspore import context, Tensor, nn from mindspore import numpy as msnp import mindspore as ms -from mindsponge import Sponge -from mindsponge import set_global_units -from mindsponge import Protein -from mindsponge import ForceField -from mindsponge import SimulationCell -from mindsponge.callback import RunInfo -from mindsponge.optimizer import SteepestDescent -from mindsponge.potential.bias import OscillatorBias -from mindsponge.system.modelling.pdb_generator import gen_pdb +from sponge import Sponge +from sponge import set_global_units +from sponge import Protein +from sponge import ForceField +from sponge.callback import RunInfo +from sponge.core import WithEnergyCell, WithForceCell, RunOneStepCell +from sponge.optimizer import SteepestDescent +from sponge.sampling import MaskedDriven +from sponge.partition import NeighbourList +from sponge.potential.bias import OscillatorBias +from sponge.system.modelling.pdb_generator import gen_pdb from mindsponge.common.utils import get_pdb_info -from mindsponge.metrics.structure_violations import get_structural_violations +from mindsponge.pipeline.structure_violations import get_structural_violations parser = argparse.ArgumentParser() parser.add_argument("-i", help="Set the input pdb file path.") @@ -51,7 +52,7 @@ if context.get_context("device_target") == "Ascend": context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, graph_kernel_flags="--enable_cluster_ops=ReduceSum --reduce_fuse_depth=10") else: - context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=0) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=1, enable_graph_kernel=True) def get_violation_loss(system): @@ -91,16 +92,20 @@ def optimize_strategy(system, gds, loops, ads, adm, nonh_mask, mode=1): energy = ForceField(system, "AMBER.FF14SB") learning_rate = 1e-07 factor = 1.003 - opt = SteepestDescent( - system.trainable_params(), - learning_rate=learning_rate, - factor=factor, - nonh_mask=nonh_mask, - ) + dynamic_lr = nn.ExponentialDecayLR(learning_rate, factor, 1, is_stair=True) + opt = SteepestDescent(system.trainable_params(), dynamic_lr, max_shift=1.0) + neighbours = NeighbourList(system, cutoff=None, cast_fp16=True) + with_energy = WithEnergyCell(system, energy, neighbour_list=neighbours) + modifier = MaskedDriven(length_unit=with_energy.length_unit, + energy_unit=with_energy.energy_unit, + mask=nonh_mask) + with_force = WithForceCell(system, neighbour_list=neighbours, modifier=modifier) + one_step = RunOneStepCell(energy=with_energy, force=with_force, optimizer=opt) + for i, param in enumerate(opt.trainable_params()): print(i, param.name, param.shape) - md = Sponge(system, energy, opt) + md = Sponge(network=one_step) run_info = RunInfo(1) md.run(gds, callbacks=[run_info]) @@ -114,13 +119,14 @@ def optimize_strategy(system, gds, loops, ads, adm, nonh_mask, mode=1): if mode in (1, 2): energy.set_energy_scale([1, 1, 1, 1, 1, 1]) - simulation_network = SimulationCell(system, energy, bias=[harmonic_energy]) + simulation_network = WithEnergyCell(system, energy, bias=[harmonic_energy], neighbour_list=neighbours) for _ in range(adm): opt = nn.Adam(system.trainable_params(), learning_rate=learning_rate) + one_step = RunOneStepCell(energy=simulation_network, optimizer=opt) for i, param in enumerate(opt.trainable_params()): print(i, param.name, param.shape) - md = Sponge(simulation_network, optimizer=opt) + md = Sponge(network=one_step) print(md.calc_energy()) run_info = RunInfo(1) md.run(ads, callbacks=[run_info]) @@ -129,13 +135,14 @@ def optimize_strategy(system, gds, loops, ads, adm, nonh_mask, mode=1): if mode in (1, 3): energy.set_energy_scale([1, 1, 1, 0, 0, 0]) - simulation_network = SimulationCell(system, energy, bias=[harmonic_energy]) + simulation_network = WithEnergyCell(system, energy, bias=[harmonic_energy], neighbour_list=neighbours) for _ in range(adm): opt = nn.Adam(system.trainable_params(), learning_rate=learning_rate) + one_step = RunOneStepCell(energy=simulation_network, optimizer=opt) for i, param in enumerate(opt.trainable_params()): print(i, param.name, param.shape) - md = Sponge(simulation_network, optimizer=opt) + md = Sponge(network=one_step) print(md.calc_energy()) run_info = RunInfo(1) md.run(ads, callbacks=[run_info]) @@ -150,9 +157,7 @@ def main(): ms.set_seed(seed) set_global_units("A", "kcal/mol") system = Protein(pdb=pdb_name, rebuild_hydrogen=True) - nonh_mask = Tensor( - np.where(system.atomic_number[0] > 1, 0, 1)[None, :, None], ms.int32 - ) + nonh_mask = system.heavy_atom_mask gds, loops, ads, adm = 100, 3, 200, 2 system = optimize_strategy(system, gds, loops, ads, adm, nonh_mask, mode=1) diff --git a/MindSPONGE/src/sponge/potential/bias/oscillator.py b/MindSPONGE/src/sponge/potential/bias/oscillator.py index 8af217674..a2cd1e41b 100644 --- a/MindSPONGE/src/sponge/potential/bias/oscillator.py +++ b/MindSPONGE/src/sponge/potential/bias/oscillator.py @@ -57,6 +57,8 @@ class OscillatorBias(Bias): self.old_crd = Tensor(old_crd, ms.float32) self.k = Tensor(k, ms.float32) self.nonh_mask = Tensor(1 - nonh_mask, ms.int32) + if self.nonh_mask.ndim == 1: + self.nonh_mask = self.nonh_mask[None, :, None] def construct(self, coordinate: Tensor, -- Gitee